Skip to content

Commit

Permalink
[GraphBolt] Fix gpu NegativeSampler for seeds. (#7068)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
  • Loading branch information
yxy235 and Ubuntu committed Feb 4, 2024
1 parent 3d854a6 commit af0b63e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 25 deletions.
8 changes: 7 additions & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,13 @@ def sample_negative_edges_uniform_2(
torch.cat(
(
pos_src.repeat_interleave(negative_ratio),
torch.randint(0, max_node_id, (num_negative,)),
torch.randint(
0,
max_node_id,
(num_negative,),
dtype=node_pairs.dtype,
device=node_pairs.device,
),
),
)
.view(2, num_negative)
Expand Down
21 changes: 18 additions & 3 deletions python/dgl/graphbolt/impl/uniform_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,30 @@ def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
# Construct indexes for all node pairs.
num_pos_node_pairs = node_pairs.shape[0]
negative_ratio = self.negative_ratio
pos_indexes = torch.arange(0, num_pos_node_pairs)
pos_indexes = torch.arange(
0,
num_pos_node_pairs,
device=seeds.device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes = torch.cat((pos_indexes, neg_indexes))
# Construct labels for all node pairs.
pos_num = node_pairs.shape[0]
neg_num = seeds.shape[0] - pos_num
labels = torch.cat(
(torch.ones(pos_num), torch.zeros(neg_num))
).bool()
(
torch.ones(
pos_num,
dtype=torch.bool,
device=seeds.device,
),
torch.zeros(
neg_num,
dtype=torch.bool,
device=seeds.device,
),
),
)
return seeds, labels, indexes
else:
return self.graph.sample_negative_edges_uniform(
Expand Down
75 changes: 54 additions & 21 deletions tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

import backend as F

import dgl.graphbolt as gb
import pytest
import torch
Expand All @@ -14,7 +16,9 @@ def test_NegativeSampler_invoke():
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2

# Invoke NegativeSampler via class constructor.
Expand All @@ -35,13 +39,17 @@ def test_NegativeSampler_invoke():

def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2

def _verify(negative_sampler):
Expand Down Expand Up @@ -70,13 +78,17 @@ def _verify(negative_sampler):

def test_UniformNegativeSampler_node_pairs_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2

# Verify iteration over UniformNegativeSampler.
Expand Down Expand Up @@ -106,13 +118,17 @@ def _verify(negative_sampler):
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
Expand All @@ -134,13 +150,17 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio):
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
Expand All @@ -159,12 +179,15 @@ def test_Uniform_NegativeSampler(negative_ratio):
neg_src = data.seeds[batch_size:, 0]
assert torch.equal(pos_src.repeat_interleave(negative_ratio), neg_src)
# Check labels.
assert torch.equal(data.labels[:batch_size], torch.ones(batch_size))
assert torch.equal(
data.labels[batch_size:], torch.zeros(batch_size * negative_ratio)
data.labels[:batch_size], torch.ones(batch_size).to(F.ctx())
)
assert torch.equal(
data.labels[batch_size:],
torch.zeros(batch_size * negative_ratio).to(F.ctx()),
)
# Check indexes.
pos_indexes = torch.arange(0, batch_size)
pos_indexes = torch.arange(0, batch_size).to(F.ctx())
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
expected_indexes = torch.cat((pos_indexes, neg_indexes))
assert torch.equal(data.indexes, expected_indexes)
Expand All @@ -173,13 +196,17 @@ def test_Uniform_NegativeSampler(negative_ratio):
def test_Uniform_NegativeSampler_error_shape():
# 1. seeds with shape N*3.
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to(
F.ctx()
)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 3).reshape(-1, 3), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
negative_ratio = 2
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
Expand All @@ -201,7 +228,9 @@ def test_Uniform_NegativeSampler_error_shape():
item_set = gb.ItemSet(
torch.arange(0, num_seeds * 2).reshape(-1, 2, 1), names="seeds"
)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
Expand All @@ -220,7 +249,9 @@ def test_Uniform_NegativeSampler_error_shape():
# 3. seeds with shape N.
# Construct FusedCSCSamplingGraph.
item_set = gb.ItemSet(torch.arange(0, num_seeds), names="seeds")
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx()
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
Expand Down Expand Up @@ -260,7 +291,7 @@ def get_hetero_graph():


def test_NegativeSampler_Hetero_node_pairs_Data():
graph = get_hetero_graph()
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
Expand All @@ -274,13 +305,13 @@ def test_NegativeSampler_Hetero_node_pairs_Data():
}
)

item_sampler = gb.ItemSampler(itemset, batch_size=2)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1)
assert len(list(negative_dp)) == 5


def test_NegativeSampler_Hetero_Data():
graph = get_hetero_graph()
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{
"n1:e1:n2": gb.ItemSet(
Expand All @@ -295,7 +326,9 @@ def test_NegativeSampler_Hetero_Data():
)
batch_size = 2
negative_ratio = 1
item_sampler = gb.ItemSampler(itemset, batch_size=batch_size)
item_sampler = gb.ItemSampler(itemset, batch_size=batch_size).copy_to(
F.ctx()
)
negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio)
assert len(list(negative_dp)) == 5
# Perform negative sampling.
Expand All @@ -311,5 +344,5 @@ def test_NegativeSampler_Hetero_Data():
for etype, seeds_data in data.seeds.items():
neg_src = seeds_data[batch_size:, 0]
neg_dst = seeds_data[batch_size:, 1]
assert torch.equal(expected_neg_src[i][etype], neg_src)
assert torch.equal(expected_neg_src[i][etype].to(F.ctx()), neg_src)
assert (neg_dst < 3).all(), neg_dst

0 comments on commit af0b63e

Please sign in to comment.