Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Enable non_blocking copy_to in gb.DataLoader. #7603

Merged
merged 12 commits into from
Jul 28, 2024
5 changes: 3 additions & 2 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,6 @@ def apply_to(x, device, non_blocking=False):
return x
if not non_blocking:
return x.to(device)
# The copy is non blocking only if the objects are pinned.
assert x.is_pinned(), f"{x} should be pinned."
return x.to(device, non_blocking=True)


Expand Down Expand Up @@ -373,6 +371,9 @@ def __init__(self, datapipe, device, non_blocking=False):

def __iter__(self):
for data in self.datapipe:
if self.non_blocking:
# The copy is non blocking only if contents of data are pinned.
assert data.is_pinned(), f"{data} should be pinned."
yield recursive_apply(
data, apply_to, self.device, self.non_blocking
)
Expand Down
25 changes: 17 additions & 8 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,23 @@ def __init__(
),
)

# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
)
# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
# before it. This enables enables non_blocking copies to the device.
# Prefetching enables the data pipeline up to the CopyTo to run in a
# separate thread.
if torch.cuda.is_available():
copiers = dp_utils.find_dps(datapipe_graph, CopyTo)
for copier in copiers:
if copier.device.type == "cuda":
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
copier,
copier.datapipe.transform(
lambda x: x.pin_memory()
).prefetch(2)
# After the data gets pinned, we can copy non_blocking.
.copy_to(copier.device, non_blocking=True),
)

# The stages after feature fetching is still done in the main process.
# So we set num_workers to 0 here.
Expand Down
Loading