Skip to content

Commit

Permalink
Remove dependency on torchdata. (#7638)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
  • Loading branch information
frozenbugs and Ubuntu committed Aug 6, 2024
1 parent cb4604a commit 26ff09f
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 28 deletions.
28 changes: 16 additions & 12 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@

import torch
import torch.utils.data as torch_data
import torchdata.dataloader2.graph as dp_utils

from .base import CopyTo, get_host_to_device_uva_stream
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
from .impl.gpu_graph_cache import GPUGraphCache
from .impl.neighbor_sampler import SamplePerLayer

from .internal import datapipe_graph_to_adjlist
from .internal import (
datapipe_graph_to_adjlist,
find_dps,
replace_dp,
traverse_dps,
)
from .item_sampler import ItemSampler


Expand Down Expand Up @@ -47,7 +51,7 @@ def construct_gpu_graph_cache(

def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
datapipes = find_dps(
datapipe_graph,
target_datapipe,
)
Expand All @@ -56,7 +60,7 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
parent_datapipe,
wrapper(parent_datapipe, **kwargs),
Expand Down Expand Up @@ -157,18 +161,18 @@ def __init__(
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.

datapipe = datapipe.mark_end()
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_graph = traverse_dps(datapipe)

# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into ItemSampler to maximize efficiency.
item_samplers = dp_utils.find_dps(
item_samplers = find_dps(
datapipe_graph,
ItemSampler,
)
for item_sampler in item_samplers:
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
item_sampler,
item_sampler.sharding_filter(),
Expand All @@ -186,7 +190,7 @@ def __init__(
# (3) Limit the number of UVA threads used if the feature_fetcher has
# overlapping optimization enabled.
if num_workers == 0 and torch.cuda.is_available():
feature_fetchers = dp_utils.find_dps(
feature_fetchers = find_dps(
datapipe_graph,
FeatureFetcher,
)
Expand All @@ -200,7 +204,7 @@ def __init__(
and torch.cuda.is_available()
):
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
samplers = dp_utils.find_dps(
samplers = find_dps(
datapipe_graph,
SamplePerLayer,
)
Expand All @@ -210,7 +214,7 @@ def __init__(
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(
Expand All @@ -225,10 +229,10 @@ def __init__(
# Prefetching enables the data pipeline up to the CopyTo to run in a
# separate thread.
if torch.cuda.is_available():
copiers = dp_utils.find_dps(datapipe_graph, CopyTo)
copiers = find_dps(datapipe_graph, CopyTo)
for copier in copiers:
if copier.device.type == "cuda":
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
copier,
# Add prefetch so that CPU and GPU can run concurrently.
Expand Down
Loading

0 comments on commit 26ff09f

Please sign in to comment.