Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 29, 2024
1 parent c6e0d0a commit 97f232b
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/python/pytorch/graphbolt/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def test_exclude_seed_edges_homo_cpu():
F._default_context_str == "cpu",
reason="Fails due to different result on the CPU.",
)
def test_exclude_seed_edges_gpu():
@pytest.mark.parametrize("use_datapipe", [False, True])
@pytest.mark.parametrize("async_op", [False, True])
def test_exclude_seed_edges_gpu(use_datapipe, async_op):
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
items = torch.LongTensor([[0, 3], [4, 4]])
Expand All @@ -137,7 +139,10 @@ def test_exclude_seed_edges_gpu():
fanouts,
deduplicate=True,
)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
if use_datapipe:
datapipe = datapipe.exclude_seed_edges(asynchronous=async_op)
else:

Check warning on line 144 in tests/python/pytorch/graphbolt/test_utils.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
datapipe = datapipe.transform(partial(gb.exclude_seed_edges, async_op=async_op))
if torch.cuda.get_device_capability()[0] < 7:
original_row_node_ids = [
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
Expand Down Expand Up @@ -174,6 +179,8 @@ def test_exclude_seed_edges_gpu():
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
if async_op and not datapipe:
sampled_subgraph = sampled_subgraph.wait()
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
Expand Down

0 comments on commit 97f232b

Please sign in to comment.