Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Refactor overlap_graph_fetch, simplify gb.DataLoader. #7681

Merged
merged 9 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/graphbolt/disk_based_feature/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def create_dataloader(
else {}
)
datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1], **kwargs
graph,
fanout if job != "infer" else [-1],
overlap_fetch=args.overlap_graph_fetch,
**kwargs,
)
# Copy the data to the specified device.
if args.feature_device != "cpu":
Expand All @@ -130,11 +133,7 @@ def create_dataloader(
if args.feature_device == "cpu":
datapipe = datapipe.copy_to(device=device)
# Create and return a DataLoader to handle data loading.
return gb.DataLoader(
datapipe,
num_workers=args.num_workers,
overlap_graph_fetch=args.overlap_graph_fetch,
)
return gb.DataLoader(datapipe, num_workers=args.num_workers)


def train_step(minibatch, optimizer, model, loss_fn):
Expand Down
10 changes: 4 additions & 6 deletions examples/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def create_dataloader(
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1]
graph,
fanout if job != "infer" else [-1],
overlap_fetch=args.storage_device == "pinned",
)

############################################################################
Expand Down Expand Up @@ -156,11 +158,7 @@ def create_dataloader(
# [Role]:
# Initialize a multi-process dataloader to load the data in parallel.
############################################################################
dataloader = gb.DataLoader(
datapipe,
num_workers=num_workers,
overlap_graph_fetch=args.storage_device == "pinned",
)
dataloader = gb.DataLoader(datapipe, num_workers=num_workers)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down
11 changes: 5 additions & 6 deletions examples/graphbolt/pyg/labor/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def create_dataloader(
else {}
)
datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1], **kwargs
graph,
fanout if job != "infer" else [-1],
overlap_fetch=args.overlap_graph_fetch,
**kwargs,
)
# Copy the data to the specified device.
if args.feature_device != "cpu" and need_copy:
Expand All @@ -163,11 +166,7 @@ def create_dataloader(
if need_copy:
datapipe = datapipe.copy_to(device=device)
# Create and return a DataLoader to handle data loading.
return gb.DataLoader(
datapipe,
num_workers=args.num_workers,
overlap_graph_fetch=args.overlap_graph_fetch,
)
return gb.DataLoader(datapipe, num_workers=args.num_workers)


@torch.compile
Expand Down
14 changes: 6 additions & 8 deletions examples/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ def create_dataloader(
need_copy = False
# Sample neighbors for each node in the mini-batch.
datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1]
graph,
fanout if job != "infer" else [-1],
overlap_fetch=args.overlap_graph_fetch,
num_gpu_cached_edges=args.num_gpu_cached_edges,
gpu_cache_threshold=args.gpu_graph_caching_threshold,
)
# Copy the data to the specified device.
if args.feature_device != "cpu" and need_copy:
Expand All @@ -211,13 +215,7 @@ def create_dataloader(
if need_copy:
datapipe = datapipe.copy_to(device=device)
# Create and return a DataLoader to handle data loading.
return gb.DataLoader(
datapipe,
num_workers=args.num_workers,
overlap_graph_fetch=args.overlap_graph_fetch,
num_gpu_cached_edges=args.num_gpu_cached_edges,
gpu_cache_threshold=args.gpu_graph_caching_threshold,
)
return gb.DataLoader(datapipe, num_workers=args.num_workers)


@torch.compile
Expand Down
10 changes: 4 additions & 6 deletions examples/graphbolt/rgcn/hetero_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def create_dataloader(
# The graph(FusedCSCSamplingGraph) from which to sample neighbors.
# `fanouts`:
# The number of neighbors to sample for each node in each layer.
datapipe = datapipe.sample_neighbor(graph, fanouts=fanouts)
datapipe = datapipe.sample_neighbor(
graph, fanouts=fanouts, overlap_fetch=args.overlap_graph_fetch
)

# Fetch the features for each node in the mini-batch.
# `features`:
Expand All @@ -141,11 +143,7 @@ def create_dataloader(
# Create a DataLoader from the datapipe.
# `num_workers`:
# The number of worker processes to use for data loading.
return gb.DataLoader(
datapipe,
num_workers=num_workers,
overlap_graph_fetch=args.overlap_graph_fetch,
)
return gb.DataLoader(datapipe, num_workers=num_workers)


def extract_embed(node_embed, input_nodes):
Expand Down
10 changes: 4 additions & 6 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,14 @@ def create_dataloader(
############################################################################
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device)
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.sample_neighbor(
graph, args.fanout, overlap_fetch=args.storage_device == "pinned"
Copy link
Collaborator

@frozenbugs frozenbugs Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we release the overlap_fetch in last release, if so, we need to highlight the API change if we really believe this is a good api change.

Copy link
Collaborator Author

@mfbalin mfbalin Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is in DGL 2.3. In this release, this parameter moved from gb.DataLoader to sample_neighbor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device)

dataloader = gb.DataLoader(
datapipe,
args.num_workers,
overlap_graph_fetch=args.storage_device == "pinned",
)
dataloader = gb.DataLoader(datapipe, args.num_workers)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down
91 changes: 4 additions & 87 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""Graph Bolt DataLoaders"""

from collections import OrderedDict

import torch
import torch.utils.data as torch_data

from .base import CopyTo, get_host_to_device_uva_stream
from .base import CopyTo
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
from .impl.gpu_graph_cache import GPUGraphCache
from .impl.neighbor_sampler import SamplePerLayer

from .internal import (
Expand All @@ -22,34 +19,9 @@

__all__ = [
"DataLoader",
"construct_gpu_graph_cache",
]


def construct_gpu_graph_cache(
sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold
):
"Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters."
graph = sample_per_layer_obj.sampler.__self__
num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges)
dtypes = OrderedDict()
dtypes["indices"] = graph.indices.dtype
if graph.type_per_edge is not None:
dtypes["type_per_edge"] = graph.type_per_edge.dtype
if graph.edge_attributes is not None:
probs_or_mask = graph.edge_attributes.get(
sample_per_layer_obj.prob_name, None
)
if probs_or_mask is not None:
dtypes["probs_or_mask"] = probs_or_mask.dtype
return GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
graph.csc_indptr.dtype,
list(dtypes.values()),
)


def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = find_dps(
Expand Down Expand Up @@ -125,18 +97,6 @@ class DataLoader(torch_data.DataLoader):
If True, the data loader will not shut down the worker processes after a
dataset has been consumed once. This allows to maintain the workers
instances alive.
overlap_graph_fetch : bool, optional
If True, the data loader will overlap the UVA graph fetching operations
with the rest of operations by using an alternative CUDA stream. This
option should be enabled if you have moved your graph to the pinned
memory for optimal performance. Default is False.
num_gpu_cached_edges : int, optional
If positive and overlap_graph_fetch is True, then the GPU will cache
frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth
demand due to pinned graph accesses.
gpu_cache_threshold : int, optional
Determines how many times a vertex needs to be accessed before its
neighborhood ends up being cached on the GPU.
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
Expand All @@ -150,9 +110,6 @@ def __init__(
datapipe,
num_workers=0,
persistent_workers=True,
overlap_graph_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
max_uva_threads=10240,
):
# Multiprocessing requires two modifications to the datapipe:
Expand Down Expand Up @@ -200,54 +157,14 @@ def __init__(
if feature_fetcher.max_num_stages > 0: # Overlap enabled.
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)

if (
overlap_graph_fetch
and num_workers == 0
and torch.cuda.is_available()
):
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
if num_workers == 0 and torch.cuda.is_available():
samplers = find_dps(
datapipe_graph,
SamplePerLayer,
)
gpu_graph_cache = None
for sampler in samplers:
if num_gpu_cached_edges > 0 and gpu_graph_cache is None:
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
if (
sampler.sampler.__name__ == "sample_layer_neighbors"
or gpu_graph_cache is not None
):
# This code path is not faster for sample_neighbors.
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(
gpu_graph_cache,
get_host_to_device_uva_stream(),
1,
),
)
elif sampler.sampler.__name__ == "sample_neighbors":
# This code path is faster for sample_neighbors.
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.datapipe.sample_per_layer(
sampler=sampler.sampler,
fanout=sampler.fanout,
replace=sampler.replace,
prob_name=sampler.prob_name,
returning_indices_is_optional=True,
),
)
else:
raise AssertionError(
"overlap_graph_fetch is supported only for "
"sample_neighbor and sample_layer_neighbor."
)
if sampler.overlap_fetch:
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)

# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
# before it. This enables enables non_blocking copies to the device.
Expand Down
31 changes: 31 additions & 0 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..internal_utils import gb_warning, is_wsl, recursive_apply
from ..sampling_graph import SamplingGraph
from .gpu_graph_cache import GPUGraphCache
from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl


Expand Down Expand Up @@ -315,6 +316,14 @@ def _indptr_node_type_offset_list(
"""Sets the indptr node type offset list if present."""
self._indptr_node_type_offset_list_ = indptr_node_type_offset_list

@property
def _gpu_graph_cache(self) -> Optional[GPUGraphCache]:
return (
self._gpu_graph_cache_
if hasattr(self, "_gpu_graph_cache_")
else None
)

@property
def type_per_edge(self) -> Optional[torch.Tensor]:
"""Returns the edge type tensor if present.
Expand Down Expand Up @@ -1432,6 +1441,28 @@ def _pin(x):

return self._apply_to_members(_pin)

def _initialize_gpu_graph_cache(
self,
num_gpu_cached_edges: int,
gpu_cache_threshold: int,
prob_name: Optional[str] = None,
):
"Construct a GPUGraphCache given the cache parameters."
num_gpu_cached_edges = min(num_gpu_cached_edges, self.total_num_edges)
dtypes = [self.indices.dtype]
if self.type_per_edge is not None:
dtypes.append(self.type_per_edge.dtype)
if self.edge_attributes is not None:
probs_or_mask = self.edge_attributes.get(prob_name, None)
if probs_or_mask is not None:
dtypes.append(probs_or_mask.dtype)
self._gpu_graph_cache_ = GPUGraphCache(
num_gpu_cached_edges,
gpu_cache_threshold,
self.csc_indptr.dtype,
dtypes,
)


def fused_csc_sampling_graph(
csc_indptr: torch.Tensor,
Expand Down
Loading
Loading