From 6e46c5607ab3957110a04b1787d0fefe5f8a99a7 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Wed, 8 May 2024 22:28:03 -0400 Subject: [PATCH 1/9] [GraphBolt][CUDA] GPUCachedFeature update fix. --- python/dgl/graphbolt/impl/gpu_cached_feature.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index 0be929ba4abf..99c2b95e00af 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -104,12 +104,19 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): updated. """ if ids is None: + new_feat0 = value[0] + old_feat0 = self._fallback_feature.read(torch.tensor([0])) self._fallback_feature.update(value) - size = min(self.cache_size, value.shape[0]) - self._feature.replace( - torch.arange(0, size, device="cuda"), - value[:size].to("cuda"), - ) + if new_feat0.dtype != old_feat0.dtype or new_feat0.shape != old_feat0.shape: + self.cache_size = self.cache_size * old_feat0.nbytes // new_feat0.nbytes + self._feature = None + self._feature = GPUCache((self.cache_size,) + new_feat0.shape[1:], new_feat0.dtype) + else: + size = min(self.cache_size, value.shape[0]) + self._feature.replace( + torch.arange(0, size, device="cuda"), + value[:size].to("cuda"), + ) else: self._fallback_feature.update(value, ids) self._feature.replace(ids, value) From 2279907885f375a7048452c2cae693eb14ad9f7f Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 03:49:03 +0000 Subject: [PATCH 2/9] Refine implementation, modify examples and lint. --- .../multigpu/graphbolt/node_classification.py | 2 +- .../pyg/node_classification_advanced.py | 12 ++++++++ .../dgl/graphbolt/impl/gpu_cached_feature.py | 29 +++++++++---------- .../graphbolt/impl/test_gpu_cached_feature.py | 4 +++ 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 3df09bf852ea..3ab8ed41a839 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -399,7 +399,7 @@ def parse_args(): "--gpu-cache-size", type=int, default=0, - help="The capacity of the GPU cache, the number of features to store.", + help="The capacity of the GPU cache in bytes.", ) parser.add_argument( "--dataset", diff --git a/examples/sampling/graphbolt/pyg/node_classification_advanced.py b/examples/sampling/graphbolt/pyg/node_classification_advanced.py index 2b5fb19d7518..2f25db523b56 100644 --- a/examples/sampling/graphbolt/pyg/node_classification_advanced.py +++ b/examples/sampling/graphbolt/pyg/node_classification_advanced.py @@ -350,6 +350,12 @@ def parse_args(): help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM," " 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.", ) + parser.add_argument( + "--gpu-cache-size", + type=int, + default=0, + help="The capacity of the GPU cache in bytes.", + ) parser.add_argument( "--sample-mode", default="sample_neighbor", @@ -403,6 +409,12 @@ def main(): num_classes = dataset.tasks[0].metadata["num_classes"] + if args.gpu_cache_size > 0 and args.feature_device != "cuda": + features._features[("node", None, "feat")] = gb.GPUCachedFeature( + features._features[("node", None, "feat")], + args.gpu_cache_size, + ) + train_dataloader, valid_dataloader = ( create_dataloader( graph=graph, diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index 99c2b95e00af..a58b3c139029 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -17,8 +17,8 @@ class GPUCachedFeature(Feature): ---------- fallback_feature : Feature The fallback feature. - cache_size : int - The capacity of the GPU cache, the number of features to store. + max_cache_size_in_bytes : int + The capacity of the GPU cache in bytes. Examples -------- @@ -42,16 +42,17 @@ class GPUCachedFeature(Feature): torch.Size([5]) """ - def __init__(self, fallback_feature: Feature, cache_size: int): + def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int): super(GPUCachedFeature, self).__init__() assert isinstance(fallback_feature, Feature), ( f"The fallback_feature must be an instance of Feature, but got " f"{type(fallback_feature)}." ) self._fallback_feature = fallback_feature - self.cache_size = cache_size + self.max_cache_size_in_bytes = max_cache_size_in_bytes # Fetching the feature dimension from the underlying feature. feat0 = fallback_feature.read(torch.tensor([0])) + cache_size = max_cache_size_in_bytes // feat0.nbytes self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype) def read(self, ids: torch.Tensor = None): @@ -104,19 +105,15 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): updated. """ if ids is None: - new_feat0 = value[0] - old_feat0 = self._fallback_feature.read(torch.tensor([0])) + feat0 = value[:1] self._fallback_feature.update(value) - if new_feat0.dtype != old_feat0.dtype or new_feat0.shape != old_feat0.shape: - self.cache_size = self.cache_size * old_feat0.nbytes // new_feat0.nbytes - self._feature = None - self._feature = GPUCache((self.cache_size,) + new_feat0.shape[1:], new_feat0.dtype) - else: - size = min(self.cache_size, value.shape[0]) - self._feature.replace( - torch.arange(0, size, device="cuda"), - value[:size].to("cuda"), - ) + cache_size = min( + self.max_cache_size_in_bytes // feat0.nbytes, value.shape[0] + ) + self._feature = None # Destroy the existing cache first. + self._feature = GPUCache( + (cache_size,) + feat0.shape[1:], feat0.dtype + ) else: self._fallback_feature.update(value, ids) self._feature.replace(ids, value) diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index eb9a62babff1..3e847e25b391 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -94,3 +94,7 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): feat_store_a.read(), torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to("cuda"), ) + + # Test with different dimensionality + feat_store_a.update(b) + assert torch.equal(feat_store_a.read(), b.to("cuda")) From dab096ad39d2dcc5c2e3c55910227fa8f5444335 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 03:54:33 +0000 Subject: [PATCH 3/9] update test cache_size argument. --- tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index 3e847e25b391..b3cf34440fe9 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -36,6 +36,9 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True ) + cache_size_a *= a[:1].nbytes + cache_size_b *= b[:1].nbytes + feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a) feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b) From fe917bf445b1f8619d81855f70853ca12bab7da4 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 00:57:32 -0400 Subject: [PATCH 4/9] avoid using Tensor.nbytes as it is available after torch 2.1. --- python/dgl/graphbolt/impl/gpu_cached_feature.py | 8 ++++++-- .../pytorch/graphbolt/impl/test_gpu_cached_feature.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index a58b3c139029..f0c526346eaf 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -8,6 +8,10 @@ __all__ = ["GPUCachedFeature"] +def nbytes(tensor): + return tensor.numel() * tensor.element_size() + + class GPUCachedFeature(Feature): r"""GPU cached feature wrapping a fallback feature. @@ -52,7 +56,7 @@ def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int): self.max_cache_size_in_bytes = max_cache_size_in_bytes # Fetching the feature dimension from the underlying feature. feat0 = fallback_feature.read(torch.tensor([0])) - cache_size = max_cache_size_in_bytes // feat0.nbytes + cache_size = max_cache_size_in_bytes // nbytes(feat0) self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype) def read(self, ids: torch.Tensor = None): @@ -108,7 +112,7 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): feat0 = value[:1] self._fallback_feature.update(value) cache_size = min( - self.max_cache_size_in_bytes // feat0.nbytes, value.shape[0] + self.max_cache_size_in_bytes // nbytes(feat0), value.shape[0] ) self._feature = None # Destroy the existing cache first. self._feature = GPUCache( diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index b3cf34440fe9..2a2c82fc7101 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -36,8 +36,8 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True ) - cache_size_a *= a[:1].nbytes - cache_size_b *= b[:1].nbytes + cache_size_a *= a[:1].element_size() * a[:1].numel() + cache_size_b *= b[:1].element_size() * b[:1].numel() feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a) feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b) From 6e0bcc9d6b262bba470334ef2089f9159a29e9ea Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 01:01:37 -0400 Subject: [PATCH 5/9] linting. --- python/dgl/graphbolt/impl/gpu_cached_feature.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index f0c526346eaf..e743937ebdc8 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -9,6 +9,7 @@ def nbytes(tensor): + """Returns the number of bytes to store the given tensor.""" return tensor.numel() * tensor.element_size() From b908775b106aeaabe3b53a93d648c779c2295776 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 01:21:49 -0400 Subject: [PATCH 6/9] fix the remaining bug. --- python/dgl/graphbolt/impl/gpu_cached_feature.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index e743937ebdc8..74b5e497687a 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -57,7 +57,9 @@ def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int): self.max_cache_size_in_bytes = max_cache_size_in_bytes # Fetching the feature dimension from the underlying feature. feat0 = fallback_feature.read(torch.tensor([0])) - cache_size = max_cache_size_in_bytes // nbytes(feat0) + cache_size = (max_cache_size_in_bytes + nbytes(feat0) - 1) // nbytes( + feat0 + ) self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype) def read(self, ids: torch.Tensor = None): @@ -113,7 +115,9 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): feat0 = value[:1] self._fallback_feature.update(value) cache_size = min( - self.max_cache_size_in_bytes // nbytes(feat0), value.shape[0] + (self.max_cache_size_in_bytes + nbytes(feat0) - 1) + // nbytes(feat0), + value.shape[0], ) self._feature = None # Destroy the existing cache first. self._feature = GPUCache( From 3d8cf821be4f5949f890e0838c69a0426a0dbc4c Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 03:49:36 -0400 Subject: [PATCH 7/9] add comment about torch versions. --- python/dgl/graphbolt/impl/gpu_cached_feature.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index 74b5e497687a..2fa27101ded1 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -9,7 +9,11 @@ def nbytes(tensor): - """Returns the number of bytes to store the given tensor.""" + """Returns the number of bytes to store the given tensor. + + Needs to be defined only for torch versions less than 2.1. In torch >= 2.1, + we can simply use "tensor.nbytes". + """ return tensor.numel() * tensor.element_size() From 66a954d5c4fbb64c8b847afed12dea774b7b8a68 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 04:03:22 -0400 Subject: [PATCH 8/9] refactor the division. --- python/dgl/graphbolt/impl/gpu_cached_feature.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index 2fa27101ded1..c0fb086d6879 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -17,6 +17,13 @@ def nbytes(tensor): return tensor.numel() * tensor.element_size() +def num_cache_items(bytes, single_item): + """Returns the number of rows to be cached.""" + item_bytes = nbytes(single_item) + # Round up so that we never get a size of 0, unless bytes is 0. + return (bytes + item_bytes - 1) // item_bytes + + class GPUCachedFeature(Feature): r"""GPU cached feature wrapping a fallback feature. @@ -61,9 +68,7 @@ def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int): self.max_cache_size_in_bytes = max_cache_size_in_bytes # Fetching the feature dimension from the underlying feature. feat0 = fallback_feature.read(torch.tensor([0])) - cache_size = (max_cache_size_in_bytes + nbytes(feat0) - 1) // nbytes( - feat0 - ) + cache_size = num_cache_items(max_cache_size_in_bytes, feat0) self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype) def read(self, ids: torch.Tensor = None): @@ -119,8 +124,7 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): feat0 = value[:1] self._fallback_feature.update(value) cache_size = min( - (self.max_cache_size_in_bytes + nbytes(feat0) - 1) - // nbytes(feat0), + num_cache_items(self.max_cache_size_in_bytes, feat0), value.shape[0], ) self._feature = None # Destroy the existing cache first. From d2e41f8a2e71580ba400a8dda9d04bada5eafb96 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 9 May 2024 04:07:40 -0400 Subject: [PATCH 9/9] linting --- python/dgl/graphbolt/impl/gpu_cached_feature.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index c0fb086d6879..e03402ad4162 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -17,11 +17,11 @@ def nbytes(tensor): return tensor.numel() * tensor.element_size() -def num_cache_items(bytes, single_item): +def num_cache_items(cache_capacity_in_bytes, single_item): """Returns the number of rows to be cached.""" item_bytes = nbytes(single_item) # Round up so that we never get a size of 0, unless bytes is 0. - return (bytes + item_bytes - 1) // item_bytes + return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes class GPUCachedFeature(Feature):