Skip to content

Commit

Permalink
fix the test.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 13, 2024
1 parent c933aec commit 7730aea
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .gpu_graph_cache import *
from .cpu_feature_cache import *
from .cpu_cached_feature import *
from .cooperative_conv import *
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..sampled_subgraph import SampledSubgraph
from ..subgraph_sampler import all_to_all

__all__ = ["CooperativeConvFunction", "CooperativeConv"]

Check warning on line 7 in python/dgl/graphbolt/impl/cooperative_conv.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
class CooperativeConvFunction(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ def _seeds_cooperative_exchange_2(minibatch):
typed_seeds.split(typed_counts_sent),
)
seeds_received[ntype] = typed_seeds_received
counts_sent[ntype] = typed_counts_sent
counts_received[ntype] = typed_counts_received
minibatch._seed_nodes = seeds_received
subgraph._counts_sent = revert_to_homo(counts_sent)
subgraph._counts_received = revert_to_homo(counts_received)
Expand Down
2 changes: 2 additions & 0 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
group,
)
seeds_received[ntype] = typed_seeds_received
counts_sent[ntype] = typed_counts_sent
counts_received[ntype] = typed_counts_received
minibatch._seed_nodes = seeds_received
minibatch._counts_sent = revert_to_homo(counts_sent)
minibatch._counts_received = revert_to_homo(counts_received)
Expand Down
15 changes: 15 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dgl
import dgl.graphbolt
import dgl.graphbolt as gb
import pytest
import torch
import torch.distributed as thd
Expand Down Expand Up @@ -194,5 +195,19 @@ def test_gpu_sampling_DataLoader(
if sampler_name == "LayerNeighborSampler":
assert torch.equal(edge_feature, edge_feature_ref)
assert len(list(dataloader)) == N // B

if asynchronous and cooperative:
for minibatch in minibatches:
x = torch.ones((minibatch.node_ids().size(0), 1), device=F.ctx())
for subgraph in minibatch.sampled_subgraphs:
x = gb.CooperativeConvFunction.apply(subgraph, x)
x, edge_index, size = subgraph.to_pyg(x)
x = x[0]

Check warning on line 205 in tests/python/pytorch/graphbolt/test_dataloader.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
one = torch.ones(edge_index.shape[1], dtype=x.dtype, device=x.device)
coo = torch.sparse_coo_tensor(edge_index.flipud(), one, size=(size[1], size[0]))
x = torch.sparse.mm(coo, x)
assert x.shape[0] == minibatch.seeds.shape[0]
assert x.shape[1] == 1

if thd.is_initialized():
thd.destroy_process_group()

0 comments on commit 7730aea

Please sign in to comment.