Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Incremental GPU graph cache into gb.Dataloader. #7475

Merged
merged 12 commits into from
Jun 26, 2024
48 changes: 47 additions & 1 deletion python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Graph Bolt DataLoaders"""

from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor

import torch
Expand All @@ -9,6 +10,7 @@

from .base import CopyTo
from .feature_fetcher import FeatureFetcher
from .impl.gpu_graph_cache import GPUGraphCache
from .impl.neighbor_sampler import SamplePerLayer

from .internal import datapipe_graph_to_adjlist
Expand All @@ -17,9 +19,34 @@

__all__ = [
"DataLoader",
"construct_gpu_graph_cache",
]


def construct_gpu_graph_cache(
sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold
):
"Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters."
graph = sample_per_layer_obj.sampler.__self__
num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges)
dtypes = OrderedDict()
dtypes["indices"] = graph.indices.dtype
if graph.type_per_edge is not None:
dtypes["type_per_edge"] = graph.type_per_edge.dtype
if graph.edge_attributes is not None:
probs_or_mask = graph.edge_attributes.get(
sample_per_layer_obj.prob_name, None
)
if probs_or_mask is not None:
dtypes["probs_or_mask"] = probs_or_mask.dtype
return GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
graph.csc_indptr.dtype,
list(dtypes.values()),
)


def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
Expand Down Expand Up @@ -106,6 +133,13 @@ class DataLoader(torch.utils.data.DataLoader):
If True, the data loader will overlap the UVA graph fetching operations
with the rest of operations by using an alternative CUDA stream. Default
is False.
num_gpu_cached_edges : int, optional
If positive and overlap_graph_fetch is True, then the GPU will cache
frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth
demand due to pinned graph accesses.
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
Expand All @@ -121,6 +155,8 @@ def __init__(
persistent_workers=True,
overlap_feature_fetch=True,
overlap_graph_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
max_uva_threads=6144,
):
# Multiprocessing requires two modifications to the datapipe:
Expand Down Expand Up @@ -188,11 +224,21 @@ def __init__(
SamplePerLayer,
)
executor = ThreadPoolExecutor(max_workers=1)
gpu_graph_cache = None
for sampler in samplers:
if gpu_graph_cache is None:
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(_get_uva_stream(), executor, 1),
sampler.fetch_and_sample(
gpu_graph_cache,
_get_uva_stream(),
executor,
1,
),
)

# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
Expand Down
122 changes: 106 additions & 16 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,90 @@
"SamplePerLayer",
"SamplePerLayerFromFetchedSubgraph",
"FetchInsubgraphData",
"ConcatHeteroSeeds",
"CombineCachedAndFetchedInSubgraph",
]


@functional_datapipe("fetch_cached_insubgraph_data")
class FetchCachedInsubgraphData(Mapper):
"""Queries the GPUGraphCache and returns the missing seeds and a lambda
function that can be called with the fetched graph structure.
"""

def __init__(self, datapipe, gpu_graph_cache):
super().__init__(datapipe, self._fetch_per_layer)
self.cache = gpu_graph_cache

def _fetch_per_layer(self, minibatch):
minibatch._seeds, minibatch._replace = self.cache.query(
minibatch._seeds
)

return minibatch


@functional_datapipe("combine_cached_and_fetched_insubgraph")
class CombineCachedAndFetchedInSubgraph(Mapper):
"""Combined the fetched graph structure with the graph structure already
found inside the GPUGraphCache.
"""

def __init__(self, datapipe, sample_per_layer_obj):
super().__init__(datapipe, self._combine_per_layer)
self.prob_name = sample_per_layer_obj.prob_name

def _combine_per_layer(self, minibatch):
subgraph = minibatch._sliced_sampling_graph

edge_tensors = [subgraph.indices]
if subgraph.type_per_edge is not None:
edge_tensors.append(subgraph.type_per_edge)
probs_or_mask = subgraph.edge_attribute(self.prob_name)
if probs_or_mask is not None:
edge_tensors.append(probs_or_mask)

subgraph.csc_indptr, edge_tensors = minibatch._replace(
subgraph.csc_indptr, edge_tensors
)
delattr(minibatch, "_replace")

subgraph.indices = edge_tensors[0]
edge_tensors = edge_tensors[1:]
if subgraph.type_per_edge is not None:
subgraph.type_per_edge = edge_tensors[0]
edge_tensors = edge_tensors[1:]
if probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])
edge_tensors = edge_tensors[1:]
assert len(edge_tensors) == 0

return minibatch


@functional_datapipe("concat_hetero_seeds")
class ConcatHeteroSeeds(Mapper):
"""Concatenates the seeds into a single tensor in the hetero case."""

def __init__(self, datapipe, sample_per_layer_obj):
super().__init__(datapipe, self._concat)
self.graph = sample_per_layer_obj.sampler.__self__

def _concat(self, minibatch):
seeds = minibatch._seed_nodes
if isinstance(seeds, dict):
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seed_offsets = None
minibatch._seeds = seeds
minibatch._seed_offsets = seed_offsets

return minibatch


@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(Mapper):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
Expand All @@ -33,10 +114,18 @@ class FetchInsubgraphData(Mapper):
read as well."""

def __init__(
self, datapipe, sample_per_layer_obj, stream=None, executor=None
self,
datapipe,
sample_per_layer_obj,
gpu_graph_cache,
stream=None,
executor=None,
):
super().__init__(datapipe, self._fetch_per_layer)
self.graph = sample_per_layer_obj.sampler.__self__
datapipe = datapipe.concat_hetero_seeds(sample_per_layer_obj)
if gpu_graph_cache is not None:
datapipe = datapipe.fetch_cached_insubgraph_data(gpu_graph_cache)
super().__init__(datapipe, self._fetch_per_layer)
self.prob_name = sample_per_layer_obj.prob_name
self.stream = stream
if executor is None:
Expand All @@ -46,18 +135,10 @@ def __init__(

def _fetch_per_layer_impl(self, minibatch, stream):
with torch.cuda.stream(self.stream):
seeds = minibatch._seed_nodes
is_hetero = isinstance(seeds, dict)
if is_hetero:
for idx in seeds.values():
idx.record_stream(torch.cuda.current_stream())
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seeds.record_stream(torch.cuda.current_stream())
seed_offsets = None
seeds = minibatch._seeds
seed_offsets = minibatch._seed_offsets
delattr(minibatch, "_seeds")
delattr(minibatch, "_seed_offsets")

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
Expand Down Expand Up @@ -222,11 +303,20 @@ def _compact_per_layer(self, minibatch):
class FetcherAndSampler(MiniBatchTransformer):
"""Overlapped graph sampling operation replacement."""

def __init__(self, sampler, stream, executor, buffer_size):
def __init__(
self,
sampler,
gpu_graph_cache,
stream,
executor,
buffer_size,
):
datapipe = sampler.datapipe.fetch_insubgraph_data(
sampler, stream, executor
sampler, gpu_graph_cache, stream, executor
)
datapipe = datapipe.buffer(buffer_size).wait_future().wait()
if gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(sampler)
datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler)
super().__init__(datapipe)

Expand Down
21 changes: 19 additions & 2 deletions tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dgl.graphbolt as gb
import pytest
import torch
from dgl.graphbolt.dataloader import construct_gpu_graph_cache


def get_hetero_graph():
Expand Down Expand Up @@ -42,7 +43,10 @@ def get_hetero_graph():
@pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("prob_name", [None, "weight", "mask"])
@pytest.mark.parametrize("sorted", [False, True])
def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
@pytest.mark.parametrize("num_cached_edges", [0, 10])
def test_NeighborSampler_GraphFetch(
hetero, prob_name, sorted, num_cached_edges
):
if sorted:
items = torch.arange(3)
else:
Expand All @@ -66,8 +70,21 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
compact_per_layer = sample_per_layer.compact_per_layer(True)
gb.seed(123)
expected_results = list(compact_per_layer)
datapipe = gb.FetchInsubgraphData(datapipe, sample_per_layer)
gpu_graph_cache = None
if num_cached_edges > 0:
gpu_graph_cache = construct_gpu_graph_cache(
sample_per_layer, num_cached_edges, 1
)
datapipe = gb.FetchInsubgraphData(
datapipe,
sample_per_layer,
gpu_graph_cache,
)
datapipe = datapipe.wait_future()
if num_cached_edges > 0:
datapipe = gb.CombineCachedAndFetchedInSubgraph(
datapipe, sample_per_layer
)
datapipe = gb.SamplePerLayerFromFetchedSubgraph(datapipe, sample_per_layer)
datapipe = datapipe.compact_per_layer(True)
gb.seed(123)
Expand Down
6 changes: 6 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,15 @@ def test_DataLoader():
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_graph_fetch", [True, False])
@pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024])
@pytest.mark.parametrize("gpu_cache_threshold", [1, 3])
def test_gpu_sampling_DataLoader(
sampler_name,
enable_feature_fetch,
overlap_feature_fetch,
overlap_graph_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
):
N = 40
B = 4
Expand Down Expand Up @@ -94,6 +98,8 @@ def test_gpu_sampling_DataLoader(
datapipe,
overlap_feature_fetch=overlap_feature_fetch,
overlap_graph_fetch=overlap_graph_fetch,
num_gpu_cached_edges=num_gpu_cached_edges,
gpu_cache_threshold=gpu_cache_threshold,
)
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
if overlap_graph_fetch:
Expand Down
Loading