Skip to content

Commit

Permalink
Merge branch 'master' into dist_partition
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying committed Sep 12, 2024
2 parents 3834358 + 165e250 commit fe751b1
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 14 deletions.
1 change: 1 addition & 0 deletions python/dgl/graphbolt/internal/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def wait(self):
if is_homogeneous:
compacted_csc_formats = list(compacted_csc_formats.values())[0]
unique_nodes = list(unique_nodes.values())[0]
offsets = list(offsets.values())[0]

return unique_nodes, compacted_csc_formats, offsets

Expand Down
193 changes: 180 additions & 13 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from typing import Dict

import torch
import torch.distributed as thd
from torch.utils.data import functional_datapipe

from .base import seed_type_str_to_ntypes
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch import MiniBatch
from .minibatch_transformer import MiniBatchTransformer

__all__ = [
Expand All @@ -28,6 +30,25 @@ def wait(self):
return result


def _shift(inputs: list, group=None):
cutoff = len(inputs) - thd.get_rank(group)
return inputs[cutoff:] + inputs[:cutoff]


def all_to_all(outputs, inputs, group=None, async_op=False):
"""Wrapper for thd.all_to_all that permuted outputs and inputs before
calling it. The arguments have the permutation
`rank, ..., world_size - 1, 0, ..., rank - 1` and we make it
`0, world_size - 1` before calling `thd.all_to_all`."""
shift_fn = partial(_shift, group=group)
return thd.all_to_all(shift_fn(outputs), shift_fn(inputs), group, async_op)


def _revert_to_homo(d: dict):
is_homogenous = len(d) == 1 and "_N" in d
return list(d.values())[0] if is_homogenous else d


@functional_datapipe("sample_subgraph")
class SubgraphSampler(MiniBatchTransformer):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
Expand All @@ -49,8 +70,8 @@ class SubgraphSampler(MiniBatchTransformer):
Arguments to be passed into sampling_stages.
kwargs : Keyword Arguments
Arguments to be passed into sampling_stages. Preprocessing stage makes
use of the `asynchronous` parameter before it is passed to
the sampling stages.
use of the `asynchronous` and `cooperative` parameters before they are
passed to the sampling stages.
"""

def __init__(
Expand All @@ -60,10 +81,22 @@ def __init__(
**kwargs,
):
async_op = kwargs.get("asynchronous", False)
preprocess_fn = partial(self._preprocess, async_op=async_op)
cooperative = kwargs.get("cooperative", False)
preprocess_fn = partial(
self._preprocess, cooperative=cooperative, async_op=async_op
)
datapipe = datapipe.transform(preprocess_fn)
if async_op:
datapipe = datapipe.buffer().transform(self._wait_preprocess_future)
fn = partial(self._wait_preprocess_future, cooperative=cooperative)
datapipe = datapipe.buffer().transform(fn)
if cooperative:
datapipe = datapipe.transform(self._seeds_cooperative_exchange_1)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_2)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_3)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_4)
datapipe = self.sampling_stages(datapipe, *args, **kwargs)
datapipe = datapipe.transform(self._postprocess)
super().__init__(datapipe)
Expand All @@ -75,30 +108,142 @@ def _postprocess(minibatch):
return minibatch

@staticmethod
def _preprocess(minibatch, async_op: bool):
def _preprocess(minibatch, cooperative: bool, async_op: bool):
if minibatch.seeds is None:
raise ValueError(
f"Invalid minibatch {minibatch}: `seeds` should have a value."
)
results = SubgraphSampler._seeds_preprocess(minibatch, async_op)
rank = thd.get_rank() if cooperative else 0
world_size = thd.get_world_size() if cooperative else 1
results = SubgraphSampler._seeds_preprocess(
minibatch, rank, world_size, async_op
)
if async_op:
minibatch._preprocess_future = results
else:
(
minibatch._seed_nodes,
minibatch._seeds_timestamp,
minibatch.compacted_seeds,
offsets,
) = results
if cooperative:
minibatch._seeds_offsets = offsets
return minibatch

@staticmethod
def _wait_preprocess_future(minibatch):
def _wait_preprocess_future(minibatch, cooperative: bool):
(
minibatch._seed_nodes,
minibatch._seeds_timestamp,
minibatch.compacted_seeds,
offsets,
) = minibatch._preprocess_future.wait()
delattr(minibatch, "_preprocess_future")
if cooperative:
minibatch._seeds_offsets = offsets
return minibatch

@staticmethod
def _seeds_cooperative_exchange_1(minibatch, group=None):
rank = thd.get_rank(group)
world_size = thd.get_world_size(group)
assert world_size > 1
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
if minibatch._seeds_offsets is None:
seeds_list = list(seeds.values())
(
sorted_seeds_list,
index_list,
offsets_list,
) = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
assert minibatch.compacted_seeds is None
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
num_ntypes = len(seeds.keys())
for i, (
seed_type,
typed_sorted_seeds,
typed_index,
typed_offsets,
) in enumerate(
zip(
seeds.keys(),
sorted_seeds_list,
index_list,
offsets_list,
)
):
sorted_seeds[seed_type] = typed_sorted_seeds
sorted_compacted[seed_type] = typed_index
sorted_offsets[seed_type] = typed_offsets.tolist()

minibatch._seed_nodes = sorted_seeds
minibatch.compacted_seeds = sorted_compacted
minibatch._seeds_offsets = sorted_offsets
else:
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(minibatch._seeds_offsets[0].values()):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
delattr(minibatch, "_seeds_offsets")
counts_received = torch.empty_like(counts_sent)
minibatch._counts_future = all_to_all(
counts_received.split(num_ntypes),
counts_sent.split(num_ntypes),
group=group,
async_op=True,
)
minibatch._counts_sent = counts_sent
minibatch._counts_received = counts_received
return minibatch

@staticmethod
def _seeds_cooperative_exchange_2(minibatch, group=None):
world_size = thd.get_world_size(group)
seeds = minibatch._seed_nodes
minibatch._counts_future.wait()
delattr(minibatch, "_counts_future")
counts_received = minibatch._counts_received
num_ntypes = len(seeds.keys())
seeds_received = {}
counts_sent = {}
counts_received = {}
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
typed_counts_sent = minibatch._counts_sent[idx].tolist()
typed_counts_received = minibatch._counts_received[idx].tolist()
typed_seeds_received = typed_seeds.new_empty(
sum(typed_counts_received)
)
all_to_all(
typed_seeds_received.split(typed_counts_received),
typed_seeds.split(typed_counts_sent),
group,
)
seeds_received[ntype] = typed_seeds_received
minibatch._seed_nodes = _revert_to_homo(seeds_received)
minibatch._counts_sent = _revert_to_homo(counts_sent)
minibatch._counts_received = _revert_to_homo(counts_received)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
minibatch._unique_future = unique_and_compact(
minibatch._seed_nodes, 0, 1, async_op=True
)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_4(minibatch):
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
delattr(minibatch, "_unique_future")
minibatch._seed_nodes = _revert_to_homo(unique_seeds)
minibatch._seed_inverse_ids = _revert_to_homo(inverse_seeds)
return minibatch

def _sample(self, minibatch):
Expand All @@ -119,7 +264,12 @@ def sampling_stages(self, datapipe):
return datapipe.transform(self._sample)

@staticmethod
def _seeds_preprocess(minibatch, async_op):
def _seeds_preprocess(
minibatch: MiniBatch,
rank: int = 0,
world_size: int = 1,
async_op: bool = False,
):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
Expand All @@ -130,6 +280,11 @@ def _seeds_preprocess(minibatch, async_op):
----------
minibatch: MiniBatch
The minibatch.
rank : int
The rank of the current process among cooperating processes.
world_size : int
The number of cooperating
(`arXiv:2210.13339<https://arxiv.org/abs/2310.12403>`__) processes.
async_op: bool
Boolean indicating whether the call is asynchronous. If so, the
result can be obtained by calling wait on the returned future.
Expand All @@ -145,8 +300,16 @@ def _seeds_preprocess(minibatch, async_op):
compacted_seeds: torch.tensor or a Dict[str, torch.Tensor]
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
offsets: None or torch.Tensor or Dict[src, torch.Tensor]
The unique nodes offsets tensor partitions the unique_nodes tensor.
Has size `world_size + 1` and
`unique_nodes[offsets[i]: offsets[i + 1]]` belongs to the rank
`(rank + i) % world_size`.
"""
use_timestamp = hasattr(minibatch, "timestamp")
assert (
not use_timestamp or world_size == 1
), "Temporal code path does not currently support Cooperative Minibatching"
seeds = minibatch.seeds
is_heterogeneous = isinstance(seeds, Dict)
if is_heterogeneous:
Expand All @@ -164,7 +327,7 @@ def _seeds_preprocess(minibatch, async_op):
if hasattr(minibatch, "timestamp")
else None
)
result = _NoOpWaiter((seeds, nodes_timestamp, None))
result = _NoOpWaiter((seeds, nodes_timestamp, None, None))
break
result = None
assert typed_seeds.ndim == 2, (
Expand Down Expand Up @@ -200,16 +363,17 @@ def __init__(self, nodes, nodes_timestamp, seeds):
)
else:
self.future = unique_and_compact(
nodes, async_op=async_op
nodes, rank, world_size, async_op
)
self.seeds = seeds

def wait(self):
"""Returns the stored value when invoked."""
if use_timestamp:
unique_seeds, nodes_timestamp, compacted = self.future
offsets = None
else:
unique_seeds, compacted, _ = (
unique_seeds, compacted, offsets = (
self.future.wait() if async_op else self.future
)
nodes_timestamp = None
Expand All @@ -234,6 +398,7 @@ def wait(self):
unique_seeds,
nodes_timestamp,
compacted_seeds,
offsets,
)

# When typed_seeds is not a one-dimensional tensor
Expand All @@ -248,7 +413,7 @@ def wait(self):
if hasattr(minibatch, "timestamp")
else None
)
result = _NoOpWaiter((seeds, nodes_timestamp, None))
result = _NoOpWaiter((seeds, nodes_timestamp, None, None))
else:
# Collect nodes from all types of input.
nodes = [seeds.view(-1)]
Expand Down Expand Up @@ -289,8 +454,9 @@ def wait(self):
nodes_timestamp,
compacted,
) = self.future
offsets = None
else:
unique_seeds, compacted, _ = (
unique_seeds, compacted, offsets = (
self.future.wait() if async_op else self.future
)
nodes_timestamp = None
Expand All @@ -305,6 +471,7 @@ def wait(self):
unique_seeds,
nodes_timestamp,
compacted_seeds,
offsets,
)

result = _Waiter(nodes, nodes_timestamp, seeds)
Expand Down
4 changes: 3 additions & 1 deletion tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def test_NeighborSampler_GraphFetch(
graph.type_per_edge = None
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
fanout = torch.LongTensor([2])
preprocess_fn = partial(gb.SubgraphSampler._preprocess, async_op=False)
preprocess_fn = partial(
gb.SubgraphSampler._preprocess, cooperative=False, async_op=False
)
datapipe = item_sampler.map(preprocess_fn)
datapipe = datapipe.map(
partial(gb.NeighborSampler._prepare, graph.node_type_to_id)
Expand Down

0 comments on commit fe751b1

Please sign in to comment.