Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] gb.DataLoader can simply be a datapipe. #7732

Merged
merged 6 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.utils.data as torch_data

from .base import CopyTo

from .datapipes import (
datapipe_graph_to_adjlist,
find_dps,
Expand All @@ -15,6 +14,7 @@
from .impl.neighbor_sampler import SamplePerLayer
from .internal_utils import gb_warning
from .item_sampler import ItemSampler
from .minibatch_transformer import MiniBatchTransformer


__all__ = [
Expand Down Expand Up @@ -75,7 +75,7 @@ def __iter__(self):
yield from self.dataloader


class DataLoader(torch_data.DataLoader):
class DataLoader(MiniBatchTransformer):
Rhett-Ying marked this conversation as resolved.
Show resolved Hide resolved
"""Multiprocessing DataLoader.

Iterates over the data pipeline with everything before feature fetching
Expand Down Expand Up @@ -122,32 +122,33 @@ def __init__(
datapipe = datapipe.mark_end()
datapipe_graph = traverse_dps(datapipe)

# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into ItemSampler to maximize efficiency.
item_samplers = find_dps(
datapipe_graph,
ItemSampler,
)
for item_sampler in item_samplers:
datapipe_graph = replace_dp(
if num_workers > 0:
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into ItemSampler to maximize efficiency.
item_samplers = find_dps(
datapipe_graph,
item_sampler,
item_sampler.sharding_filter(),
ItemSampler,
)
for item_sampler in item_samplers:
datapipe_graph = replace_dp(
datapipe_graph,
item_sampler,
item_sampler.sharding_filter(),
)

# (2) Cut datapipe at FeatureFetcher and wrap.
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
FeatureFetcherStartMarker,
MultiprocessingWrapper,
num_workers=num_workers,
persistent_workers=persistent_workers,
)

# (2) Cut datapipe at FeatureFetcher and wrap.
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
FeatureFetcherStartMarker,
MultiprocessingWrapper,
num_workers=num_workers,
persistent_workers=persistent_workers,
)

# (3) Limit the number of UVA threads used if the feature_fetcher has
# overlapping optimization enabled.
# (3) Limit the number of UVA threads used if the feature_fetcher
# or any of the samplers have overlapping optimization enabled.
if num_workers == 0 and torch.cuda.is_available():
feature_fetchers = find_dps(
datapipe_graph,
Expand Down Expand Up @@ -187,6 +188,4 @@ def __init__(
),
)

# The stages after feature fetching is still done in the main process.
# So we set num_workers to 0 here.
super().__init__(datapipe, batch_size=None, num_workers=0)
super().__init__(datapipe)
3 changes: 1 addition & 2 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def test_gpu_sampling_DataLoader(
bufferer_cnt += 2 * num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
datapipe = dataloader.dataset
datapipe_graph = traverse_dps(datapipe)
datapipe_graph = traverse_dps(dataloader)
bufferers = find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,
Expand Down
Loading
Loading