Skip to content

Commit

Permalink
avoid using Tensor.nbytes as it is available after torch 2.1.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed May 9, 2024
1 parent dab096a commit 9370208
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions python/dgl/graphbolt/impl/gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
__all__ = ["GPUCachedFeature"]


def nbytes(tensor):
return tensor.numel() * tensor.element_size()

Check warning on line 13 in python/dgl/graphbolt/impl/gpu_cached_feature.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
class GPUCachedFeature(Feature):
r"""GPU cached feature wrapping a fallback feature.
Expand Down Expand Up @@ -52,7 +55,7 @@ def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int):
self.max_cache_size_in_bytes = max_cache_size_in_bytes
# Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0]))
cache_size = max_cache_size_in_bytes // feat0.nbytes
cache_size = max_cache_size_in_bytes // nbytes(feat0)
self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype)

def read(self, ids: torch.Tensor = None):
Expand Down Expand Up @@ -108,7 +111,7 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
feat0 = value[:1]
self._fallback_feature.update(value)
cache_size = min(
self.max_cache_size_in_bytes // feat0.nbytes, value.shape[0]
self.max_cache_size_in_bytes // nbytes(feat0), value.shape[0]
)
self._feature = None # Destroy the existing cache first.
self._feature = GPUCache(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
[[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True
)

cache_size_a *= a[:1].nbytes
cache_size_b *= b[:1].nbytes
cache_size_a *= a[:1].element_size() * a[:1].numel()
cache_size_b *= b[:1].element_size() * b[:1].numel()

feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a)
feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b)
Expand Down

0 comments on commit 9370208

Please sign in to comment.