Skip to content

Commit

Permalink
add cpu test as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 29, 2024
1 parent 3678edd commit 232139b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/python/pytorch/graphbolt/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def test_add_reverse_edges_hetero():
F._default_context_str == "gpu",
reason="Fails due to different result on the GPU.",
)
def test_exclude_seed_edges_homo_cpu():
@pytest.mark.parametrize("use_datapipe", [False, True])
def test_exclude_seed_edges_homo_cpu(use_datapipe):
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx())
items = torch.LongTensor([[0, 3], [4, 4]])
Expand All @@ -83,7 +84,10 @@ def test_exclude_seed_edges_homo_cpu():
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = gb.NeighborSampler
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
if use_datapipe:
datapipe = datapipe.exclude_seed_edges()
else:
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
Expand Down

0 comments on commit 232139b

Please sign in to comment.