Skip to content

Commit

Permalink
[GraphBolt][CUDA] Refactor overlap_graph_fetch, simplify `gb.DataLo…
Browse files Browse the repository at this point in the history
…ader`.
  • Loading branch information
mfbalin committed Aug 11, 2024
1 parent c86776d commit 3a424dc
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 184 deletions.
80 changes: 5 additions & 75 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""Graph Bolt DataLoaders"""

from collections import OrderedDict

import torch
import torch.utils.data as torch_data

from .base import CopyTo, get_host_to_device_uva_stream
from .base import CopyTo
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
from .impl.gpu_graph_cache import GPUGraphCache
from .impl.neighbor_sampler import SamplePerLayer

from .internal import (
Expand All @@ -22,34 +19,9 @@

__all__ = [
"DataLoader",
"construct_gpu_graph_cache",
]


def construct_gpu_graph_cache(
sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold
):
"Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters."
graph = sample_per_layer_obj.sampler.__self__
num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges)
dtypes = OrderedDict()
dtypes["indices"] = graph.indices.dtype
if graph.type_per_edge is not None:
dtypes["type_per_edge"] = graph.type_per_edge.dtype
if graph.edge_attributes is not None:
probs_or_mask = graph.edge_attributes.get(
sample_per_layer_obj.prob_name, None
)
if probs_or_mask is not None:
dtypes["probs_or_mask"] = probs_or_mask.dtype
return GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
graph.csc_indptr.dtype,
list(dtypes.values()),
)


def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = find_dps(
Expand Down Expand Up @@ -150,9 +122,6 @@ def __init__(
datapipe,
num_workers=0,
persistent_workers=True,
overlap_graph_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
max_uva_threads=10240,
):
# Multiprocessing requires two modifications to the datapipe:
Expand Down Expand Up @@ -200,54 +169,15 @@ def __init__(
if feature_fetcher.max_num_stages > 0: # Overlap enabled.
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)

if (
overlap_graph_fetch
and num_workers == 0
and torch.cuda.is_available()
):
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
if num_workers == 0 and torch.cuda.is_available():
samplers = find_dps(
datapipe_graph,
SamplePerLayer,
)
gpu_graph_cache = None
for sampler in samplers:
if num_gpu_cached_edges > 0 and gpu_graph_cache is None:
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
if (
sampler.sampler.__name__ == "sample_layer_neighbors"
or gpu_graph_cache is not None
):
# This code path is not faster for sample_neighbors.
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(
gpu_graph_cache,
get_host_to_device_uva_stream(),
1,
),
)
elif sampler.sampler.__name__ == "sample_neighbors":
# This code path is faster for sample_neighbors.
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.datapipe.sample_per_layer(
sampler=sampler.sampler,
fanout=sampler.fanout,
replace=sampler.replace,
prob_name=sampler.prob_name,
returning_indices_is_optional=True,
),
)
else:
raise AssertionError(
"overlap_graph_fetch is supported only for "
"sample_neighbor and sample_layer_neighbor."
)
if sampler.overlap_fetch:
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)

Check warning on line 179 in python/dgl/graphbolt/dataloader.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.


# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
# before it. This enables enables non_blocking copies to the device.
Expand Down
28 changes: 28 additions & 0 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..internal_utils import gb_warning, is_wsl, recursive_apply
from ..sampling_graph import SamplingGraph
from .gpu_graph_cache import GPUGraphCache
from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl


Expand Down Expand Up @@ -314,6 +315,14 @@ def _indptr_node_type_offset_list(
):
"""Sets the indptr node type offset list if present."""
self._indptr_node_type_offset_list_ = indptr_node_type_offset_list

Check warning on line 317 in python/dgl/graphbolt/impl/fused_csc_sampling_graph.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.

@property
def _gpu_graph_cache(self) -> Optional[GPUGraphCache]:
return (
self._gpu_graph_cache_
if hasattr(self, "_gpu_graph_cache_")
else None
)

@property
def type_per_edge(self) -> Optional[torch.Tensor]:
Expand Down Expand Up @@ -1432,6 +1441,25 @@ def _pin(x):

return self._apply_to_members(_pin)

def _initialize_gpu_graph_cache(
self, num_gpu_cached_edges: int, gpu_cache_threshold: int, prob_name: Optional[str] = None
):
"Construct a GPUGraphCache given the cache parameters."
num_gpu_cached_edges = min(num_gpu_cached_edges, self.total_num_edges)
dtypes = [self.indices.dtype]
if self.type_per_edge is not None:
dtypes.append(self.type_per_edge.dtype)
if self.edge_attributes is not None:
probs_or_mask = self.edge_attributes.get(prob_name, None)
if probs_or_mask is not None:
dtypes.append(probs_or_mask.dtype)
self._gpu_graph_cache_ = GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
self.csc_indptr.dtype,
dtypes,
)


def fused_csc_sampling_graph(
csc_indptr: torch.Tensor,
Expand Down
Loading

0 comments on commit 3a424dc

Please sign in to comment.