From 3a424dc0614f6e7cff05e1649be7645dc29e72ec Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 14:19:03 -0400 Subject: [PATCH 1/9] [GraphBolt][CUDA] Refactor `overlap_graph_fetch`, simplify `gb.DataLoader`. --- python/dgl/graphbolt/dataloader.py | 80 +-------- .../impl/fused_csc_sampling_graph.py | 28 ++++ python/dgl/graphbolt/impl/neighbor_sampler.py | 157 +++++++++--------- .../graphbolt/impl/test_neighbor_sampler.py | 27 +-- .../pytorch/graphbolt/test_dataloader.py | 10 +- 5 files changed, 118 insertions(+), 184 deletions(-) diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index bb37cf2e9806..3be6ab685c48 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -1,13 +1,10 @@ """Graph Bolt DataLoaders""" -from collections import OrderedDict - import torch import torch.utils.data as torch_data -from .base import CopyTo, get_host_to_device_uva_stream +from .base import CopyTo from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker -from .impl.gpu_graph_cache import GPUGraphCache from .impl.neighbor_sampler import SamplePerLayer from .internal import ( @@ -22,34 +19,9 @@ __all__ = [ "DataLoader", - "construct_gpu_graph_cache", ] -def construct_gpu_graph_cache( - sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold -): - "Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters." - graph = sample_per_layer_obj.sampler.__self__ - num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges) - dtypes = OrderedDict() - dtypes["indices"] = graph.indices.dtype - if graph.type_per_edge is not None: - dtypes["type_per_edge"] = graph.type_per_edge.dtype - if graph.edge_attributes is not None: - probs_or_mask = graph.edge_attributes.get( - sample_per_layer_obj.prob_name, None - ) - if probs_or_mask is not None: - dtypes["probs_or_mask"] = probs_or_mask.dtype - return GPUGraphCache( - num_gpu_cached_edges, - gpu_cache_threshold, - graph.csc_indptr.dtype, - list(dtypes.values()), - ) - - def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs): """Find parent of target_datapipe and wrap it with .""" datapipes = find_dps( @@ -150,9 +122,6 @@ def __init__( datapipe, num_workers=0, persistent_workers=True, - overlap_graph_fetch=False, - num_gpu_cached_edges=0, - gpu_cache_threshold=1, max_uva_threads=10240, ): # Multiprocessing requires two modifications to the datapipe: @@ -200,54 +169,15 @@ def __init__( if feature_fetcher.max_num_stages > 0: # Overlap enabled. torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) - if ( - overlap_graph_fetch - and num_workers == 0 - and torch.cuda.is_available() - ): - torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) + if num_workers == 0 and torch.cuda.is_available(): samplers = find_dps( datapipe_graph, SamplePerLayer, ) - gpu_graph_cache = None for sampler in samplers: - if num_gpu_cached_edges > 0 and gpu_graph_cache is None: - gpu_graph_cache = construct_gpu_graph_cache( - sampler, num_gpu_cached_edges, gpu_cache_threshold - ) - if ( - sampler.sampler.__name__ == "sample_layer_neighbors" - or gpu_graph_cache is not None - ): - # This code path is not faster for sample_neighbors. - datapipe_graph = replace_dp( - datapipe_graph, - sampler, - sampler.fetch_and_sample( - gpu_graph_cache, - get_host_to_device_uva_stream(), - 1, - ), - ) - elif sampler.sampler.__name__ == "sample_neighbors": - # This code path is faster for sample_neighbors. - datapipe_graph = replace_dp( - datapipe_graph, - sampler, - sampler.datapipe.sample_per_layer( - sampler=sampler.sampler, - fanout=sampler.fanout, - replace=sampler.replace, - prob_name=sampler.prob_name, - returning_indices_is_optional=True, - ), - ) - else: - raise AssertionError( - "overlap_graph_fetch is supported only for " - "sample_neighbor and sample_layer_neighbor." - ) + if sampler.overlap_fetch: + torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) + # (4) Cut datapipe at CopyTo and wrap with pinning and prefetching # before it. This enables enables non_blocking copies to the device. diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 7288c969d0a7..04bc0ce05c4c 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -10,6 +10,7 @@ from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID from ..internal_utils import gb_warning, is_wsl, recursive_apply from ..sampling_graph import SamplingGraph +from .gpu_graph_cache import GPUGraphCache from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl @@ -314,6 +315,14 @@ def _indptr_node_type_offset_list( ): """Sets the indptr node type offset list if present.""" self._indptr_node_type_offset_list_ = indptr_node_type_offset_list + + @property + def _gpu_graph_cache(self) -> Optional[GPUGraphCache]: + return ( + self._gpu_graph_cache_ + if hasattr(self, "_gpu_graph_cache_") + else None + ) @property def type_per_edge(self) -> Optional[torch.Tensor]: @@ -1432,6 +1441,25 @@ def _pin(x): return self._apply_to_members(_pin) + def _initialize_gpu_graph_cache( + self, num_gpu_cached_edges: int, gpu_cache_threshold: int, prob_name: Optional[str] = None + ): + "Construct a GPUGraphCache given the cache parameters." + num_gpu_cached_edges = min(num_gpu_cached_edges, self.total_num_edges) + dtypes = [self.indices.dtype] + if self.type_per_edge is not None: + dtypes.append(self.type_per_edge.dtype) + if self.edge_attributes is not None: + probs_or_mask = self.edge_attributes.get(prob_name, None) + if probs_or_mask is not None: + dtypes.append(probs_or_mask.dtype) + self._gpu_graph_cache_ = GPUGraphCache( + num_gpu_cached_edges, + gpu_cache_threshold, + self.csc_indptr.dtype, + dtypes, + ) + def fused_csc_sampling_graph( csc_indptr: torch.Tensor, diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 8c997eb6848f..7ee2d85f0133 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -24,7 +24,6 @@ "NeighborSampler", "LayerNeighborSampler", "SamplePerLayer", - "SamplePerLayerFromFetchedSubgraph", "FetchInsubgraphData", "ConcatHeteroSeeds", "CombineCachedAndFetchedInSubgraph", @@ -55,9 +54,9 @@ class CombineCachedAndFetchedInSubgraph(Mapper): found inside the GPUGraphCache. """ - def __init__(self, datapipe, sample_per_layer_obj): + def __init__(self, datapipe, prob_name): super().__init__(datapipe, self._combine_per_layer) - self.prob_name = sample_per_layer_obj.prob_name + self.prob_name = prob_name def _combine_per_layer(self, minibatch): subgraph = minibatch._sliced_sampling_graph @@ -94,9 +93,9 @@ def _combine_per_layer(self, minibatch): class ConcatHeteroSeeds(Mapper): """Concatenates the seeds into a single tensor in the hetero case.""" - def __init__(self, datapipe, sample_per_layer_obj): + def __init__(self, datapipe, graph): super().__init__(datapipe, self._concat) - self.graph = sample_per_layer_obj.sampler.__self__ + self.graph = graph def _concat(self, minibatch): seeds = minibatch._seed_nodes @@ -124,20 +123,21 @@ class FetchInsubgraphData(Mapper): def __init__( self, datapipe, - sample_per_layer_obj, - gpu_graph_cache, - stream=None, + graph, + prob_name, ): - self.graph = sample_per_layer_obj.sampler.__self__ - datapipe = datapipe.concat_hetero_seeds(sample_per_layer_obj) - if gpu_graph_cache is not None: - datapipe = datapipe.fetch_cached_insubgraph_data(gpu_graph_cache) + datapipe = datapipe.concat_hetero_seeds(graph) + if graph._gpu_graph_cache is not None: + datapipe = datapipe.fetch_cached_insubgraph_data(graph._gpu_graph_cache) + self.graph = graph + self.prob_name = prob_name super().__init__(datapipe, self._fetch_per_layer) - self.prob_name = sample_per_layer_obj.prob_name - self.stream = stream - def _fetch_per_layer_impl(self, minibatch, stream): - with torch.cuda.stream(self.stream): + def _fetch_per_layer(self, minibatch): + stream = torch.cuda.current_stream() + uva_stream = get_host_to_device_uva_stream() + uva_stream.wait_stream(stream) + with torch.cuda.stream(uva_stream): seeds = minibatch._seeds seed_offsets = minibatch._seed_offsets delattr(minibatch, "_seeds") @@ -146,7 +146,7 @@ def _fetch_per_layer_impl(self, minibatch, stream): seeds.record_stream(torch.cuda.current_stream()) def record_stream(tensor): - if stream is not None and tensor.is_cuda: + if tensor.is_cuda: tensor.record_stream(stream) return tensor @@ -210,49 +210,10 @@ def record_stream(tensor): subgraph._indptr_node_type_offset_list = seed_offsets minibatch._sliced_sampling_graph = subgraph - if self.stream is not None: - minibatch.wait = torch.cuda.current_stream().record_event().wait + minibatch.wait = torch.cuda.current_stream().record_event().wait return minibatch - def _fetch_per_layer(self, minibatch): - current_stream = None - if self.stream is not None: - current_stream = torch.cuda.current_stream() - self.stream.wait_stream(current_stream) - return self._fetch_per_layer_impl(minibatch, current_stream) - - -@functional_datapipe("sample_per_layer_from_fetched_subgraph") -class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer): - """Sample neighbor edges from a graph for a single layer.""" - - def __init__(self, datapipe, sample_per_layer_obj): - super().__init__(datapipe, self._sample_per_layer_from_fetched_subgraph) - self.sampler_name = sample_per_layer_obj.sampler.__name__ - self.fanout = sample_per_layer_obj.fanout - self.replace = sample_per_layer_obj.replace - self.prob_name = sample_per_layer_obj.prob_name - - def _sample_per_layer_from_fetched_subgraph(self, minibatch): - subgraph = minibatch._sliced_sampling_graph - delattr(minibatch, "_sliced_sampling_graph") - kwargs = { - key[1:]: getattr(minibatch, key) - for key in ["_random_seed", "_seed2_contribution"] - if hasattr(minibatch, key) - } - sampled_subgraph = getattr(subgraph, self.sampler_name)( - None, - self.fanout, - self.replace, - self.prob_name, - **kwargs, - ) - minibatch.sampled_subgraphs.insert(0, sampled_subgraph) - - return minibatch - @functional_datapipe("sample_per_layer") class SamplePerLayer(MiniBatchTransformer): @@ -265,10 +226,11 @@ def __init__( fanout, replace, prob_name, - returning_indices_is_optional=False, + overlap_fetch, ): graph = sampler.__self__ - if returning_indices_is_optional and graph.indices.is_pinned(): + self.returning_indices_is_optional = False + if overlap_fetch and sampler.__name__ == "sample_neighbors" and graph.indices.is_pinned() and graph._gpu_graph_cache is None: datapipe = datapipe.transform(self._sample_per_layer) datapipe = ( datapipe.transform(partial(self._fetch_indices, graph.indices)) @@ -285,13 +247,20 @@ def __init__( ) ) super().__init__(datapipe) + self.returning_indices_is_optional = True + elif overlap_fetch: + datapipe = datapipe.fetch_insubgraph_data(graph, prob_name) + datapipe = datapipe.buffer().wait() + if graph._gpu_graph_cache is not None: + datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name) + super().__init__(datapipe, self._sample_per_layer_from_fetched_subgraph) else: super().__init__(datapipe, self._sample_per_layer) self.sampler = sampler self.fanout = fanout self.replace = replace self.prob_name = prob_name - self.returning_indices_is_optional = returning_indices_is_optional + self.overlap_fetch = overlap_fetch def _sample_per_layer(self, minibatch): kwargs = { @@ -310,6 +279,25 @@ def _sample_per_layer(self, minibatch): minibatch.sampled_subgraphs.insert(0, subgraph) return minibatch + def _sample_per_layer_from_fetched_subgraph(self, minibatch): + subgraph = minibatch._sliced_sampling_graph + delattr(minibatch, "_sliced_sampling_graph") + kwargs = { + key[1:]: getattr(minibatch, key) + for key in ["_random_seed", "_seed2_contribution"] + if hasattr(minibatch, key) + } + sampled_subgraph = getattr(subgraph, self.sampler.__name__)( + None, + self.fanout, + self.replace, + self.prob_name, + **kwargs, + ) + minibatch.sampled_subgraphs.insert(0, sampled_subgraph) + + return minibatch + @staticmethod def _fetch_indices(indices, minibatch): stream = torch.cuda.current_stream() @@ -398,27 +386,6 @@ def _compact_per_layer(self, minibatch): return minibatch -@functional_datapipe("fetch_and_sample") -class FetcherAndSampler(MiniBatchTransformer): - """Overlapped graph sampling operation replacement.""" - - def __init__( - self, - sampler, - gpu_graph_cache, - stream, - buffer_size, - ): - datapipe = sampler.datapipe.fetch_insubgraph_data( - sampler, gpu_graph_cache, stream - ) - datapipe = datapipe.buffer(buffer_size).wait() - if gpu_graph_cache is not None: - datapipe = datapipe.combine_cached_and_fetched_insubgraph(sampler) - datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler) - super().__init__(datapipe) - - class NeighborSamplerImpl(SubgraphSampler): # pylint: disable=abstract-method """Base class for NeighborSamplers.""" @@ -433,6 +400,9 @@ def __init__( prob_name, deduplicate, sampler, + overlap_fetch, + num_gpu_cached_edges, + gpu_cache_threshold, layer_dependency=None, batch_dependency=None, ): @@ -446,6 +416,9 @@ def __init__( prob_name, deduplicate, sampler, + overlap_fetch, + num_gpu_cached_edges, + gpu_cache_threshold, layer_dependency, ) @@ -520,8 +493,16 @@ def sampling_stages( prob_name, deduplicate, sampler, + overlap_fetch, + num_gpu_cached_edges, + gpu_cache_threshold, layer_dependency, ): + if overlap_fetch and num_gpu_cached_edges > 0: + if graph._gpu_graph_cache is None: + graph._initialize_gpu_graph_cache( + num_gpu_cached_edges, gpu_cache_threshold + ) datapipe = datapipe.transform( partial(self._prepare, graph.node_type_to_id) ) @@ -533,7 +514,7 @@ def sampling_stages( if not isinstance(fanout, torch.Tensor): fanout = torch.LongTensor([int(fanout)]) datapipe = datapipe.sample_per_layer( - sampler, fanout, replace, prob_name + sampler, fanout, replace, prob_name, overlap_fetch ) datapipe = datapipe.compact_per_layer(deduplicate) if is_labor and not layer_dependency: @@ -638,6 +619,9 @@ def __init__( replace=False, prob_name=None, deduplicate=True, + overlap_fetch=False, + num_gpu_cached_edges=0, + gpu_cache_threshold=1, ): super().__init__( datapipe, @@ -647,6 +631,9 @@ def __init__( prob_name, deduplicate, graph.sample_neighbors, + overlap_fetch, + num_gpu_cached_edges, + gpu_cache_threshold, ) @@ -776,6 +763,9 @@ def __init__( deduplicate=True, layer_dependency=False, batch_dependency=1, + overlap_fetch=False, + num_gpu_cached_edges=0, + gpu_cache_threshold=1, ): super().__init__( datapipe, @@ -785,6 +775,9 @@ def __init__( prob_name, deduplicate, graph.sample_layer_neighbors, + overlap_fetch, + num_gpu_cached_edges, + gpu_cache_threshold, layer_dependency, batch_dependency, ) diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index 6c5818bbf601..affcd43b7820 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -7,7 +7,6 @@ import dgl.graphbolt as gb import pytest import torch -from dgl.graphbolt.dataloader import construct_gpu_graph_cache def get_hetero_graph(): @@ -72,21 +71,8 @@ def test_NeighborSampler_GraphFetch( compact_per_layer = sample_per_layer.compact_per_layer(True) gb.seed(123) expected_results = list(compact_per_layer) - gpu_graph_cache = None - if num_cached_edges > 0: - gpu_graph_cache = construct_gpu_graph_cache( - sample_per_layer, num_cached_edges, 1 - ) - datapipe = gb.FetchInsubgraphData( - datapipe, - sample_per_layer, - gpu_graph_cache, - ) - if num_cached_edges > 0: - datapipe = gb.CombineCachedAndFetchedInSubgraph( - datapipe, sample_per_layer - ) - datapipe = gb.SamplePerLayerFromFetchedSubgraph(datapipe, sample_per_layer) + graph._initialize_gpu_graph_cache(num_cached_edges, 1, prob_name) + datapipe = datapipe.sample_per_layer(graph.sample_neighbors, fanout, False, prob_name, True) datapipe = datapipe.compact_per_layer(True) gb.seed(123) new_results = list(datapipe) @@ -99,10 +85,10 @@ def remove_input_nodes(minibatch): return minibatch datapipe = item_sampler.sample_neighbor( - graph, [fanout], False, prob_name=prob_name + graph, [fanout], False, prob_name=prob_name, overlap_fetch=True ) datapipe = datapipe.transform(remove_input_nodes) - dataloader = gb.DataLoader(datapipe, overlap_graph_fetch=True) + dataloader = gb.DataLoader(datapipe) gb.seed(123) new_results = list(dataloader) assert len(expected_results) == len(new_results) @@ -133,12 +119,11 @@ def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch): datapipe = datapipe.sample_layer_neighbor( graph, fanouts, + overlap_fetch=overlap_graph_fetch, layer_dependency=layer_dependency, batch_dependency=batch_dependency, ) - dataloader = gb.DataLoader( - datapipe, overlap_graph_fetch=overlap_graph_fetch - ) + dataloader = gb.DataLoader(datapipe) res = list(dataloader) assert len(res) == batch_dependency + 1 if layer_dependency: diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index ba22bfdda293..ebce4f55a784 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -108,6 +108,9 @@ def test_gpu_sampling_DataLoader( datapipe, graph, fanouts=[torch.LongTensor([2]) for _ in range(num_layers)], + overlap_fetch=overlap_graph_fetch, + num_gpu_cached_edges=num_gpu_cached_edges, + gpu_cache_threshold=gpu_cache_threshold, ) if enable_feature_fetch: datapipe = dgl.graphbolt.FeatureFetcher( @@ -119,12 +122,7 @@ def test_gpu_sampling_DataLoader( ) if i == 0: dataloaders.append( - dgl.graphbolt.DataLoader( - datapipe, - overlap_graph_fetch=overlap_graph_fetch, - num_gpu_cached_edges=num_gpu_cached_edges, - gpu_cache_threshold=gpu_cache_threshold, - ) + dgl.graphbolt.DataLoader(datapipe) ) else: dataloaders.append(dgl.graphbolt.DataLoader(datapipe)) From bde921435b1aabbd973111023c2b0ba1813fcc43 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 14:23:17 -0400 Subject: [PATCH 2/9] linting --- python/dgl/graphbolt/dataloader.py | 1 - .../impl/fused_csc_sampling_graph.py | 7 +++++-- python/dgl/graphbolt/impl/neighbor_sampler.py | 19 +++++++++++++++---- .../graphbolt/impl/test_neighbor_sampler.py | 4 +++- .../pytorch/graphbolt/test_dataloader.py | 4 +--- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index 3be6ab685c48..5e4eca7add98 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -177,7 +177,6 @@ def __init__( for sampler in samplers: if sampler.overlap_fetch: torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) - # (4) Cut datapipe at CopyTo and wrap with pinning and prefetching # before it. This enables enables non_blocking copies to the device. diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 04bc0ce05c4c..53cfcc76bbbb 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -315,7 +315,7 @@ def _indptr_node_type_offset_list( ): """Sets the indptr node type offset list if present.""" self._indptr_node_type_offset_list_ = indptr_node_type_offset_list - + @property def _gpu_graph_cache(self) -> Optional[GPUGraphCache]: return ( @@ -1442,7 +1442,10 @@ def _pin(x): return self._apply_to_members(_pin) def _initialize_gpu_graph_cache( - self, num_gpu_cached_edges: int, gpu_cache_threshold: int, prob_name: Optional[str] = None + self, + num_gpu_cached_edges: int, + gpu_cache_threshold: int, + prob_name: Optional[str] = None, ): "Construct a GPUGraphCache given the cache parameters." num_gpu_cached_edges = min(num_gpu_cached_edges, self.total_num_edges) diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 7ee2d85f0133..90632193171e 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -128,7 +128,9 @@ def __init__( ): datapipe = datapipe.concat_hetero_seeds(graph) if graph._gpu_graph_cache is not None: - datapipe = datapipe.fetch_cached_insubgraph_data(graph._gpu_graph_cache) + datapipe = datapipe.fetch_cached_insubgraph_data( + graph._gpu_graph_cache + ) self.graph = graph self.prob_name = prob_name super().__init__(datapipe, self._fetch_per_layer) @@ -230,7 +232,12 @@ def __init__( ): graph = sampler.__self__ self.returning_indices_is_optional = False - if overlap_fetch and sampler.__name__ == "sample_neighbors" and graph.indices.is_pinned() and graph._gpu_graph_cache is None: + if ( + overlap_fetch + and sampler.__name__ == "sample_neighbors" + and graph.indices.is_pinned() + and graph._gpu_graph_cache is None + ): datapipe = datapipe.transform(self._sample_per_layer) datapipe = ( datapipe.transform(partial(self._fetch_indices, graph.indices)) @@ -252,8 +259,12 @@ def __init__( datapipe = datapipe.fetch_insubgraph_data(graph, prob_name) datapipe = datapipe.buffer().wait() if graph._gpu_graph_cache is not None: - datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name) - super().__init__(datapipe, self._sample_per_layer_from_fetched_subgraph) + datapipe = datapipe.combine_cached_and_fetched_insubgraph( + prob_name + ) + super().__init__( + datapipe, self._sample_per_layer_from_fetched_subgraph + ) else: super().__init__(datapipe, self._sample_per_layer) self.sampler = sampler diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index affcd43b7820..c959a25864c3 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -72,7 +72,9 @@ def test_NeighborSampler_GraphFetch( gb.seed(123) expected_results = list(compact_per_layer) graph._initialize_gpu_graph_cache(num_cached_edges, 1, prob_name) - datapipe = datapipe.sample_per_layer(graph.sample_neighbors, fanout, False, prob_name, True) + datapipe = datapipe.sample_per_layer( + graph.sample_neighbors, fanout, False, prob_name, True + ) datapipe = datapipe.compact_per_layer(True) gb.seed(123) new_results = list(datapipe) diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index ebce4f55a784..069d2ca4cce0 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -121,9 +121,7 @@ def test_gpu_sampling_DataLoader( overlap_fetch=overlap_feature_fetch and i == 0, ) if i == 0: - dataloaders.append( - dgl.graphbolt.DataLoader(datapipe) - ) + dataloaders.append(dgl.graphbolt.DataLoader(datapipe)) else: dataloaders.append(dgl.graphbolt.DataLoader(datapipe)) dataloader, dataloader2 = dataloaders From 3b06ef011fce7c394a9a5d7c47fecf7db10800a8 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 14:31:21 -0400 Subject: [PATCH 3/9] fix bug --- python/dgl/graphbolt/impl/neighbor_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 90632193171e..f2d3c701edf4 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -512,7 +512,7 @@ def sampling_stages( if overlap_fetch and num_gpu_cached_edges > 0: if graph._gpu_graph_cache is None: graph._initialize_gpu_graph_cache( - num_gpu_cached_edges, gpu_cache_threshold + num_gpu_cached_edges, gpu_cache_threshold, prob_name ) datapipe = datapipe.transform( partial(self._prepare, graph.node_type_to_id) From 6e2970e99ec1e74a39bebcca74aec58e9e217aa5 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 14:36:06 -0400 Subject: [PATCH 4/9] fix the test. --- .../python/pytorch/graphbolt/test_dataloader.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 069d2ca4cce0..832c303b4c6c 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -104,13 +104,18 @@ def test_gpu_sampling_DataLoader( for i in range(2): datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B) datapipe = datapipe.copy_to(F.ctx()) + kwargs = {} + if i == 0: + kwargs = { + "overlap_fetch": overlap_graph_fetch, + "num_gpu_cached_edges": num_gpu_cached_edges, + "gpu_cache_threshold": gpu_cache_threshold, + } datapipe = getattr(dgl.graphbolt, sampler_name)( datapipe, graph, fanouts=[torch.LongTensor([2]) for _ in range(num_layers)], - overlap_fetch=overlap_graph_fetch, - num_gpu_cached_edges=num_gpu_cached_edges, - gpu_cache_threshold=gpu_cache_threshold, + **kwargs ) if enable_feature_fetch: datapipe = dgl.graphbolt.FeatureFetcher( @@ -120,10 +125,7 @@ def test_gpu_sampling_DataLoader( ["d"], overlap_fetch=overlap_feature_fetch and i == 0, ) - if i == 0: - dataloaders.append(dgl.graphbolt.DataLoader(datapipe)) - else: - dataloaders.append(dgl.graphbolt.DataLoader(datapipe)) + dataloaders.append(dgl.graphbolt.DataLoader(datapipe)) dataloader, dataloader2 = dataloaders bufferer_cnt = int(enable_feature_fetch and overlap_feature_fetch) From be4009d669876bb26717208f956f2b42e9860d44 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 14:36:41 -0400 Subject: [PATCH 5/9] make the code easier to read. --- tests/python/pytorch/graphbolt/test_dataloader.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 832c303b4c6c..93045c0113e9 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -104,13 +104,13 @@ def test_gpu_sampling_DataLoader( for i in range(2): datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B) datapipe = datapipe.copy_to(F.ctx()) - kwargs = {} - if i == 0: - kwargs = { - "overlap_fetch": overlap_graph_fetch, - "num_gpu_cached_edges": num_gpu_cached_edges, - "gpu_cache_threshold": gpu_cache_threshold, - } + kwargs = { + "overlap_fetch": overlap_graph_fetch, + "num_gpu_cached_edges": num_gpu_cached_edges, + "gpu_cache_threshold": gpu_cache_threshold, + } + if i != 0: + kwargs = {} datapipe = getattr(dgl.graphbolt, sampler_name)( datapipe, graph, From 788cd3a0b95b0aa914628a3af16ba80efc0bca6c Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 14:46:21 -0400 Subject: [PATCH 6/9] modify examples for the API change. --- .../disk_based_feature/node_classification.py | 11 ++++----- examples/graphbolt/node_classification.py | 10 ++++---- .../pyg/labor/node_classification.py | 11 ++++----- .../pyg/node_classification_advanced.py | 14 +++++------ examples/graphbolt/rgcn/hetero_rgcn.py | 10 ++++---- .../multigpu/graphbolt/node_classification.py | 10 ++++---- python/dgl/graphbolt/dataloader.py | 12 ---------- python/dgl/graphbolt/impl/neighbor_sampler.py | 24 +++++++++++++++++++ 8 files changed, 52 insertions(+), 50 deletions(-) diff --git a/examples/graphbolt/disk_based_feature/node_classification.py b/examples/graphbolt/disk_based_feature/node_classification.py index afb6a0ac542a..aaca410947cc 100644 --- a/examples/graphbolt/disk_based_feature/node_classification.py +++ b/examples/graphbolt/disk_based_feature/node_classification.py @@ -115,7 +115,10 @@ def create_dataloader( else {} ) datapipe = getattr(datapipe, args.sample_mode)( - graph, fanout if job != "infer" else [-1], **kwargs + graph, + fanout if job != "infer" else [-1], + overlap_fetch=args.overlap_graph_fetch, + **kwargs, ) # Copy the data to the specified device. if args.feature_device != "cpu": @@ -130,11 +133,7 @@ def create_dataloader( if args.feature_device == "cpu": datapipe = datapipe.copy_to(device=device) # Create and return a DataLoader to handle data loading. - return gb.DataLoader( - datapipe, - num_workers=args.num_workers, - overlap_graph_fetch=args.overlap_graph_fetch, - ) + return gb.DataLoader(datapipe, num_workers=args.num_workers) def train_step(minibatch, optimizer, model, loss_fn): diff --git a/examples/graphbolt/node_classification.py b/examples/graphbolt/node_classification.py index ff9ed2399b23..e5e17de88bde 100644 --- a/examples/graphbolt/node_classification.py +++ b/examples/graphbolt/node_classification.py @@ -117,7 +117,9 @@ def create_dataloader( # Initialize a neighbor sampler for sampling the neighborhoods of nodes. ############################################################################ datapipe = getattr(datapipe, args.sample_mode)( - graph, fanout if job != "infer" else [-1] + graph, + fanout if job != "infer" else [-1], + overlap_fetch=args.storage_device == "pinned", ) ############################################################################ @@ -156,11 +158,7 @@ def create_dataloader( # [Role]: # Initialize a multi-process dataloader to load the data in parallel. ############################################################################ - dataloader = gb.DataLoader( - datapipe, - num_workers=num_workers, - overlap_graph_fetch=args.storage_device == "pinned", - ) + dataloader = gb.DataLoader(datapipe, num_workers=num_workers) # Return the fully-initialized DataLoader object. return dataloader diff --git a/examples/graphbolt/pyg/labor/node_classification.py b/examples/graphbolt/pyg/labor/node_classification.py index b799b3de8cbe..09f8cb3cf050 100644 --- a/examples/graphbolt/pyg/labor/node_classification.py +++ b/examples/graphbolt/pyg/labor/node_classification.py @@ -147,7 +147,10 @@ def create_dataloader( else {} ) datapipe = getattr(datapipe, args.sample_mode)( - graph, fanout if job != "infer" else [-1], **kwargs + graph, + fanout if job != "infer" else [-1], + overlap_fetch=args.overlap_graph_fetch, + **kwargs, ) # Copy the data to the specified device. if args.feature_device != "cpu" and need_copy: @@ -163,11 +166,7 @@ def create_dataloader( if need_copy: datapipe = datapipe.copy_to(device=device) # Create and return a DataLoader to handle data loading. - return gb.DataLoader( - datapipe, - num_workers=args.num_workers, - overlap_graph_fetch=args.overlap_graph_fetch, - ) + return gb.DataLoader(datapipe, num_workers=args.num_workers) @torch.compile diff --git a/examples/graphbolt/pyg/node_classification_advanced.py b/examples/graphbolt/pyg/node_classification_advanced.py index 5066016f7e18..3b066a511b32 100644 --- a/examples/graphbolt/pyg/node_classification_advanced.py +++ b/examples/graphbolt/pyg/node_classification_advanced.py @@ -195,7 +195,11 @@ def create_dataloader( need_copy = False # Sample neighbors for each node in the mini-batch. datapipe = getattr(datapipe, args.sample_mode)( - graph, fanout if job != "infer" else [-1] + graph, + fanout if job != "infer" else [-1], + overlap_fetch=args.overlap_graph_fetch, + num_gpu_cached_edges=args.num_gpu_cached_edges, + gpu_cache_threshold=args.gpu_graph_caching_threshold, ) # Copy the data to the specified device. if args.feature_device != "cpu" and need_copy: @@ -211,13 +215,7 @@ def create_dataloader( if need_copy: datapipe = datapipe.copy_to(device=device) # Create and return a DataLoader to handle data loading. - return gb.DataLoader( - datapipe, - num_workers=args.num_workers, - overlap_graph_fetch=args.overlap_graph_fetch, - num_gpu_cached_edges=args.num_gpu_cached_edges, - gpu_cache_threshold=args.gpu_graph_caching_threshold, - ) + return gb.DataLoader(datapipe, num_workers=args.num_workers) @torch.compile diff --git a/examples/graphbolt/rgcn/hetero_rgcn.py b/examples/graphbolt/rgcn/hetero_rgcn.py index d834b84f7d9e..eec00e12f11f 100644 --- a/examples/graphbolt/rgcn/hetero_rgcn.py +++ b/examples/graphbolt/rgcn/hetero_rgcn.py @@ -124,7 +124,9 @@ def create_dataloader( # The graph(FusedCSCSamplingGraph) from which to sample neighbors. # `fanouts`: # The number of neighbors to sample for each node in each layer. - datapipe = datapipe.sample_neighbor(graph, fanouts=fanouts) + datapipe = datapipe.sample_neighbor( + graph, fanouts=fanouts, overlap_fetch=args.overlap_graph_fetch + ) # Fetch the features for each node in the mini-batch. # `features`: @@ -141,11 +143,7 @@ def create_dataloader( # Create a DataLoader from the datapipe. # `num_workers`: # The number of worker processes to use for data loading. - return gb.DataLoader( - datapipe, - num_workers=num_workers, - overlap_graph_fetch=args.overlap_graph_fetch, - ) + return gb.DataLoader(datapipe, num_workers=num_workers) def extract_embed(node_embed, input_nodes): diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 32da55f782d2..35ae6fcc38d4 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -134,16 +134,14 @@ def create_dataloader( ############################################################################ if args.storage_device != "cpu": datapipe = datapipe.copy_to(device) - datapipe = datapipe.sample_neighbor(graph, args.fanout) + datapipe = datapipe.sample_neighbor( + graph, args.fanout, overlap_fetch=args.storage_device == "pinned" + ) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) if args.storage_device == "cpu": datapipe = datapipe.copy_to(device) - dataloader = gb.DataLoader( - datapipe, - args.num_workers, - overlap_graph_fetch=args.storage_device == "pinned", - ) + dataloader = gb.DataLoader(datapipe, args.num_workers) # Return the fully-initialized DataLoader object. return dataloader diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index 5e4eca7add98..d76cb48fa0db 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -97,18 +97,6 @@ class DataLoader(torch_data.DataLoader): If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers instances alive. - overlap_graph_fetch : bool, optional - If True, the data loader will overlap the UVA graph fetching operations - with the rest of operations by using an alternative CUDA stream. This - option should be enabled if you have moved your graph to the pinned - memory for optimal performance. Default is False. - num_gpu_cached_edges : int, optional - If positive and overlap_graph_fetch is True, then the GPU will cache - frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth - demand due to pinned graph accesses. - gpu_cache_threshold : int, optional - Determines how many times a vertex needs to be accessed before its - neighborhood ends up being cached on the GPU. max_uva_threads : int, optional Limits the number of CUDA threads used for UVA copies so that the rest of the computations can run simultaneously with it. Setting it to a too diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index f2d3c701edf4..d01417a56c42 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -581,6 +581,18 @@ class NeighborSampler(NeighborSamplerImpl): Boolean indicating whether seeds between hops will be deduplicated. If True, the same elements in seeds will be deleted to only one. Otherwise, the same elements will be remained. + overlap_fetch : bool, optional + If True, the data loader will overlap the UVA graph fetching operations + with the rest of operations by using an alternative CUDA stream. This + option should be enabled if you have moved your graph to the pinned + memory for optimal performance. Default is False. + num_gpu_cached_edges : int, optional + If positive and overlap_graph_fetch is True, then the GPU will cache + frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth + demand due to pinned graph accesses. + gpu_cache_threshold : int, optional + Determines how many times a vertex needs to be accessed before its + neighborhood ends up being cached on the GPU. Examples ------- @@ -716,6 +728,18 @@ class LayerNeighborSampler(NeighborSamplerImpl): the random variates proportional to :math:`\\frac{1}{\\kappa}`. Implements the dependent minibatching approach in `arXiv:2310.12403 `__. + overlap_fetch : bool, optional + If True, the data loader will overlap the UVA graph fetching operations + with the rest of operations by using an alternative CUDA stream. This + option should be enabled if you have moved your graph to the pinned + memory for optimal performance. Default is False. + num_gpu_cached_edges : int, optional + If positive and overlap_graph_fetch is True, then the GPU will cache + frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth + demand due to pinned graph accesses. + gpu_cache_threshold : int, optional + Determines how many times a vertex needs to be accessed before its + neighborhood ends up being cached on the GPU. Examples ------- From ee4ae8423bafa28e190f7c0e09c46bf21da1c1a2 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 15:07:48 -0400 Subject: [PATCH 7/9] improve the code --- python/dgl/graphbolt/impl/neighbor_sampler.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index d01417a56c42..07c891ae545a 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -417,6 +417,11 @@ def __init__( layer_dependency=None, batch_dependency=None, ): + if overlap_fetch and num_gpu_cached_edges > 0: + if graph._gpu_graph_cache is None: + graph._initialize_gpu_graph_cache( + num_gpu_cached_edges, gpu_cache_threshold, prob_name + ) if sampler.__name__ == "sample_layer_neighbors": self._init_seed(batch_dependency) super().__init__( @@ -428,8 +433,6 @@ def __init__( deduplicate, sampler, overlap_fetch, - num_gpu_cached_edges, - gpu_cache_threshold, layer_dependency, ) @@ -505,15 +508,8 @@ def sampling_stages( deduplicate, sampler, overlap_fetch, - num_gpu_cached_edges, - gpu_cache_threshold, layer_dependency, ): - if overlap_fetch and num_gpu_cached_edges > 0: - if graph._gpu_graph_cache is None: - graph._initialize_gpu_graph_cache( - num_gpu_cached_edges, gpu_cache_threshold, prob_name - ) datapipe = datapipe.transform( partial(self._prepare, graph.node_type_to_id) ) From a76856c53eadbf2aaf13f45bda938d5b801663af Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 15:09:35 -0400 Subject: [PATCH 8/9] make the test same as before. --- tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index c959a25864c3..70c6ceba29dd 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -71,7 +71,8 @@ def test_NeighborSampler_GraphFetch( compact_per_layer = sample_per_layer.compact_per_layer(True) gb.seed(123) expected_results = list(compact_per_layer) - graph._initialize_gpu_graph_cache(num_cached_edges, 1, prob_name) + if num_cached_edges > 0: + graph._initialize_gpu_graph_cache(num_cached_edges, 1, prob_name) datapipe = datapipe.sample_per_layer( graph.sample_neighbors, fanout, False, prob_name, True ) From 45541ad33fe47f57342ad2031dd89528e116bc8b Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 11 Aug 2024 20:51:16 -0400 Subject: [PATCH 9/9] fix test failure. --- tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index 70c6ceba29dd..3f827923e2f5 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -102,6 +102,8 @@ def remove_input_nodes(minibatch): @pytest.mark.parametrize("layer_dependency", [False, True]) @pytest.mark.parametrize("overlap_graph_fetch", [False, True]) def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch): + if F._default_context_str != "gpu" and overlap_graph_fetch: + pytest.skip("overlap_graph_fetch is only available for GPU.") num_edges = 200 csc_indptr = torch.cat( (