Skip to content

Commit

Permalink
negative node pairs should be 2D
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jan 15, 2024
1 parent 90e57e7 commit f04a106
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
42 changes: 22 additions & 20 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ def negative_node_pairs(self):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_negative_dsts.view(-1),
self.compacted_negative_srcs,
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_negative_dsts[etype].view(-1),
neg_src,
self.compacted_negative_dsts[etype],
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
Expand All @@ -319,10 +319,10 @@ def negative_node_pairs(self):
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_node_pairs[1].repeat_interleave(
negative_ratio
),
self.compacted_negative_srcs,
self.compacted_node_pairs[1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
# For heterogeneous graph.
else:
Expand All @@ -331,10 +331,10 @@ def negative_node_pairs(self):
].size(1)
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_node_pairs[etype][1].repeat_interleave(
negative_ratio
),
neg_src,
self.compacted_node_pairs[etype][1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
Expand All @@ -346,10 +346,10 @@ def negative_node_pairs(self):
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
negative_node_pairs = (
self.compacted_node_pairs[0].repeat_interleave(
negative_ratio
),
self.compacted_negative_dsts.view(-1),
self.compacted_node_pairs[0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
Expand All @@ -358,10 +358,10 @@ def negative_node_pairs(self):
].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0].repeat_interleave(
negative_ratio
),
neg_dst.view(-1),
self.compacted_node_pairs[etype][0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
neg_dst,
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
Expand Down Expand Up @@ -396,6 +396,7 @@ def node_pairs_with_labels(self):
for etype in positive_node_pairs:
pos_src, pos_dst = positive_node_pairs[etype]
neg_src, neg_dst = negative_node_pairs[etype]
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs_by_etype[etype] = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
Expand All @@ -410,6 +411,7 @@ def node_pairs_with_labels(self):
# Homogeneous graph.
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
Expand Down
6 changes: 6 additions & 0 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,16 @@ def _node_pairs_preprocess(self, minibatch):
for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_srcs[etype] = compacted[src_type].pop(0)
compacted_negative_srcs[etype] = compacted_negative_srcs[
etype
].view(neg_src[etype].shape)
if has_neg_dst:
for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
compacted_negative_dsts[etype] = compacted_negative_dsts[
etype
].view(neg_dst[etype].shape)
else:
# Collect nodes from all types of input.
nodes = list(node_pairs)
Expand Down
26 changes: 20 additions & 6 deletions tests/python/pytorch/graphbolt/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,14 @@ def test_integration_link_prediction():
[0.9634, 0.2294],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 1, 1]),
tensor([4, 4, 1, 4, 0, 1, 1, 5])),
negative_node_pairs=(tensor([[0, 0],
[1, 1],
[1, 1],
[1, 1]]),
tensor([[4, 4],
[1, 4],
[0, 1],
[1, 5]])),
negative_dsts=tensor([[0, 0],
[3, 0],
[5, 3],
Expand Down Expand Up @@ -138,8 +144,14 @@ def test_integration_link_prediction():
[0.5160, 0.2486],
[0.2109, 0.1089]])},
negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 2, 2]),
tensor([3, 4, 5, 4, 1, 0, 3, 4])),
negative_node_pairs=(tensor([[0, 0],
[1, 1],
[1, 1],
[2, 2]]),
tensor([[3, 4],
[5, 4],
[1, 0],
[3, 4]])),
negative_dsts=tensor([[1, 5],
[2, 5],
[4, 3],
Expand Down Expand Up @@ -186,8 +198,10 @@ def test_integration_link_prediction():
[0.9634, 0.2294],
[0.6172, 0.7865]])},
negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1]),
tensor([2, 1, 2, 3])),
negative_node_pairs=(tensor([[0, 0],
[1, 1]]),
tensor([[2, 1],
[2, 3]])),
negative_dsts=tensor([[0, 4],
[0, 1]]),
labels=None,
Expand Down

0 comments on commit f04a106

Please sign in to comment.