Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Pipelined sampling optimization #7039

Merged
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
b4c045f
prototyping
mfbalin Jan 19, 2024
a26e354
fix bug
mfbalin Jan 20, 2024
58a1190
remove print expressions, works now
mfbalin Jan 20, 2024
eedb8d1
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 23, 2024
3d2ed98
add tests
mfbalin Jan 24, 2024
105924a
use seeds_timestamp in preprocess
mfbalin Jan 24, 2024
bfb28ec
add docstring for linting
mfbalin Jan 24, 2024
e4becc9
fix linting
mfbalin Jan 24, 2024
428ff24
fix argument bug
mfbalin Jan 24, 2024
85b0601
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 24, 2024
866316e
fix the bug
mfbalin Jan 24, 2024
e2793fd
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 24, 2024
2473722
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 27, 2024
2d1dda9
address reviews
mfbalin Jan 29, 2024
fad7c50
add docstring to the new `MinibatchTransformer`.
mfbalin Jan 29, 2024
a8fdfc6
address review properly.
mfbalin Jan 29, 2024
933246f
remove unused `Mapper` import for linting.
mfbalin Jan 29, 2024
cd68728
NeighborSampler2 now derives from `MinibatchTransformer`.
mfbalin Jan 30, 2024
c3a903d
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 30, 2024
dcbfb4e
FInal refactoring of NeighborSampler.
mfbalin Jan 30, 2024
21fe633
Fix not only preprocess but also postprocess issue.
mfbalin Jan 30, 2024
29861f1
take back test changes.
mfbalin Jan 30, 2024
232f2f3
fix in_subgraph_sampler
mfbalin Jan 30, 2024
03bea25
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 30, 2024
86d9c43
add docstring for `append_sampling_step`.
mfbalin Jan 30, 2024
f995d20
Address reviews, minimize changes, keep API exactly the same.
mfbalin Jan 31, 2024
a64d34e
remove leftover changes.
mfbalin Jan 31, 2024
e46b8c7
minor change.
mfbalin Jan 31, 2024
8cc858c
Make the function into a proper one so that it can be pickled.
mfbalin Jan 31, 2024
02ca357
make the lambda into a proper function so that it can be pickled.
mfbalin Jan 31, 2024
19b4367
linting.
mfbalin Jan 31, 2024
67d6f71
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 31, 2024
144134c
final linting.
mfbalin Jan 31, 2024
96bac52
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 31, 2024
718cab8
Cleanup NeighborSampler as it does not need to store anything itself.
mfbalin Jan 31, 2024
ee3a7d7
linting
mfbalin Jan 31, 2024
1d906e7
address reviews by not passing sampler as string argument.
mfbalin Feb 1, 2024
6ab2e75
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Feb 1, 2024
6f880c0
Talk about `sampling_stages` in the SubgraphSampler API.
mfbalin Feb 1, 2024
5d907ee
add more documentation for `sampling_stages`.
mfbalin Feb 1, 2024
e24f4d7
add pipelined sampling optimization datapipes
mfbalin Jan 27, 2024
57def7f
Fix and get the fetch and sample datapipes working.
mfbalin Jan 27, 2024
ebde732
more progress
mfbalin Jan 28, 2024
822abc9
get the implementation working
mfbalin Jan 28, 2024
15b9951
set back the previous default thread count.
mfbalin Jan 28, 2024
4b0bd60
share the thread used for fetching insubgraph.
mfbalin Jan 29, 2024
bca8172
rebase onto refactor neighbor sampler branch
mfbalin Jan 29, 2024
866a74e
add test for the new feature.
mfbalin Jan 29, 2024
bc9a46c
add tests for the graph sampling pipelining components.
mfbalin Jan 29, 2024
6718d83
add buffer_size parameter to pipelined sampling
mfbalin Jan 30, 2024
c5e811e
add docstring for linting.
mfbalin Jan 30, 2024
c99dd5a
remove unused import after rebase
mfbalin Jan 30, 2024
5a500e1
fix the test
mfbalin Jan 30, 2024
0875f3b
Use maximum BLOCK_SIZE for UVA kernels.
mfbalin Jan 30, 2024
18ad8e3
fix the test after rebase.
mfbalin Jan 30, 2024
4f4bae5
fix errors after rebase.
mfbalin Jan 31, 2024
a0a787c
Merge branch 'master' into gb_cuda_pipelined_sampling_optimization
mfbalin Feb 1, 2024
4b29657
linting after other PR merged.
mfbalin Feb 1, 2024
a370010
remove duplicate changes due to merging of the prev PR.
mfbalin Feb 1, 2024
4be1860
remove FetcherAndSampler to address reviews.
mfbalin Feb 1, 2024
f7e90c4
linting
mfbalin Feb 1, 2024
d788b95
Address reviews, move FetcherAndSampler to neighbor_sampler.py
mfbalin Feb 1, 2024
1e8615a
linting
mfbalin Feb 1, 2024
626bded
refactor `MiniBatchTransformer`.
mfbalin Feb 1, 2024
16414e6
fix dataloader text after renaming Awaiter to Waiter.
mfbalin Feb 1, 2024
10c24df
remove unused function from FetcherAndSampler.
mfbalin Feb 1, 2024
91f84ad
Merge branch 'master' into gb_cuda_pipelined_sampling_optimization
mfbalin Feb 2, 2024
05c7ab0
address reviews, add unit test.
mfbalin Feb 4, 2024
23ade71
Merge branch 'master' into gb_cuda_pipelined_sampling_optimization
mfbalin Feb 4, 2024
6b21742
change to `index_select_csc_with_indptr`.
mfbalin Feb 4, 2024
0a21186
Merge branch 'master' into gb_cuda_pipelined_sampling_optimization
mfbalin Feb 4, 2024
e0fa30b
Minor change to trigger CI again.
mfbalin Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base types and utilities for Graph Bolt."""

from collections import deque
from dataclasses import dataclass

import torch
Expand All @@ -14,6 +15,10 @@
"etype_str_to_tuple",
"etype_tuple_to_str",
"CopyTo",
"Waiter",
"Bufferer",
"FutureWaiter",
"EndMarker",
"isin",
"expand_indptr",
"CSCFormatBase",
Expand Down Expand Up @@ -219,6 +224,76 @@ def __iter__(self):
yield data


@functional_datapipe("mark_end")
class EndMarker(IterDataPipe):
"""Used to mark the end of a datapipe and is a no-op."""

def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
yield from self.datapipe


@functional_datapipe("buffer")
class Bufferer(IterDataPipe):
"""Buffers items before yielding them.

Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""

def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
"'buffer_size' is required to be a positive integer."
)
self.buffer = deque(maxlen=buffer_size)

def __iter__(self):
for data in self.datapipe:
if len(self.buffer) < self.buffer.maxlen:
self.buffer.append(data)
else:
return_data = self.buffer.popleft()
self.buffer.append(data)
yield return_data
while len(self.buffer) > 0:
yield self.buffer.popleft()


@functional_datapipe("wait")
class Waiter(IterDataPipe):
"""Calls the wait function of all items."""

def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for data in self.datapipe:
data.wait()
yield data


@functional_datapipe("wait_future")
class FutureWaiter(IterDataPipe):
"""Calls the result function of all items and returns their results."""

def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for data in self.datapipe:
yield data.result()


@dataclass
class CSCFormatBase:
r"""Basic class representing data in Compressed Sparse Column (CSC) format.
Expand Down
87 changes: 27 additions & 60 deletions python/dgl/graphbolt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Graph Bolt DataLoaders"""

from collections import deque
from concurrent.futures import ThreadPoolExecutor

import torch
import torch.utils.data
Expand All @@ -9,15 +9,14 @@

from .base import CopyTo
from .feature_fetcher import FeatureFetcher
from .impl.neighbor_sampler import SamplePerLayer

from .internal import datapipe_graph_to_adjlist
from .item_sampler import ItemSampler


__all__ = [
"DataLoader",
"Awaiter",
"Bufferer",
]


Expand All @@ -40,61 +39,6 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
return datapipe_graph


class EndMarker(dp.iter.IterDataPipe):
"""Used to mark the end of a datapipe and is a no-op."""

def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
yield from self.datapipe


class Bufferer(dp.iter.IterDataPipe):
"""Buffers items before yielding them.

Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""

def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
"'buffer_size' is required to be a positive integer."
)
self.buffer = deque(maxlen=buffer_size)

def __iter__(self):
for data in self.datapipe:
if len(self.buffer) < self.buffer.maxlen:
self.buffer.append(data)
else:
return_data = self.buffer.popleft()
self.buffer.append(data)
yield return_data
while len(self.buffer) > 0:
yield self.buffer.popleft()


class Awaiter(dp.iter.IterDataPipe):
"""Calls the wait function of all items."""

def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for data in self.datapipe:
data.wait()
yield data


class MultiprocessingWrapper(dp.iter.IterDataPipe):
"""Wraps a datapipe with multiprocessing.

Expand Down Expand Up @@ -156,6 +100,10 @@ class DataLoader(torch.utils.data.DataLoader):
If True, the data loader will overlap the UVA feature fetcher operations
with the rest of operations by using an alternative CUDA stream. Default
is True.
overlap_graph_fetch : bool, optional
If True, the data loader will overlap the UVA graph fetching operations
with the rest of operations by using an alternative CUDA stream. Default
is False.
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
Expand All @@ -170,6 +118,7 @@ def __init__(
num_workers=0,
persistent_workers=True,
overlap_feature_fetch=True,
overlap_graph_fetch=False,
max_uva_threads=6144,
):
# Multiprocessing requires two modifications to the datapipe:
Expand All @@ -179,7 +128,7 @@ def __init__(
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.

datapipe = EndMarker(datapipe)
datapipe = datapipe.mark_end()
datapipe_graph = dp_utils.traverse_dps(datapipe)

# (1) Insert minibatch distribution.
Expand Down Expand Up @@ -223,7 +172,25 @@ def __init__(
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
feature_fetcher,
Awaiter(Bufferer(feature_fetcher, buffer_size=1)),
feature_fetcher.buffer(1).wait(),
)

if (
overlap_graph_fetch
and num_workers == 0
and torch.cuda.is_available()
):
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
samplers = dp_utils.find_dps(
datapipe_graph,
SamplePerLayer,
)
executor = ThreadPoolExecutor(max_workers=1)
for sampler in samplers:
datapipe_graph = dp_utils.replace_dp(
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
datapipe_graph,
sampler,
sampler.fetch_and_sample(_get_uva_stream(), executor, 1),
)

# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
Expand Down
Loading
Loading