Skip to content

Commit

Permalink
[BugFix] return batch related ids in g.idtype (dmlc#6578)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored and DominikaJedynak committed Mar 12, 2024
1 parent 17d2df1 commit 84e663f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 40 deletions.
14 changes: 8 additions & 6 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def remove_edges(self, eids, etype=None, store_ids=False):
c_etype_batch_num_edges, one_hot_removed_edges, reducer="sum"
)
self._batch_num_edges[c_etype] = c_etype_batch_num_edges - F.astype(
batch_num_removed_edges, F.int64
batch_num_removed_edges, self.idtype
)

sub_g = self.edge_subgraph(
Expand Down Expand Up @@ -891,7 +891,7 @@ def remove_nodes(self, nids, ntype=None, store_ids=False):
self._batch_num_nodes[
target_ntype
] = c_ntype_batch_num_nodes - F.astype(
batch_num_removed_nodes, F.int64
batch_num_removed_nodes, self.idtype
)
# Record old num_edges to check later whether some edges were removed
old_num_edges = {
Expand All @@ -918,7 +918,7 @@ def remove_nodes(self, nids, ntype=None, store_ids=False):
for c_etype in canonical_etypes:
if self._graph.num_edges(self.get_etype_id(c_etype)) == 0:
self._batch_num_edges[c_etype] = F.zeros(
(self.batch_size,), F.int64, self.device
(self.batch_size,), self.idtype, self.device
)
continue

Expand All @@ -937,7 +937,7 @@ def remove_nodes(self, nids, ntype=None, store_ids=False):
reducer="sum",
)
self._batch_num_edges[c_etype] = F.astype(
batch_num_left_edges, F.int64
batch_num_left_edges, self.idtype
)

if batched and not store_ids:
Expand Down Expand Up @@ -1512,7 +1512,7 @@ def batch_num_nodes(self, ntype=None):
self._batch_num_nodes = {}
for ty in self.ntypes:
bnn = F.copy_to(
F.tensor([self.num_nodes(ty)], F.int64), self.device
F.tensor([self.num_nodes(ty)], self.idtype), self.device
)
self._batch_num_nodes[ty] = bnn
if ntype is None:
Expand Down Expand Up @@ -1602,6 +1602,7 @@ def set_batch_num_nodes(self, val):
batch
unbatch
"""
val = utils.prepare_tensor_or_dict(self, val, "batch_num_nodes")
if not isinstance(val, Mapping):
if len(self.ntypes) != 1:
raise DGLError(
Expand Down Expand Up @@ -1661,7 +1662,7 @@ def batch_num_edges(self, etype=None):
self._batch_num_edges = {}
for ty in self.canonical_etypes:
bne = F.copy_to(
F.tensor([self.num_edges(ty)], F.int64), self.device
F.tensor([self.num_edges(ty)], self.idtype), self.device
)
self._batch_num_edges[ty] = bne
if etype is None:
Expand Down Expand Up @@ -1753,6 +1754,7 @@ def set_batch_num_edges(self, val):
batch
unbatch
"""
val = utils.prepare_tensor_or_dict(self, val, "batch_num_edges")
if not isinstance(val, Mapping):
if len(self.etypes) != 1:
raise DGLError(
Expand Down
68 changes: 34 additions & 34 deletions tests/python/common/transforms/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,21 +1608,21 @@ def test_remove_edges(idtype):
assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([2, 0, 2], dtype=F.int64)
bg_r.batch_num_edges(), F.tensor([2, 0, 2], dtype=idtype)
)

bg_r = dgl.remove_edges(bg, [0, 2])
assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64)
bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=idtype)
)

bg_r = dgl.remove_edges(bg, F.tensor([0, 2], dtype=idtype))
assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64)
bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=idtype)
)

# batched heterogeneous graph
Expand Down Expand Up @@ -1659,7 +1659,7 @@ def test_remove_edges(idtype):
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([1, 2, 0], dtype=F.int64)
bg_r.batch_num_edges("follows"), F.tensor([1, 2, 0], dtype=idtype)
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), bg.batch_num_edges("plays")
Expand All @@ -1673,15 +1673,15 @@ def test_remove_edges(idtype):
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=idtype)
)

bg_r = dgl.remove_edges(bg, [0, 1, 3], etype="follows")
assert bg.batch_size == bg_r.batch_size
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64)
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_edges("plays"), bg_r.batch_num_edges("plays")
Expand All @@ -1695,7 +1695,7 @@ def test_remove_edges(idtype):
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
)

bg_r = dgl.remove_edges(
Expand All @@ -1705,7 +1705,7 @@ def test_remove_edges(idtype):
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64)
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_edges("plays"), bg_r.batch_num_edges("plays")
Expand All @@ -1719,7 +1719,7 @@ def test_remove_edges(idtype):
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
)


Expand Down Expand Up @@ -1847,28 +1847,28 @@ def test_remove_nodes(idtype):
bg_r = dgl.remove_nodes(bg, 1)
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(
bg_r.batch_num_nodes(), F.tensor([4, 0, 5], dtype=F.int64)
bg_r.batch_num_nodes(), F.tensor([4, 0, 5], dtype=idtype)
)
assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([0, 0, 3], dtype=F.int64)
bg_r.batch_num_edges(), F.tensor([0, 0, 3], dtype=idtype)
)

bg_r = dgl.remove_nodes(bg, [1, 7])
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(
bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64)
bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=idtype)
)
assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64)
bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=idtype)
)

bg_r = dgl.remove_nodes(bg, F.tensor([1, 7], dtype=idtype))
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(
bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64)
bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=idtype)
)
assert F.array_equal(
bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64)
bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=idtype)
)

# batched heterogeneous graph
Expand Down Expand Up @@ -1902,16 +1902,16 @@ def test_remove_nodes(idtype):
bg_r = dgl.remove_nodes(bg, 1, ntype="user")
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(
bg_r.batch_num_nodes("user"), F.tensor([3, 6, 3], dtype=F.int64)
bg_r.batch_num_nodes("user"), F.tensor([3, 6, 3], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game")
)
assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 2, 0], dtype=F.int64)
bg_r.batch_num_edges("follows"), F.tensor([0, 2, 0], dtype=idtype)
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 2], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 2], dtype=idtype)
)

bg_r = dgl.remove_nodes(bg, 6, ntype="game")
Expand All @@ -1920,28 +1920,28 @@ def test_remove_nodes(idtype):
bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user")
)
assert F.array_equal(
bg_r.batch_num_nodes("game"), F.tensor([3, 2, 2], dtype=F.int64)
bg_r.batch_num_nodes("game"), F.tensor([3, 2, 2], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([2, 0, 1], dtype=idtype)
)

bg_r = dgl.remove_nodes(bg, [1, 5, 6, 11], ntype="user")
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(
bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=F.int64)
bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game")
)
assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64)
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
)

bg_r = dgl.remove_nodes(bg, [0, 3, 4, 7], ntype="game")
Expand All @@ -1950,30 +1950,30 @@ def test_remove_nodes(idtype):
bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user")
)
assert F.array_equal(
bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=F.int64)
bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
)

bg_r = dgl.remove_nodes(
bg, F.tensor([1, 5, 6, 11], dtype=idtype), ntype="user"
)
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(
bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=F.int64)
bg_r.batch_num_nodes("user"), F.tensor([3, 4, 2], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_nodes("game"), bg_r.batch_num_nodes("game")
)
assert F.array_equal(
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=F.int64)
bg_r.batch_num_edges("follows"), F.tensor([0, 1, 0], dtype=idtype)
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
)

bg_r = dgl.remove_nodes(
Expand All @@ -1984,13 +1984,13 @@ def test_remove_nodes(idtype):
bg.batch_num_nodes("user"), bg_r.batch_num_nodes("user")
)
assert F.array_equal(
bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=F.int64)
bg_r.batch_num_nodes("game"), F.tensor([2, 0, 2], dtype=idtype)
)
assert F.array_equal(
bg.batch_num_edges("follows"), bg_r.batch_num_edges("follows")
)
assert F.array_equal(
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=F.int64)
bg_r.batch_num_edges("plays"), F.tensor([1, 0, 1], dtype=idtype)
)


Expand Down Expand Up @@ -2247,13 +2247,13 @@ def test_remove_selfloop(idtype):
idtype=idtype,
device=F.ctx(),
)
g.set_batch_num_nodes(F.tensor([3, 2], dtype=F.int64))
g.set_batch_num_edges(F.tensor([4, 3], dtype=F.int64))
g.set_batch_num_nodes([3, 2])
g.set_batch_num_edges([4, 3])
g = dgl.remove_self_loop(g)
assert g.num_nodes() == 5
assert g.num_edges() == 3
assert F.array_equal(g.batch_num_nodes(), F.tensor([3, 2], dtype=F.int64))
assert F.array_equal(g.batch_num_edges(), F.tensor([2, 1], dtype=F.int64))
assert F.array_equal(g.batch_num_nodes(), F.tensor([3, 2], dtype=idtype))
assert F.array_equal(g.batch_num_edges(), F.tensor([2, 1], dtype=idtype))


@parametrize_idtype
Expand Down

0 comments on commit 84e663f

Please sign in to comment.