From df678ef195f960454e18e070f2d226184bc95da6 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 19:16:45 -0400 Subject: [PATCH 1/8] [GraphBolt] `recursive_apply_reduce_all`. --- python/dgl/graphbolt/internal_utils.py | 38 ++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/python/dgl/graphbolt/internal_utils.py b/python/dgl/graphbolt/internal_utils.py index 53e7ed8ed5f6..240462442fa2 100644 --- a/python/dgl/graphbolt/internal_utils.py +++ b/python/dgl/graphbolt/internal_utils.py @@ -113,6 +113,44 @@ def recursive_apply(data, fn, *args, **kwargs): return fn(data, *args, **kwargs) +def recursive_apply_reduce_all(data, fn, *args, **kwargs): + """Recursively apply a function to every element in a container and reduce + the boolean results with all. + + If the input data is a list or any sequence other than a string, returns + True if and only if the given function returns True for all elements. + + If the input data is a dict or any mapping, returns True if and only if the + given function returns True for values. + + If the input data is a nested container, the result will be reduced over the + nested structure where each element is tested recursively. + + The first argument of the function will be passed with the individual elements from + the input data, followed by the arguments in :attr:`args` and :attr:`kwargs`. + + Parameters + ---------- + data : any + Any object. + fn : callable + Any function returning a boolean. + args, kwargs : + Additional arguments and keyword-arguments passed to the function. + """ + if isinstance(data, Mapping): + return all( + recursive_apply_reduce_all(v, fn, *args, **kwargs) + for v in data.values() + ) + elif isinstance(data, tuple) or is_listlike(data): + return all( + recursive_apply_reduce_all(v, fn, *args, **kwargs) for v in data + ) + else: + return fn(data, *args, **kwargs) + + def download( url, path=None, From b99db080164170e901dcca77011ee17a43e78de0 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 19:33:35 -0400 Subject: [PATCH 2/8] [GraphBolt] Move `get_property` utils. --- python/dgl/graphbolt/impl/ondisk_dataset.py | 8 ++++-- .../graphbolt/impl/sampled_subgraph_impl.py | 2 +- python/dgl/graphbolt/internal/utils.py | 26 ------------------- python/dgl/graphbolt/internal_utils.py | 26 +++++++++++++++++++ python/dgl/graphbolt/minibatch.py | 7 +++-- 5 files changed, 38 insertions(+), 31 deletions(-) diff --git a/python/dgl/graphbolt/impl/ondisk_dataset.py b/python/dgl/graphbolt/impl/ondisk_dataset.py index 3a2dc601c523..d669dc825509 100644 --- a/python/dgl/graphbolt/impl/ondisk_dataset.py +++ b/python/dgl/graphbolt/impl/ondisk_dataset.py @@ -19,11 +19,15 @@ calculate_dir_hash, check_dataset_change, copy_or_convert_data, - get_attributes, read_data, read_edges, ) -from ..internal_utils import download, extract_archive, gb_warning +from ..internal_utils import ( + download, + extract_archive, + gb_warning, + get_attributes, +) from ..itemset import HeteroItemSet, ItemSet from ..sampling_graph import SamplingGraph from .fused_csc_sampling_graph import ( diff --git a/python/dgl/graphbolt/impl/sampled_subgraph_impl.py b/python/dgl/graphbolt/impl/sampled_subgraph_impl.py index aea13163ca04..d8a4833104c1 100644 --- a/python/dgl/graphbolt/impl/sampled_subgraph_impl.py +++ b/python/dgl/graphbolt/impl/sampled_subgraph_impl.py @@ -6,7 +6,7 @@ import torch from ..base import CSCFormatBase, etype_str_to_tuple -from ..internal import get_attributes +from ..internal_utils import get_attributes from ..sampled_subgraph import SampledSubgraph __all__ = ["SampledSubgraphImpl"] diff --git a/python/dgl/graphbolt/internal/utils.py b/python/dgl/graphbolt/internal/utils.py index 614e2ac5561c..c423ffaa198b 100644 --- a/python/dgl/graphbolt/internal/utils.py +++ b/python/dgl/graphbolt/internal/utils.py @@ -144,32 +144,6 @@ def copy_or_convert_data( save_data(data, output_path, output_format) -def get_nonproperty_attributes(_obj) -> list: - """Get attributes of the class except for the properties.""" - attributes = [ - attribute - for attribute in dir(_obj) - if not attribute.startswith("__") - and ( - not hasattr(type(_obj), attribute) - or not isinstance(getattr(type(_obj), attribute), property) - ) - and not callable(getattr(_obj, attribute)) - ] - return attributes - - -def get_attributes(_obj) -> list: - """Get attributes of the class.""" - attributes = [ - attribute - for attribute in dir(_obj) - if not attribute.startswith("__") - and not callable(getattr(_obj, attribute)) - ] - return attributes - - def read_edges(dataset_dir, edge_fmt, edge_path): """Read egde data from numpy or csv.""" assert edge_fmt in [ diff --git a/python/dgl/graphbolt/internal_utils.py b/python/dgl/graphbolt/internal_utils.py index 240462442fa2..089050ad83b6 100644 --- a/python/dgl/graphbolt/internal_utils.py +++ b/python/dgl/graphbolt/internal_utils.py @@ -151,6 +151,32 @@ def recursive_apply_reduce_all(data, fn, *args, **kwargs): return fn(data, *args, **kwargs) +def get_nonproperty_attributes(_obj) -> list: + """Get attributes of the class except for the properties.""" + attributes = [ + attribute + for attribute in dir(_obj) + if not attribute.startswith("__") + and ( + not hasattr(type(_obj), attribute) + or not isinstance(getattr(type(_obj), attribute), property) + ) + and not callable(getattr(_obj, attribute)) + ] + return attributes + + +def get_attributes(_obj) -> list: + """Get attributes of the class.""" + attributes = [ + attribute + for attribute in dir(_obj) + if not attribute.startswith("__") + and not callable(getattr(_obj, attribute)) + ] + return attributes + + def download( url, path=None, diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index f46288e40a7f..bc2e62011ee7 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -6,8 +6,11 @@ import torch from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr -from .internal import get_attributes, get_nonproperty_attributes -from .internal_utils import recursive_apply +from .internal_utils import ( + get_attributes, + get_nonproperty_attributes, + recursive_apply, +) from .sampled_subgraph import SampledSubgraph __all__ = ["MiniBatch"] From 550eb5469157a0bfaf7b8bd24e7ce39309cfdcb3 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 19:42:15 -0400 Subject: [PATCH 3/8] [GraphBolt][CUDA] Add `non_blocking` option to `CopyTo`. --- python/dgl/graphbolt/base.py | 64 ++++++++++++++++++--- python/dgl/graphbolt/minibatch.py | 27 +++++++-- python/dgl/graphbolt/sampled_subgraph.py | 20 ++++++- tests/python/pytorch/graphbolt/test_base.py | 13 ++++- 4 files changed, 105 insertions(+), 19 deletions(-) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 1c27c8cff8c1..02532b6b0b73 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -20,7 +20,11 @@ from torch.utils.data import functional_datapipe from torchdata.datapipes.iter import IterDataPipe -from .internal_utils import recursive_apply +from .internal_utils import ( + get_nonproperty_attributes, + recursive_apply, + recursive_apply_reduce_all, +) __all__ = [ "CANONICAL_ETYPE_DELIMITER", @@ -306,10 +310,32 @@ def seed_type_str_to_ntypes(seed_type, seed_size): return ntypes -def apply_to(x, device): +def apply_to(x, device, non_blocking=False): """Apply `to` function to object x only if it has `to`.""" - return x.to(device) if hasattr(x, "to") else x + if device == "pinned" and hasattr(x, "pin_memory"): + return x.pin_memory() + if not hasattr(x, "to"): + return x + if not non_blocking: + return x.to(device) + # The copy is non blocking only if the objects are pinned. + assert x.is_pinned(), f"{x} should be pinned." + return x.to(device, non_blocking=True) + + +def is_object_pinned(obj): + """Recursively check all members of the object and return True if only if + all are pinned.""" + + for attr in get_nonproperty_attributes(obj): + member_result = recursive_apply_reduce_all( + getattr(obj, attr), + lambda x: x is None or x.is_pinned(), + ) + if not member_result: + return False + return True @functional_datapipe("copy_to") @@ -334,17 +360,22 @@ class CopyTo(IterDataPipe): The DataPipe. device : torch.device The PyTorch CUDA device. + non_blocking : bool + Whether the copy should be performed without blocking. All elements have + to be already in pinned system memory if enabled. Default is False. """ - def __init__(self, datapipe, device): + def __init__(self, datapipe, device, non_blocking=False): super().__init__() self.datapipe = datapipe - self.device = device + self.device = torch.device(device) + self.non_blocking = non_blocking def __iter__(self): for data in self.datapipe: - data = recursive_apply(data, apply_to, self.device) - yield data + yield recursive_apply( + data, apply_to, self.device, self.non_blocking + ) @functional_datapipe("mark_end") @@ -460,7 +491,9 @@ def __init__(self, indptr: torch.Tensor, indices: torch.Tensor): def __repr__(self) -> str: return _csc_format_base_str(self) - def to(self, device: torch.device) -> None: # pylint: disable=invalid-name + def to( + self, device: torch.device, non_blocking=False + ) -> None: # pylint: disable=invalid-name """Copy `CSCFormatBase` to the specified device using reflection.""" for attr in dir(self): @@ -470,12 +503,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name self, attr, recursive_apply( - getattr(self, attr), lambda x: apply_to(x, device) + getattr(self, attr), + apply_to, + device, + non_blocking=non_blocking, ), ) return self + def pin_memory(self): + """Copy `SampledSubgraph` to the pinned memory using reflection.""" + + return self.to("pinned") + + def is_pinned(self) -> bool: + """Check whether `SampledSubgraph` is pinned using reflection.""" + + return is_object_pinned(self) + def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str: final_str = "CSCFormatBase(" diff --git a/python/dgl/graphbolt/minibatch.py b/python/dgl/graphbolt/minibatch.py index bc2e62011ee7..7b34a8d5f1d3 100644 --- a/python/dgl/graphbolt/minibatch.py +++ b/python/dgl/graphbolt/minibatch.py @@ -5,7 +5,13 @@ import torch -from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr +from .base import ( + apply_to, + CSCFormatBase, + etype_str_to_tuple, + expand_indptr, + is_object_pinned, +) from .internal_utils import ( get_attributes, get_nonproperty_attributes, @@ -350,20 +356,31 @@ def to_pyg_data(self): ) return pyg_data - def to(self, device: torch.device): # pylint: disable=invalid-name + def to( + self, device: torch.device, non_blocking=False + ): # pylint: disable=invalid-name """Copy `MiniBatch` to the specified device using reflection.""" - def _to(x): - return x.to(device) if hasattr(x, "to") else x + copy_fn = lambda x: apply_to(x, device, non_blocking=non_blocking) transfer_attrs = get_nonproperty_attributes(self) for attr in transfer_attrs: # Only copy member variables. - setattr(self, attr, recursive_apply(getattr(self, attr), _to)) + setattr(self, attr, recursive_apply(getattr(self, attr), copy_fn)) return self + def pin_memory(self): + """Copy `MiniBatch` to the pinned memory using reflection.""" + + return self.to("pinned") + + def is_pinned(self) -> bool: + """Check whether `SampledSubgraph` is pinned using reflection.""" + + return is_object_pinned(self) + def _minibatch_str(minibatch: MiniBatch) -> str: final_str = "" diff --git a/python/dgl/graphbolt/sampled_subgraph.py b/python/dgl/graphbolt/sampled_subgraph.py index 1e4f238e1367..d46535115170 100644 --- a/python/dgl/graphbolt/sampled_subgraph.py +++ b/python/dgl/graphbolt/sampled_subgraph.py @@ -10,6 +10,7 @@ CSCFormatBase, etype_str_to_tuple, expand_indptr, + is_object_pinned, isin, ) @@ -232,7 +233,9 @@ def exclude_edges( ) return calling_class(*_slice_subgraph(self, index)) - def to(self, device: torch.device) -> None: # pylint: disable=invalid-name + def to( + self, device: torch.device, non_blocking=False + ) -> None: # pylint: disable=invalid-name """Copy `SampledSubgraph` to the specified device using reflection.""" for attr in dir(self): @@ -242,12 +245,25 @@ def to(self, device: torch.device) -> None: # pylint: disable=invalid-name self, attr, recursive_apply( - getattr(self, attr), lambda x: apply_to(x, device) + getattr(self, attr), + apply_to, + device, + non_blocking=non_blocking, ), ) return self + def pin_memory(self): + """Copy `SampledSubgraph` to the pinned memory using reflection.""" + + return self.to("pinned") + + def is_pinned(self) -> bool: + """Check whether `SampledSubgraph` is pinned using reflection.""" + + return is_object_pinned(self) + def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): indptr = node_pair.indptr diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index a65264369ed3..3ca506314a49 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -12,19 +12,26 @@ from . import gb_test_utils -@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") -def test_CopyTo(): +@unittest.skipIf(F._default_context_str != "gpu", "CopyTo needs GPU to test") +@pytest.mark.parametrize("non_blocking", [False, True]) +def test_CopyTo(non_blocking): item_sampler = gb.ItemSampler( gb.ItemSet(torch.arange(20), names="seeds"), 4 ) + if non_blocking: + item_sampler = item_sampler.transform(lambda x: x.pin_memory()) # Invoke CopyTo via class constructor. dp = gb.CopyTo(item_sampler, "cuda") for data in dp: assert data.seeds.device.type == "cuda" + dp = gb.CopyTo(item_sampler, "cuda", non_blocking) + for data in dp: + assert data.seeds.device.type == "cuda" + # Invoke CopyTo via functional form. - dp = item_sampler.copy_to("cuda") + dp = item_sampler.copy_to("cuda", non_blocking) for data in dp: assert data.seeds.device.type == "cuda" From f26836292dc722aaa9ce1096611942b204b54940 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 19:54:38 -0400 Subject: [PATCH 4/8] [GraphBolt][CUDA] Enable `non_blocking` `copy_to` in `gb.DataLoader`. --- python/dgl/graphbolt/dataloader.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index 71e27693f1db..516f85070ae8 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -224,14 +224,24 @@ def __init__( ), ) - # (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the - # data pipeline up to the CopyTo operation to run in a separate thread. - datapipe_graph = _find_and_wrap_parent( - datapipe_graph, - CopyTo, - dp.iter.Prefetcher, - buffer_size=2, - ) + # (4) Cut datapipe at CopyTo and wrap with PinMemory before CopyTo. This + # enables enables non_blocking copies to the device. PinMemory already + # is a PrefetcherIterDataPipe so the data pipeline up to the CopyTo will + # run in a separate thread. + if torch.cuda.is_available(): + copiers = dp_utils.find_dps(datapipe_graph, CopyTo) + for copier in copiers: + if copier.device.type == "cuda": + datapipe_graph = dp_utils.replace_dp( + datapipe_graph, + copier, + # Prefetcher is inside this datapipe already. + dp.iter.PinMemory( + copier.datapipe, + pin_memory_fn=lambda x, _: x.pin_memory(), + ).copy_to(copier.device, non_blocking=True), + # After the data gets pinned, we copy non_blocking. + ) # The stages after feature fetching is still done in the main process. # So we set num_workers to 0 here. From 13e9a6880b8371ab0bc5cde129a7141acb0fd48e Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 19:58:08 -0400 Subject: [PATCH 5/8] fix linting --- python/dgl/graphbolt/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 02532b6b0b73..c1249170d51f 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -491,9 +491,9 @@ def __init__(self, indptr: torch.Tensor, indices: torch.Tensor): def __repr__(self) -> str: return _csc_format_base_str(self) - def to( + def to( # pylint: disable=invalid-name self, device: torch.device, non_blocking=False - ) -> None: # pylint: disable=invalid-name + ) -> None: """Copy `CSCFormatBase` to the specified device using reflection.""" for attr in dir(self): From 8f594c2cc766166e75c7f014772a86411ca78a2e Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 21:07:44 -0400 Subject: [PATCH 6/8] move the is_pinned check insided `CopyTo`. --- python/dgl/graphbolt/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index c1249170d51f..8e7e6365f413 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -319,8 +319,6 @@ def apply_to(x, device, non_blocking=False): return x if not non_blocking: return x.to(device) - # The copy is non blocking only if the objects are pinned. - assert x.is_pinned(), f"{x} should be pinned." return x.to(device, non_blocking=True) @@ -373,6 +371,9 @@ def __init__(self, datapipe, device, non_blocking=False): def __iter__(self): for data in self.datapipe: + if self.non_blocking: + # The copy is non blocking only if contents of data are pinned. + assert data.is_pinned(), f"{data} should be pinned." yield recursive_apply( data, apply_to, self.device, self.non_blocking ) From c9945a9dc3ce8b347c3d101146696fea1bef283c Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 21:44:53 -0400 Subject: [PATCH 7/8] stop using `PinMemory`. --- python/dgl/graphbolt/dataloader.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index 516f85070ae8..9aba0b2d30f8 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -224,10 +224,10 @@ def __init__( ), ) - # (4) Cut datapipe at CopyTo and wrap with PinMemory before CopyTo. This - # enables enables non_blocking copies to the device. PinMemory already - # is a PrefetcherIterDataPipe so the data pipeline up to the CopyTo will - # run in a separate thread. + # (4) Cut datapipe at CopyTo and wrap with pinning and prefetching + # before CopyTo. This enables enables non_blocking copies to the device. + # Prefetching enables the data pipeline up to the CopyTo to run in a + # separate thread. if torch.cuda.is_available(): copiers = dp_utils.find_dps(datapipe_graph, CopyTo) for copier in copiers: @@ -235,12 +235,11 @@ def __init__( datapipe_graph = dp_utils.replace_dp( datapipe_graph, copier, - # Prefetcher is inside this datapipe already. - dp.iter.PinMemory( - copier.datapipe, - pin_memory_fn=lambda x, _: x.pin_memory(), - ).copy_to(copier.device, non_blocking=True), - # After the data gets pinned, we copy non_blocking. + copier.datapipe.transform( + lambda x: x.pin_memory() + ).prefetch(2) + # After the data gets pinned, we can copy non_blocking. + .copy_to(copier.device, non_blocking=True), ) # The stages after feature fetching is still done in the main process. From 177b9129fac55aba35188d64d8b48cce57f863a6 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sat, 27 Jul 2024 21:45:44 -0400 Subject: [PATCH 8/8] refine comment. --- python/dgl/graphbolt/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index 9aba0b2d30f8..c81a6b0a7e6a 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -225,7 +225,7 @@ def __init__( ) # (4) Cut datapipe at CopyTo and wrap with pinning and prefetching - # before CopyTo. This enables enables non_blocking copies to the device. + # before it. This enables enables non_blocking copies to the device. # Prefetching enables the data pipeline up to the CopyTo to run in a # separate thread. if torch.cuda.is_available():