From 92c8f08de46fa568e42ab65387e840772c4622e6 Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Thu, 11 Jan 2024 18:31:37 +0800 Subject: [PATCH] [Graphbolt]Fix negative sampler (#6933) (#6938) Co-authored-by: peizhou001 <110809584+peizhou001@users.noreply.github.com> Co-authored-by: Ubuntu --- .../graphbolt/fused_csc_sampling_graph.h | 26 --- graphbolt/src/fused_csc_sampling_graph.cc | 12 -- graphbolt/src/python_binding.cc | 3 - .../impl/fused_csc_sampling_graph.py | 32 ++-- .../impl/uniform_negative_sampler.py | 15 +- .../graphbolt/impl/test_negative_sampler.py | 8 +- .../pytorch/graphbolt/test_integration.py | 153 ++++++++---------- 7 files changed, 98 insertions(+), 151 deletions(-) diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index 8a3e23ab509f..9df22fae52d5 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -356,32 +356,6 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { torch::optional node_timestamp_attr_name, torch::optional edge_timestamp_attr_name) const; - /** - * @brief Sample negative edges by randomly choosing negative - * source-destination pairs according to a uniform distribution. For each edge - * ``(u, v)``, it is supposed to generate `negative_ratio` pairs of negative - * edges ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in - * the graph. - * - * @param node_pairs A tuple of two 1D tensors that represent the source and - * destination of positive edges, with 'positive' indicating that these edges - * are present in the graph. It's important to note that within the context of - * a heterogeneous graph, the ids in these tensors signify heterogeneous ids. - * @param negative_ratio The ratio of the number of negative samples to - * positive samples. - * @param max_node_id The maximum ID of the node to be selected. It - * should correspond to the number of nodes of a specific type. - * - * @return A tuple consisting of two 1D tensors represents the source and - * destination of negative edges. In the context of a heterogeneous - * graph, both the input nodes and the selected nodes are represented - * by heterogeneous IDs. Note that negative refers to false negatives, - * which means the edge could be present or not present in the graph. - */ - std::tuple SampleNegativeEdgesUniform( - const std::tuple& node_pairs, - int64_t negative_ratio, int64_t max_node_id) const; - /** * @brief Copy the graph to shared memory. * @param shared_memory_name The name of the shared memory. diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 34e6a63233dd..998b67800943 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -692,18 +692,6 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( edge_timestamp)); } -std::tuple -FusedCSCSamplingGraph::SampleNegativeEdgesUniform( - const std::tuple& node_pairs, - int64_t negative_ratio, int64_t max_node_id) const { - torch::Tensor pos_src; - std::tie(pos_src, std::ignore) = node_pairs; - auto neg_len = pos_src.size(0) * negative_ratio; - auto neg_src = pos_src.repeat(negative_ratio); - auto neg_dst = torch::randint(0, max_node_id, {neg_len}, pos_src.options()); - return std::make_tuple(neg_src, neg_dst); -} - static c10::intrusive_ptr BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) { helper.InitializeRead(); diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index c60ad4b91180..44b6306d890d 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -52,9 +52,6 @@ TORCH_LIBRARY(graphbolt, m) { .def( "temporal_sample_neighbors", &FusedCSCSamplingGraph::TemporalSampleNeighbors) - .def( - "sample_negative_edges_uniform", - &FusedCSCSamplingGraph::SampleNegativeEdgesUniform) .def("copy_to_shared_memory", &FusedCSCSamplingGraph::CopyToSharedMemory) .def_pickle( // __getstate__ diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index a9261b7e7e6a..6bbc120b1538 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -850,7 +850,8 @@ def sample_negative_edges_uniform( pairs according to a uniform distribution. For each edge ``(u, v)``, it is supposed to generate `negative_ratio` pairs of negative edges ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in - the graph. + the graph. As ``u`` is exactly same as the corresponding positive edges, + it returns None for negative sources. Parameters ---------- @@ -877,23 +878,22 @@ def sample_negative_edges_uniform( `edge_type`. Note that negative refers to false negatives, which means the edge could be present or not present in the graph. """ - if edge_type is not None: - assert ( - self.node_type_offset is not None - ), "The 'node_type_offset' array is necessary for performing \ - negative sampling by edge type." - _, _, dst_node_type = etype_str_to_tuple(edge_type) - dst_node_type_id = self.node_type_to_id[dst_node_type] - max_node_id = ( - self.node_type_offset[dst_node_type_id + 1] - - self.node_type_offset[dst_node_type_id] - ) + if edge_type: + _, _, dst_ntype = etype_str_to_tuple(edge_type) + max_node_id = self.num_nodes[dst_ntype] else: max_node_id = self.total_num_nodes - return self._c_csc_graph.sample_negative_edges_uniform( - node_pairs, - negative_ratio, - max_node_id, + pos_src, _ = node_pairs + num_negative = pos_src.size(0) * negative_ratio + return ( + None, + torch.randint( + 0, + max_node_id, + (num_negative,), + dtype=pos_src.dtype, + device=pos_src.device, + ), ) def copy_to_shared_memory(self, shared_memory_name: str): diff --git a/python/dgl/graphbolt/impl/uniform_negative_sampler.py b/python/dgl/graphbolt/impl/uniform_negative_sampler.py index 512bd7ab5bc9..f979fd603249 100644 --- a/python/dgl/graphbolt/impl/uniform_negative_sampler.py +++ b/python/dgl/graphbolt/impl/uniform_negative_sampler.py @@ -32,20 +32,23 @@ class UniformNegativeSampler(NegativeSampler): Examples -------- >>> from dgl import graphbolt as gb - >>> indptr = torch.LongTensor([0, 2, 4, 5]) - >>> indices = torch.LongTensor([1, 2, 0, 2, 0]) + >>> indptr = torch.LongTensor([0, 1, 2, 3, 4]) + >>> indices = torch.LongTensor([1, 2, 3, 0]) >>> graph = gb.fused_csc_sampling_graph(indptr, indices) - >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) + >>> node_pairs = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) >>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_sampler = gb.ItemSampler( - ... item_set, batch_size=1,) + ... item_set, batch_size=4,) >>> neg_sampler = gb.UniformNegativeSampler( ... item_sampler, graph, 2) >>> for minibatch in neg_sampler: ... print(minibatch.negative_srcs) ... print(minibatch.negative_dsts) - (tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0])) - (tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0])) + None + tensor([[2, 1], + [2, 1], + [3, 2], + [1, 3]]) """ def __init__( diff --git a/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py b/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py index 7905f32798f1..577ade0e6f3f 100644 --- a/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py @@ -46,8 +46,7 @@ def test_UniformNegativeSampler_invoke(): def _verify(negative_sampler): for data in negative_sampler: # Assertation - assert data.negative_srcs.size(0) == batch_size - assert data.negative_srcs.size(1) == negative_ratio + assert data.negative_srcs is None assert data.negative_dsts.size(0) == batch_size assert data.negative_dsts.size(1) == negative_ratio @@ -90,12 +89,9 @@ def test_Uniform_NegativeSampler(negative_ratio): # Assertation assert len(pos_src) == batch_size assert len(pos_dst) == batch_size - assert len(neg_src) == batch_size assert len(neg_dst) == batch_size - assert neg_src.numel() == batch_size * negative_ratio + assert neg_src is None assert neg_dst.numel() == batch_size * negative_ratio - expected_src = pos_src.repeat(negative_ratio).view(-1, negative_ratio) - assert torch.equal(expected_src, neg_src) def get_hetero_graph(): diff --git a/tests/python/pytorch/graphbolt/test_integration.py b/tests/python/pytorch/graphbolt/test_integration.py index e6c16567b8ba..fe3b5c0a2daa 100644 --- a/tests/python/pytorch/graphbolt/test_integration.py +++ b/tests/python/pytorch/graphbolt/test_integration.py @@ -48,7 +48,7 @@ def test_integration_link_prediction(): } feature_store = gb.BasicFeatureStore(features) datapipe = gb.ItemSampler(item_set, batch_size=4) - datapipe = datapipe.sample_uniform_negative(graph, 1) + datapipe = datapipe.sample_uniform_negative(graph, 2) fanouts = torch.LongTensor([1]) datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True) datapipe = datapipe.transform(gb.exclude_seed_edges) @@ -62,23 +62,23 @@ def test_integration_link_prediction(): str( """MiniBatch(seed_nodes=None, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]), - indices=tensor([5, 4]), + indices=tensor([0, 4]), ), original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_edge_ids=None, original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]), ), - SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1]), - indices=tensor([5]), + SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]), + indices=tensor([5, 4]), ), original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_edge_ids=None, - original_column_node_ids=tensor([5, 3, 1, 2, 0]), + original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]), )], positive_node_pairs=(tensor([0, 1, 1, 1]), tensor([2, 3, 3, 1])), - node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4])), - tensor([1., 1., 1., 1., 0., 0., 0., 0.])), + node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4, 0, 1, 1, 5])), + tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])), node_pairs=(tensor([5, 3, 3, 3]), tensor([1, 2, 2, 3])), node_features={'feat': tensor([[0.5160, 0.2486], @@ -87,131 +87,120 @@ def test_integration_link_prediction(): [0.2109, 0.1089], [0.9634, 0.2294], [0.5503, 0.8223]])}, - negative_srcs=tensor([[5], - [3], - [3], - [3]]), - negative_node_pairs=(tensor([0, 1, 1, 1]), - tensor([4, 4, 1, 4])), - negative_dsts=tensor([[0], - [0], - [3], - [0]]), + negative_srcs=None, + negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 1, 1]), + tensor([4, 4, 1, 4, 0, 1, 1, 5])), + negative_dsts=tensor([[0, 0], + [3, 0], + [5, 3], + [3, 4]]), labels=None, input_nodes=tensor([5, 3, 1, 2, 0, 4]), edge_features=[{}, {}], compacted_node_pairs=(tensor([0, 1, 1, 1]), tensor([2, 3, 3, 1])), - compacted_negative_srcs=tensor([[0], - [1], - [1], - [1]]), - compacted_negative_dsts=tensor([[4], - [4], - [1], - [4]]), + compacted_negative_srcs=None, + compacted_negative_dsts=tensor([[4, 4], + [1, 4], + [0, 1], + [1, 5]]), blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2), - Block(num_src_nodes=6, num_dst_nodes=5, num_edges=1)], + Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)], )""" ), str( """MiniBatch(seed_nodes=None, - sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]), - indices=tensor([1, 3]), + sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]), + indices=tensor([4, 1, 0]), ), - original_row_node_ids=tensor([3, 4, 0, 5, 1]), + original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_edge_ids=None, - original_column_node_ids=tensor([3, 4, 0, 5, 1]), + original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]), ), - SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]), - indices=tensor([1, 3]), + SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]), + indices=tensor([4, 4, 0]), ), - original_row_node_ids=tensor([3, 4, 0, 5, 1]), + original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_edge_ids=None, - original_column_node_ids=tensor([3, 4, 0, 5, 1]), + original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]), )], positive_node_pairs=(tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1])), - node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 1, 1, 2]), tensor([0, 0, 1, 1, 1, 1, 3, 4])), - tensor([1., 1., 1., 1., 0., 0., 0., 0.])), + node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 0, 1, 1, 1, 1, 2, 2]), tensor([0, 0, 1, 1, 3, 4, 5, 4, 1, 0, 3, 4])), + tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])), node_pairs=(tensor([3, 4, 4, 0]), tensor([3, 3, 4, 4])), node_features={'feat': tensor([[0.8672, 0.2276], [0.5503, 0.8223], [0.9634, 0.2294], + [0.6172, 0.7865], [0.5160, 0.2486], - [0.6172, 0.7865]])}, - negative_srcs=tensor([[3], - [4], - [4], - [0]]), - negative_node_pairs=(tensor([0, 1, 1, 2]), - tensor([1, 1, 3, 4])), - negative_dsts=tensor([[4], - [4], - [5], - [1]]), + [0.2109, 0.1089]])}, + negative_srcs=None, + negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 2, 2]), + tensor([3, 4, 5, 4, 1, 0, 3, 4])), + negative_dsts=tensor([[1, 5], + [2, 5], + [4, 3], + [1, 5]]), labels=None, - input_nodes=tensor([3, 4, 0, 5, 1]), + input_nodes=tensor([3, 4, 0, 1, 5, 2]), edge_features=[{}, {}], compacted_node_pairs=(tensor([0, 1, 1, 2]), tensor([0, 0, 1, 1])), - compacted_negative_srcs=tensor([[0], - [1], - [1], - [2]]), - compacted_negative_dsts=tensor([[1], - [1], - [3], - [4]]), - blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2), - Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)], + compacted_negative_srcs=None, + compacted_negative_dsts=tensor([[3, 4], + [5, 4], + [1, 0], + [3, 4]]), + blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3), + Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)], )""" ), str( """MiniBatch(seed_nodes=None, - sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1]), - indices=tensor([1]), + sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]), + indices=tensor([1, 0]), ), - original_row_node_ids=tensor([5, 4]), + original_row_node_ids=tensor([5, 4, 0, 1]), original_edge_ids=None, - original_column_node_ids=tensor([5, 4]), + original_column_node_ids=tensor([5, 4, 0, 1]), ), - SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1]), - indices=tensor([1]), + SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]), + indices=tensor([1, 0]), ), - original_row_node_ids=tensor([5, 4]), + original_row_node_ids=tensor([5, 4, 0, 1]), original_edge_ids=None, - original_column_node_ids=tensor([5, 4]), + original_column_node_ids=tensor([5, 4, 0, 1]), )], positive_node_pairs=(tensor([0, 1]), tensor([0, 0])), - node_pairs_with_labels=((tensor([0, 1, 0, 1]), tensor([0, 0, 0, 0])), - tensor([1., 1., 0., 0.])), + node_pairs_with_labels=((tensor([0, 1, 0, 0, 1, 1]), tensor([0, 0, 2, 1, 2, 3])), + tensor([1., 1., 0., 0., 0., 0.])), node_pairs=(tensor([5, 4]), tensor([5, 5])), node_features={'feat': tensor([[0.5160, 0.2486], - [0.5503, 0.8223]])}, - negative_srcs=tensor([[5], - [4]]), - negative_node_pairs=(tensor([0, 1]), - tensor([0, 0])), - negative_dsts=tensor([[5], - [5]]), + [0.5503, 0.8223], + [0.9634, 0.2294], + [0.6172, 0.7865]])}, + negative_srcs=None, + negative_node_pairs=(tensor([0, 0, 1, 1]), + tensor([2, 1, 2, 3])), + negative_dsts=tensor([[0, 4], + [0, 1]]), labels=None, - input_nodes=tensor([5, 4]), + input_nodes=tensor([5, 4, 0, 1]), edge_features=[{}, {}], compacted_node_pairs=(tensor([0, 1]), tensor([0, 0])), - compacted_negative_srcs=tensor([[0], - [1]]), - compacted_negative_dsts=tensor([[0], - [0]]), - blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1), - Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)], + compacted_negative_srcs=None, + compacted_negative_dsts=tensor([[2, 1], + [2, 3]]), + blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2), + Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)], )""" ), ]