From 6b29acb5780f5a6216498bfc6ef7473b093f2fae Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Wed, 3 Jul 2024 01:48:48 -0400 Subject: [PATCH] [GraphBolt] Add temporal labor sampling to graph. --- .../graphbolt/fused_csc_sampling_graph.h | 4 +- graphbolt/src/fused_csc_sampling_graph.cc | 56 ++++++-- .../impl/fused_csc_sampling_graph.py | 133 ++++++++++++++++++ .../impl/test_fused_csc_sampling_graph.py | 14 +- 4 files changed, 190 insertions(+), 17 deletions(-) diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index a2d280777444..3ea2827573c1 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -404,7 +404,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { torch::optional input_nodes_pre_time_window, torch::optional probs_or_mask, torch::optional node_timestamp_attr_name, - torch::optional edge_timestamp_attr_name) const; + torch::optional edge_timestamp_attr_name, + torch::optional random_seed, + double seed2_contribution) const; /** * @brief Copy the graph to shared memory. diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 3c217e011fce..f74295ed6b7f 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -910,7 +910,9 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( torch::optional input_nodes_pre_time_window, torch::optional probs_or_mask, torch::optional node_timestamp_attr_name, - torch::optional edge_timestamp_attr_name) const { + torch::optional edge_timestamp_attr_name, + torch::optional random_seed, + double seed2_contribution) const { torch::optional> seed_offsets = torch::nullopt; // 1. Get probs_or_mask. if (probs_or_mask.has_value()) { @@ -928,19 +930,45 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name); // 4. Call SampleNeighborsImpl if (layer) { - const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( - static_cast(0), std::numeric_limits::max()); - SamplerArgs args{indices_, random_seed, NumNodes()}; - return SampleNeighborsImpl( - input_nodes, seed_offsets, fanouts, return_eids, - GetTemporalNumPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - type_per_edge_, input_nodes_pre_time_window, probs_or_mask, - node_timestamp, edge_timestamp), - GetTemporalPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - indptr_.options(), type_per_edge_, input_nodes_pre_time_window, - probs_or_mask, node_timestamp, edge_timestamp, args)); + if (random_seed.has_value() && random_seed->numel() >= 2) { + SamplerArgs args{ + indices_, + {random_seed.value(), static_cast(seed2_contribution)}, + NumNodes()}; + return SampleNeighborsImpl( + input_nodes, seed_offsets, fanouts, return_eids, + GetTemporalNumPickFn( + input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_, + input_nodes_pre_time_window, probs_or_mask, node_timestamp, + edge_timestamp), + GetTemporalPickFn( + input_nodes_timestamp, indices_, fanouts, replace, + indptr_.options(), type_per_edge_, input_nodes_pre_time_window, + probs_or_mask, node_timestamp, edge_timestamp, args)); + } else { + auto args = [&] { + if (random_seed.has_value() && random_seed->numel() == 1) { + return SamplerArgs{ + indices_, random_seed.value(), NumNodes()}; + } else { + return SamplerArgs{ + indices_, + RandomEngine::ThreadLocal()->RandInt( + static_cast(0), std::numeric_limits::max()), + NumNodes()}; + } + }(); + return SampleNeighborsImpl( + input_nodes, seed_offsets, fanouts, return_eids, + GetTemporalNumPickFn( + input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_, + input_nodes_pre_time_window, probs_or_mask, node_timestamp, + edge_timestamp), + GetTemporalPickFn( + input_nodes_timestamp, indices_, fanouts, replace, + indptr_.options(), type_per_edge_, input_nodes_pre_time_window, + probs_or_mask, node_timestamp, edge_timestamp, args)); + } } else { SamplerArgs args; return SampleNeighborsImpl( diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index f9c583c30d98..1e753a2aeee6 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -1119,6 +1119,139 @@ def temporal_sample_neighbors( probs_or_mask, node_timestamp_attr_name, edge_timestamp_attr_name, + None, # random_seed, labor parameter + 0, # seed2_contribution, labor_parameter + ) + return self._convert_to_sampled_subgraph(C_sampled_subgraph) + + def temporal_sample_layer_neighbors( + self, + nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], + input_nodes_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]], + fanouts: torch.Tensor, + replace: bool = False, + input_nodes_pre_time_window: Optional[ + Union[torch.Tensor, Dict[str, torch.Tensor]] + ] = None, + probs_name: Optional[str] = None, + node_timestamp_attr_name: Optional[str] = None, + edge_timestamp_attr_name: Optional[str] = None, + random_seed: torch.Tensor = None, + seed2_contribution: float = 0.0, + ) -> torch.ScriptObject: + """Temporally Sample neighboring edges of the given nodes and return the induced + subgraph via layer-neighbor sampling from the NeurIPS 2023 paper + `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs + `__ + + If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given, + the sampled neighbor or edge of an input node must have a timestamp + that is smaller than that of the input node. + + Parameters + ---------- + nodes: torch.Tensor + IDs of the given seed nodes. + input_nodes_timestamp: torch.Tensor + Timestamps of the given seed nodes. + fanouts: torch.Tensor + The number of edges to be sampled for each node with or without + considering edge types. + - When the length is 1, it indicates that the fanout applies to + all neighbors of the node as a collective, regardless of the + edge type. + - Otherwise, the length should equal to the number of edge + types, and each fanout value corresponds to a specific edge + type of the nodes. + The value of each fanout should be >= 0 or = -1. + - When the value is -1, all neighbors (with non-zero probability, + if weighted) will be sampled once regardless of replacement. It + is equivalent to selecting all neighbors with non-zero + probability when the fanout is >= the number of neighbors (and + replace is set to false). + - When the value is a non-negative integer, it serves as a + minimum threshold for selecting neighbors. + replace: bool + Boolean indicating whether the sample is preformed with or + without replacement. If True, a value can be selected multiple + times. Otherwise, each value can be selected only once. + input_nodes_pre_time_window: torch.Tensor + The time window of the nodes represents a period of time before + `input_nodes_timestamp`. If provided, only neighbors and related + edges whose timestamps fall within `[input_nodes_timestamp - + input_nodes_pre_time_window, input_nodes_timestamp]` will be + filtered. + probs_name: str, optional + An optional string specifying the name of an edge attribute. This + attribute tensor should contain (unnormalized) probabilities + corresponding to each neighboring edge of a node. It must be a 1D + floating-point or boolean tensor, with the number of elements + equalling the total number of edges. + node_timestamp_attr_name: str, optional + An optional string specifying the name of an node attribute. + edge_timestamp_attr_name: str, optional + An optional string specifying the name of an edge attribute. + random_seed: torch.Tensor, optional + An int64 tensor with one or two elements. + + The passed random_seed makes it so that for any seed node ``s`` and + its neighbor ``t``, the rolled random variate ``r_t`` is the same + for any call to this function with the same random seed. When + sampling as part of the same batch, one would want identical seeds + so that LABOR can globally sample. One example is that for + heterogenous graphs, there is a single random seed passed for each + edge type. This will sample much fewer nodes compared to having + unique random seeds for each edge type. If one called this function + individually for each edge type for a heterogenous graph with + different random seeds, then it would run LABOR locally for each + edge type, resulting into a larger number of nodes being sampled. + + If this function is called without a ``random_seed``, we get the + random seed by getting a random number from GraphBolt. Use this + argument with identical random_seed if multiple calls to this + function are used to sample as part of a single batch. + + If given two numbers, then the ``seed2_contribution`` argument + determines the interpolation between the two random seeds. + seed2_contribution: float, optional + A float value between [0, 1) that determines the contribution of the + second random seed, ``random_seed[-1]``, to generate the random + variates. + + Returns + ------- + SampledSubgraphImpl + The sampled subgraph. + """ + if isinstance(nodes, dict): + ( + nodes, + input_nodes_timestamp, + input_nodes_pre_time_window, + ) = self._convert_to_homogeneous_nodes( + nodes, input_nodes_timestamp, input_nodes_pre_time_window + ) + + # Ensure nodes is 1-D tensor. + probs_or_mask = self.edge_attributes[probs_name] if probs_name else None + self._check_sampler_arguments(nodes, fanouts, probs_or_mask) + has_original_eids = ( + self.edge_attributes is not None + and ORIGINAL_EDGE_ID in self.edge_attributes + ) + C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors( + nodes, + input_nodes_timestamp, + fanouts.tolist(), + replace, + True, + has_original_eids, + input_nodes_pre_time_window, + probs_or_mask, + node_timestamp_attr_name, + edge_timestamp_attr_name, + random_seed, + seed2_contribution, ) return self._convert_to_sampled_subgraph(C_sampled_subgraph) diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index 4204fba159a6..1f23e8c685af 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -825,10 +825,16 @@ def test_in_subgraph_hetero(): @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("replace", [False, True]) +@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("use_node_timestamp", [False, True]) @pytest.mark.parametrize("use_edge_timestamp", [False, True]) def test_temporal_sample_neighbors_homo( - indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp + indptr_dtype, + indices_dtype, + replace, + labor, + use_node_timestamp, + use_edge_timestamp, ): """Original graph in COO: 1 0 1 0 1 @@ -853,7 +859,11 @@ def test_temporal_sample_neighbors_homo( # Generate subgraph via sample neighbors. fanouts = torch.LongTensor([2]) - sampler = graph.temporal_sample_neighbors + sampler = ( + graph.temporal_sample_layer_neighbors + if labor + else graph.temporal_sample_neighbors + ) seed_list = [1, 3, 4] seed_timestamp = torch.randint(0, 100, (len(seed_list),), dtype=torch.int64)