Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_gpu_graph_cache_cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jun 26, 2024
2 parents 2187274 + 95dc96a commit 0fcf5fe
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 19 deletions.
14 changes: 14 additions & 0 deletions examples/sampling/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def create_dataloader(
datapipe,
num_workers=args.num_workers,
overlap_graph_fetch=args.overlap_graph_fetch,
num_gpu_cached_edges=args.num_gpu_cached_edges,
gpu_cache_threshold=args.gpu_graph_caching_threshold,
)


Expand Down Expand Up @@ -370,6 +372,18 @@ def parse_args():
"with the rest of operations by using an alternative CUDA stream. Disabled"
"by default.",
)
parser.add_argument(
"--num-gpu-cached-edges",
type=int,
default=0,
help="The number of edges to be cached from the graph on the GPU.",
)
parser.add_argument(
"--gpu-graph-caching-threshold",
type=int,
default=1,
help="The number of accesses after which a vertex neighborhood will be cached.",
)
parser.add_argument(
"--torch-compile",
action="store_true",
Expand Down
48 changes: 47 additions & 1 deletion python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Graph Bolt DataLoaders"""

from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor

import torch
Expand All @@ -9,6 +10,7 @@

from .base import CopyTo
from .feature_fetcher import FeatureFetcher
from .impl.gpu_graph_cache import GPUGraphCache
from .impl.neighbor_sampler import SamplePerLayer

from .internal import datapipe_graph_to_adjlist
Expand All @@ -17,9 +19,34 @@

__all__ = [
"DataLoader",
"construct_gpu_graph_cache",
]


def construct_gpu_graph_cache(
sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold
):
"Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters."
graph = sample_per_layer_obj.sampler.__self__
num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges)
dtypes = OrderedDict()
dtypes["indices"] = graph.indices.dtype
if graph.type_per_edge is not None:
dtypes["type_per_edge"] = graph.type_per_edge.dtype
if graph.edge_attributes is not None:
probs_or_mask = graph.edge_attributes.get(
sample_per_layer_obj.prob_name, None
)
if probs_or_mask is not None:
dtypes["probs_or_mask"] = probs_or_mask.dtype
return GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
graph.csc_indptr.dtype,
list(dtypes.values()),
)


def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
Expand Down Expand Up @@ -106,6 +133,13 @@ class DataLoader(torch.utils.data.DataLoader):
If True, the data loader will overlap the UVA graph fetching operations
with the rest of operations by using an alternative CUDA stream. Default
is False.
num_gpu_cached_edges : int, optional
If positive and overlap_graph_fetch is True, then the GPU will cache
frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth
demand due to pinned graph accesses.
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
Expand All @@ -121,6 +155,8 @@ def __init__(
persistent_workers=True,
overlap_feature_fetch=True,
overlap_graph_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
max_uva_threads=6144,
):
# Multiprocessing requires two modifications to the datapipe:
Expand Down Expand Up @@ -188,11 +224,21 @@ def __init__(
SamplePerLayer,
)
executor = ThreadPoolExecutor(max_workers=1)
gpu_graph_cache = None
for sampler in samplers:
if gpu_graph_cache is None:
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(_get_uva_stream(), executor, 1),
sampler.fetch_and_sample(
gpu_graph_cache,
_get_uva_stream(),
executor,
1,
),
)

# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
Expand Down
122 changes: 106 additions & 16 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,90 @@
"SamplePerLayer",
"SamplePerLayerFromFetchedSubgraph",
"FetchInsubgraphData",
"ConcatHeteroSeeds",
"CombineCachedAndFetchedInSubgraph",
]


@functional_datapipe("fetch_cached_insubgraph_data")
class FetchCachedInsubgraphData(Mapper):
"""Queries the GPUGraphCache and returns the missing seeds and a lambda
function that can be called with the fetched graph structure.
"""

def __init__(self, datapipe, gpu_graph_cache):
super().__init__(datapipe, self._fetch_per_layer)
self.cache = gpu_graph_cache

def _fetch_per_layer(self, minibatch):
minibatch._seeds, minibatch._replace = self.cache.query(
minibatch._seeds
)

return minibatch


@functional_datapipe("combine_cached_and_fetched_insubgraph")
class CombineCachedAndFetchedInSubgraph(Mapper):
"""Combined the fetched graph structure with the graph structure already
found inside the GPUGraphCache.
"""

def __init__(self, datapipe, sample_per_layer_obj):
super().__init__(datapipe, self._combine_per_layer)
self.prob_name = sample_per_layer_obj.prob_name

def _combine_per_layer(self, minibatch):
subgraph = minibatch._sliced_sampling_graph

edge_tensors = [subgraph.indices]
if subgraph.type_per_edge is not None:
edge_tensors.append(subgraph.type_per_edge)
probs_or_mask = subgraph.edge_attribute(self.prob_name)
if probs_or_mask is not None:
edge_tensors.append(probs_or_mask)

subgraph.csc_indptr, edge_tensors = minibatch._replace(
subgraph.csc_indptr, edge_tensors
)
delattr(minibatch, "_replace")

subgraph.indices = edge_tensors[0]
edge_tensors = edge_tensors[1:]
if subgraph.type_per_edge is not None:
subgraph.type_per_edge = edge_tensors[0]
edge_tensors = edge_tensors[1:]
if probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])
edge_tensors = edge_tensors[1:]
assert len(edge_tensors) == 0

return minibatch


@functional_datapipe("concat_hetero_seeds")
class ConcatHeteroSeeds(Mapper):
"""Concatenates the seeds into a single tensor in the hetero case."""

def __init__(self, datapipe, sample_per_layer_obj):
super().__init__(datapipe, self._concat)
self.graph = sample_per_layer_obj.sampler.__self__

def _concat(self, minibatch):
seeds = minibatch._seed_nodes
if isinstance(seeds, dict):
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seed_offsets = None
minibatch._seeds = seeds
minibatch._seed_offsets = seed_offsets

return minibatch


@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(Mapper):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
Expand All @@ -33,10 +114,18 @@ class FetchInsubgraphData(Mapper):
read as well."""

def __init__(
self, datapipe, sample_per_layer_obj, stream=None, executor=None
self,
datapipe,
sample_per_layer_obj,
gpu_graph_cache,
stream=None,
executor=None,
):
super().__init__(datapipe, self._fetch_per_layer)
self.graph = sample_per_layer_obj.sampler.__self__
datapipe = datapipe.concat_hetero_seeds(sample_per_layer_obj)
if gpu_graph_cache is not None:
datapipe = datapipe.fetch_cached_insubgraph_data(gpu_graph_cache)
super().__init__(datapipe, self._fetch_per_layer)
self.prob_name = sample_per_layer_obj.prob_name
self.stream = stream
if executor is None:
Expand All @@ -46,18 +135,10 @@ def __init__(

def _fetch_per_layer_impl(self, minibatch, stream):
with torch.cuda.stream(self.stream):
seeds = minibatch._seed_nodes
is_hetero = isinstance(seeds, dict)
if is_hetero:
for idx in seeds.values():
idx.record_stream(torch.cuda.current_stream())
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seeds.record_stream(torch.cuda.current_stream())
seed_offsets = None
seeds = minibatch._seeds
seed_offsets = minibatch._seed_offsets
delattr(minibatch, "_seeds")
delattr(minibatch, "_seed_offsets")

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
Expand Down Expand Up @@ -222,11 +303,20 @@ def _compact_per_layer(self, minibatch):
class FetcherAndSampler(MiniBatchTransformer):
"""Overlapped graph sampling operation replacement."""

def __init__(self, sampler, stream, executor, buffer_size):
def __init__(
self,
sampler,
gpu_graph_cache,
stream,
executor,
buffer_size,
):
datapipe = sampler.datapipe.fetch_insubgraph_data(
sampler, stream, executor
sampler, gpu_graph_cache, stream, executor
)
datapipe = datapipe.buffer(buffer_size).wait_future().wait()
if gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(sampler)
datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler)
super().__init__(datapipe)

Expand Down
21 changes: 19 additions & 2 deletions tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dgl.graphbolt as gb
import pytest
import torch
from dgl.graphbolt.dataloader import construct_gpu_graph_cache


def get_hetero_graph():
Expand Down Expand Up @@ -42,7 +43,10 @@ def get_hetero_graph():
@pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("prob_name", [None, "weight", "mask"])
@pytest.mark.parametrize("sorted", [False, True])
def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
@pytest.mark.parametrize("num_cached_edges", [0, 10])
def test_NeighborSampler_GraphFetch(
hetero, prob_name, sorted, num_cached_edges
):
if sorted:
items = torch.arange(3)
else:
Expand All @@ -66,8 +70,21 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
compact_per_layer = sample_per_layer.compact_per_layer(True)
gb.seed(123)
expected_results = list(compact_per_layer)
datapipe = gb.FetchInsubgraphData(datapipe, sample_per_layer)
gpu_graph_cache = None
if num_cached_edges > 0:
gpu_graph_cache = construct_gpu_graph_cache(
sample_per_layer, num_cached_edges, 1
)
datapipe = gb.FetchInsubgraphData(
datapipe,
sample_per_layer,
gpu_graph_cache,
)
datapipe = datapipe.wait_future()
if num_cached_edges > 0:
datapipe = gb.CombineCachedAndFetchedInSubgraph(
datapipe, sample_per_layer
)
datapipe = gb.SamplePerLayerFromFetchedSubgraph(datapipe, sample_per_layer)
datapipe = datapipe.compact_per_layer(True)
gb.seed(123)
Expand Down
6 changes: 6 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,15 @@ def test_DataLoader():
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_graph_fetch", [True, False])
@pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024])
@pytest.mark.parametrize("gpu_cache_threshold", [1, 3])
def test_gpu_sampling_DataLoader(
sampler_name,
enable_feature_fetch,
overlap_feature_fetch,
overlap_graph_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
):
N = 40
B = 4
Expand Down Expand Up @@ -94,6 +98,8 @@ def test_gpu_sampling_DataLoader(
datapipe,
overlap_feature_fetch=overlap_feature_fetch,
overlap_graph_fetch=overlap_graph_fetch,
num_gpu_cached_edges=num_gpu_cached_edges,
gpu_cache_threshold=gpu_cache_threshold,
)
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
if overlap_graph_fetch:
Expand Down

0 comments on commit 0fcf5fe

Please sign in to comment.