Skip to content

Commit

Permalink
Merge branch 'master' into gb_batched_unique_and_compact
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Apr 2, 2024
2 parents 1bdf3ac + b743cde commit da15251
Show file tree
Hide file tree
Showing 7 changed files with 391 additions and 258 deletions.
121 changes: 114 additions & 7 deletions python/dgl/graphbolt/item_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,50 @@ def minibatcher_default(batch, names):
else:
init_data = {name: item for item, name in zip(batch, names)}
minibatch = MiniBatch()
# TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which need
# to be cleaned up later.
if "node_pairs" in names:
pos_seeds = init_data["node_pairs"]
# Build negative graph.
if "negative_srcs" in names and "negative_dsts" in names:
neg_srcs = init_data["negative_srcs"]
neg_dsts = init_data["negative_dsts"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(
pos_seeds, neg_srcs=neg_srcs, neg_dsts=neg_dsts
)
elif "negative_srcs" in names:
neg_srcs = init_data["negative_srcs"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(pos_seeds, neg_srcs=neg_srcs)
elif "negative_dsts" in names:
neg_dsts = init_data["negative_dsts"]
(
init_data["seeds"],
init_data["labels"],
init_data["indexes"],
) = _construct_seeds(pos_seeds, neg_dsts=neg_dsts)
else:
init_data["seeds"] = pos_seeds
for name, item in init_data.items():
if not hasattr(minibatch, name):
dgl_warning(
f"Unknown item name '{name}' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
if name == "node_pairs":
# `node_pairs` is passed as a tensor in shape of `(N, 2)` and
# should be converted to a tuple of `(src, dst)`.
if isinstance(item, Mapping):
item = {key: (item[key][:, 0], item[key][:, 1]) for key in item}
else:
item = (item[:, 0], item[:, 1])
# TODO(#7254): Hacks for original `seed_nodes` and `node_pairs`, which
# need to be cleaned up later.
if name == "seed_nodes":
name = "seeds"
if name in ("node_pairs", "negative_srcs", "negative_dsts"):
continue
setattr(minibatch, name, item)
return minibatch

Expand Down Expand Up @@ -744,3 +774,80 @@ def __init__(
)
self._world_size = dist.get_world_size()
self._rank = dist.get_rank()


def _construct_seeds(pos_seeds, neg_srcs=None, neg_dsts=None):
# For homogeneous graph.
if isinstance(pos_seeds, torch.Tensor):
negative_ratio = neg_srcs.size(1) if neg_srcs else neg_dsts.size(1)
neg_srcs = (
neg_srcs
if neg_srcs is not None
else pos_seeds[:, 0].repeat_interleave(negative_ratio)
).view(-1)
neg_dsts = (
neg_dsts
if neg_dsts is not None
else pos_seeds[:, 1].repeat_interleave(negative_ratio)
).view(-1)
neg_seeds = torch.cat((neg_srcs, neg_dsts)).view(2, -1).T
seeds = torch.cat((pos_seeds, neg_seeds))
pos_seeds_num = pos_seeds.size(0)
labels = torch.empty(seeds.size(0), device=pos_seeds.device)
labels[:pos_seeds_num] = 1
labels[pos_seeds_num:] = 0
pos_indexes = torch.arange(
0,
pos_seeds_num,
device=pos_seeds.device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes = torch.cat((pos_indexes, neg_indexes))
# For heterogeneous graph.
else:
negative_ratio = (
list(neg_srcs.values())[0].size(1)
if neg_srcs
else list(neg_dsts.values())[0].size(1)
)
seeds = {}
labels = {}
indexes = {}
for etype in pos_seeds:
neg_src = (
neg_srcs[etype]
if neg_srcs is not None
else pos_seeds[etype][:, 0].repeat_interleave(negative_ratio)
).view(-1)
neg_dst = (
neg_dsts[etype]
if neg_dsts is not None
else pos_seeds[etype][:, 1].repeat_interleave(negative_ratio)
).view(-1)
seeds[etype] = torch.cat(
(
pos_seeds[etype],
torch.cat(
(
neg_src,
neg_dst,
)
)
.view(2, -1)
.T,
)
)
pos_seeds_num = pos_seeds[etype].size(0)
labels[etype] = torch.empty(
seeds[etype].size(0), device=pos_seeds[etype].device
)
labels[etype][:pos_seeds_num] = 1
labels[etype][pos_seeds_num:] = 0
pos_indexes = torch.arange(
0,
pos_seeds_num,
device=pos_seeds[etype].device,
)
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes[etype] = torch.cat((pos_indexes, neg_indexes))
return seeds, labels, indexes
16 changes: 6 additions & 10 deletions tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,22 @@ def original_indices(minibatch):
return _indices

mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]).to(F.ctx()))
assert torch.equal(mn.seeds, torch.LongTensor([0]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 3]).to(F.ctx()),
)

mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]).to(F.ctx()))
assert torch.equal(mn.seeds, torch.LongTensor([5]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
)
assert torch.equal(original_indices(mn), torch.tensor([1, 4]).to(F.ctx()))

mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]).to(F.ctx()))
assert torch.equal(mn.seeds, torch.LongTensor([3]).to(F.ctx()))
assert torch.equal(
mn.sampled_subgraphs[0].sampled_csc.indptr,
torch.tensor([0, 2]).to(F.ctx()),
Expand Down Expand Up @@ -176,9 +176,7 @@ def test_InSubgraphSampler_hetero():
it = iter(in_subgraph_sampler)

mn = next(it)
assert torch.equal(
mn.seed_nodes["N0"], torch.LongTensor([1, 0]).to(F.ctx())
)
assert torch.equal(mn.seeds["N0"], torch.LongTensor([1, 0]).to(F.ctx()))
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
Expand All @@ -203,7 +201,7 @@ def test_InSubgraphSampler_hetero():
)

mn = next(it)
assert mn.seed_nodes == {
assert mn.seeds == {
"N0": torch.LongTensor([2]).to(F.ctx()),
"N1": torch.LongTensor([0]).to(F.ctx()),
}
Expand All @@ -230,9 +228,7 @@ def test_InSubgraphSampler_hetero():
)

mn = next(it)
assert torch.equal(
mn.seed_nodes["N1"], torch.LongTensor([2, 1]).to(F.ctx())
)
assert torch.equal(mn.seeds["N1"], torch.LongTensor([2, 1]).to(F.ctx()))
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
Expand Down
37 changes: 27 additions & 10 deletions tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ def test_UniformNegativeSampler_node_pairs_invoke():
def _verify(negative_sampler):
for data in negative_sampler:
# Assertation
assert data.negative_srcs is None
assert data.negative_dsts.size(0) == batch_size
assert data.negative_dsts.size(1) == negative_ratio
seeds_len = batch_size + batch_size * negative_ratio
assert data.seeds.size(0) == seeds_len
assert data.labels.size(0) == seeds_len
assert data.indexes.size(0) == seeds_len

# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
Expand Down Expand Up @@ -137,14 +138,30 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst = data.node_pairs
neg_src, neg_dst = data.negative_srcs, data.negative_dsts
expected_labels = torch.empty(
batch_size * (negative_ratio + 1), device=F.ctx()
)
expected_labels[:batch_size] = 1
expected_labels[batch_size:] = 0
expected_indexes = torch.arange(batch_size, device=F.ctx())
expected_indexes = torch.cat(
(
expected_indexes,
expected_indexes.repeat_interleave(negative_ratio),
)
)
expected_neg_src = data.seeds[:batch_size][:, 0].repeat_interleave(
negative_ratio
)
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
assert len(neg_dst) == batch_size
assert neg_src is None
assert neg_dst.numel() == batch_size * negative_ratio
assert data.negative_srcs is None
assert data.negative_dsts is None
assert data.labels is not None
assert data.indexes is not None
assert data.seeds.size(0) == batch_size * (negative_ratio + 1)
assert torch.equal(data.labels, expected_labels)
assert torch.equal(data.indexes, expected_indexes)
assert torch.equal(data.seeds[batch_size:][:, 0], expected_neg_src)


@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
Expand Down
40 changes: 11 additions & 29 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
item_sampler = gb.ItemSampler(
gb.ItemSet(torch.arange(20), names="seed_nodes"), 4
gb.ItemSet(torch.arange(20), names="seeds"), 4
)

# Invoke CopyTo via class constructor.
dp = gb.CopyTo(item_sampler, "cuda")
for data in dp:
assert data.seed_nodes.device.type == "cuda"
assert data.seeds.device.type == "cuda"

# Invoke CopyTo via functional form.
dp = item_sampler.copy_to("cuda")
for data in dp:
assert data.seed_nodes.device.type == "cuda"
assert data.seeds.device.type == "cuda"


@pytest.mark.parametrize(
Expand All @@ -37,7 +37,6 @@ def test_CopyTo():
"link_prediction",
"edge_classification",
"extra_attrs",
"other",
],
)
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
Expand All @@ -63,11 +62,6 @@ def test_CopyToWithMiniBatches_original(task):
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "labels"),
)
else:
itemset = gb.ItemSet(
(torch.arange(2 * N).reshape(-1, 2), torch.arange(N)),
names=("node_pairs", "seed_nodes"),
)
graph = gb_test_utils.rand_csc_graph(100, 0.15, bidirection_edge=True)

features = {}
Expand Down Expand Up @@ -96,38 +90,25 @@ def test_CopyToWithMiniBatches_original(task):
"sampled_subgraphs",
"labels",
"blocks",
"seeds",
]
elif task == "node_inference":
copied_attrs = [
"seed_nodes",
"seeds",
"sampled_subgraphs",
"blocks",
"labels",
]
elif task == "link_prediction":
elif task == "link_prediction" or task == "edge_classification":
copied_attrs = [
"compacted_node_pairs",
"node_features",
"edge_features",
"labels",
"compacted_seeds",
"sampled_subgraphs",
"compacted_negative_srcs",
"compacted_negative_dsts",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
]
elif task == "edge_classification":
copied_attrs = [
"compacted_node_pairs",
"indexes",
"node_features",
"edge_features",
"sampled_subgraphs",
"labels",
"blocks",
"positive_node_pairs",
"negative_node_pairs",
"node_pairs_with_labels",
"seeds",
]
elif task == "extra_attrs":
copied_attrs = [
Expand All @@ -137,6 +118,7 @@ def test_CopyToWithMiniBatches_original(task):
"labels",
"blocks",
"seed_nodes",
"seeds",
]

def test_data_device(datapipe):
Expand Down
4 changes: 2 additions & 2 deletions tests/python/pytorch/graphbolt/test_feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_FeatureFetcher_with_edges_homo():
)

def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes
seeds = minibatch.seeds
subgraphs = []
for _ in range(3):
sampled_csc = gb.CSCFormatBase(
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_FeatureFetcher_with_edges_hetero():
b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])

def add_node_and_edge_ids(minibatch):
seeds = minibatch.seed_nodes
seeds = minibatch.seeds
subgraphs = []
original_edge_ids = {
"n1:e1:n2": torch.randint(0, 50, (10,)),
Expand Down
Loading

0 comments on commit da15251

Please sign in to comment.