Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Add CooperativeConv and minor fixes. #7797

Merged
merged 11 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .gpu_graph_cache import *
from .cpu_feature_cache import *
from .cpu_cached_feature import *
from .cooperative_conv import *
109 changes: 109 additions & 0 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Graphbolt cooperative convolution."""
from typing import Dict, Union

import torch

from ..sampled_subgraph import SampledSubgraph
from ..subgraph_sampler import all_to_all, convert_to_hetero, revert_to_homo

__all__ = ["CooperativeConvFunction", "CooperativeConv"]


class CooperativeConvFunction(torch.autograd.Function):
"""Cooperative convolution operation from Cooperative Minibatching.

Implements the `all-to-all` message passing algorithm
in Cooperative Minibatching, which was initially proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__ and
was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__.
Cooperation between the GPUs eliminates duplicate work performed across the
GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when
performing GNN minibatching. This reduces the redundant computations across
GPUs at the expense of communication.
"""

@staticmethod
def forward(
ctx,
subgraph: SampledSubgraph,
tensor: Union[torch.Tensor, Dict[str, torch.Tensor]],
):
"""Implements the forward pass."""
counts_sent = convert_to_hetero(subgraph._counts_sent)
counts_received = convert_to_hetero(subgraph._counts_received)
seed_inverse_ids = convert_to_hetero(subgraph._seed_inverse_ids)
seed_sizes = convert_to_hetero(subgraph._seed_sizes)
ctx.save_for_backward(
counts_sent, counts_received, seed_inverse_ids, seed_sizes
)
outs = {}
for ntype, typed_tensor in convert_to_hetero(tensor).items():
out = typed_tensor.new_empty(
(sum(counts_sent[ntype]),) + typed_tensor.shape[1:]
)
all_to_all(
torch.split(out, counts_sent[ntype]),
torch.split(
typed_tensor[seed_inverse_ids[ntype]],
counts_received[ntype],
),
)
outs[ntype] = out
return revert_to_homo(out)

@staticmethod
def backward(
ctx, grad_output: Union[torch.Tensor, Dict[str, torch.Tensor]]
):
"""Implements the forward pass."""
(
counts_sent,
counts_received,
seed_inverse_ids,
seed_sizes,
) = ctx.saved_tensors
outs = {}
for ntype, typed_grad_output in convert_to_hetero(grad_output).items():
out = typed_grad_output.new_empty(
(sum(counts_received[ntype]),) + typed_grad_output.shape[1:]
)
all_to_all(
torch.split(out, counts_received[ntype]),
torch.split(typed_grad_output, counts_sent[ntype]),
)
i = out.new_empty(2, out.shape[0], dtype=torch.int64)
i[0] = torch.arange(
out.shape[0], device=typed_grad_output.device
) # src
i[1] = seed_inverse_ids[ntype] # dst
coo = torch.sparse_coo_tensor(
i, 1, size=(seed_sizes[ntype], i.shape[1])
)
outs[ntype] = torch.sparse.mm(coo, out)
return None, revert_to_homo(outs)


class CooperativeConv(torch.nn.Module):
"""Cooperative convolution operation from Cooperative Minibatching.

Implements the `all-to-all` message passing algorithm
in Cooperative Minibatching, which was initially proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__ and
was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__.
Cooperation between the GPUs eliminates duplicate work performed across the
GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when
performing GNN minibatching. This reduces the redundant computations across
GPUs at the expense of communication.
"""

def forward(
self,
subgraph: SampledSubgraph,
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
):
"""Implements the forward pass."""
return CooperativeConvFunction.apply(subgraph, x)
32 changes: 29 additions & 3 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,17 +601,18 @@ def _seeds_cooperative_exchange_2(minibatch):
typed_seeds.split(typed_counts_sent),
)
seeds_received[ntype] = typed_seeds_received
subgraph._seeds_received = seeds_received
counts_sent[ntype] = typed_counts_sent
counts_received[ntype] = typed_counts_received
minibatch._seed_nodes = seeds_received
subgraph._counts_sent = revert_to_homo(counts_sent)
subgraph._counts_received = revert_to_homo(counts_received)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
subgraph = minibatch.sampled_subgraphs[0]
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in subgraph._seeds_received.items()
for ntype, typed_seeds in minibatch._seed_nodes.items()
}
minibatch._unique_future = unique_and_compact(
nodes, 0, 1, async_op=True
Expand All @@ -627,6 +628,11 @@ def _seeds_cooperative_exchange_4(minibatch):
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
subgraph = minibatch.sampled_subgraphs[0]
sizes = {
ntype: typed_seeds.size(0)
for ntype, typed_seeds in unique_seeds.items()
}
subgraph._seed_sizes = revert_to_homo(sizes)
subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch

Expand Down Expand Up @@ -831,6 +837,16 @@ class NeighborSampler(NeighborSamplerImpl):
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
cooperative: bool, optional
Boolean indicating whether Cooperative Minibatching, which was initially
proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__
and was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs
eliminates duplicate work performed across the GPUs due to the
overlapping sampled k-hop neighborhoods of seed nodes when performing
GNN minibatching.
asynchronous: bool
Boolean indicating whether sampling and compaction stages should run
in background threads to hide the latency of CPU GPU synchronization.
Expand Down Expand Up @@ -986,6 +1002,16 @@ class LayerNeighborSampler(NeighborSamplerImpl):
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
cooperative: bool, optional
Boolean indicating whether Cooperative Minibatching, which was initially
proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__
and was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs
eliminates duplicate work performed across the GPUs due to the
overlapping sampled k-hop neighborhoods of seed nodes when performing
GNN minibatching.
asynchronous: bool
Boolean indicating whether sampling and compaction stages should run
in background threads to hide the latency of CPU GPU synchronization.
Expand Down
15 changes: 15 additions & 0 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
__all__ = [
"SubgraphSampler",
"all_to_all",
"convert_to_hetero",
"revert_to_homo",
]

Expand Down Expand Up @@ -89,6 +90,13 @@ def revert_to_homo(d: dict):
return list(d.values())[0] if is_homogenous else d


def convert_to_hetero(item):
"""Utility function to convert homogenous data to heterogenous with a single
node type."""
is_heterogenous = isinstance(item, dict)
return item if is_heterogenous else {"_N": item}


@functional_datapipe("sample_subgraph")
class SubgraphSampler(MiniBatchTransformer):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
Expand Down Expand Up @@ -251,6 +259,8 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
group,
)
seeds_received[ntype] = typed_seeds_received
counts_sent[ntype] = typed_counts_sent
counts_received[ntype] = typed_counts_received
minibatch._seed_nodes = seeds_received
minibatch._counts_sent = revert_to_homo(counts_sent)
minibatch._counts_received = revert_to_homo(counts_received)
Expand All @@ -275,6 +285,11 @@ def _seeds_cooperative_exchange_4(minibatch):
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
sizes = {
ntype: typed_seeds.size(0)
for ntype, typed_seeds in unique_seeds.items()
}
minibatch._seed_sizes = revert_to_homo(sizes)
minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch

Expand Down
19 changes: 19 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dgl
import dgl.graphbolt
import dgl.graphbolt as gb
import pytest
import torch
import torch.distributed as thd
Expand Down Expand Up @@ -194,5 +195,23 @@ def test_gpu_sampling_DataLoader(
if sampler_name == "LayerNeighborSampler":
assert torch.equal(edge_feature, edge_feature_ref)
assert len(list(dataloader)) == N // B

if asynchronous and cooperative:
for minibatch in minibatches:
x = torch.ones((minibatch.node_ids().size(0), 1), device=F.ctx())
for subgraph in minibatch.sampled_subgraphs:
x = gb.CooperativeConvFunction.apply(subgraph, x)
x, edge_index, size = subgraph.to_pyg(x)
x = x[0]
one = torch.ones(
edge_index.shape[1], dtype=x.dtype, device=x.device
)
coo = torch.sparse_coo_tensor(
edge_index.flipud(), one, size=(size[1], size[0])
)
x = torch.sparse.mm(coo, x)
assert x.shape[0] == minibatch.seeds.shape[0]
assert x.shape[1] == 1

if thd.is_initialized():
thd.destroy_process_group()
Loading