Skip to content

Commit

Permalink
[GraphBolt] Refine FeatureCache and increase test coverage. (#7531)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 16, 2024
1 parent e38536d commit c3b774d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
16 changes: 12 additions & 4 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""HugeCTR gpu_cache wrapper for graphbolt."""
"""CPU Feature Cache implementation wrapper for graphbolt."""
import torch

__all__ = ["FeatureCache"]
Expand All @@ -20,21 +20,29 @@ class FeatureCache(object):
The shape of the cache. cache_shape[0] gives us the capacity.
dtype : torch.dtype
The data type of the elements stored in the cache.
num_parts: int, optional
The number of cache partitions for parallelism. Default is 1.
policy: str, optional
The cache policy. Default is "sieve". "s3-fifo", "lru" and "clock" are
also available.
num_parts: int, optional
The number of cache partitions for parallelism. Default is
`torch.get_num_threads()`.
pin_memory: bool, optional
Whether the cache storage should be pinned.
"""

def __init__(
self, cache_shape, dtype, num_parts=1, policy="sieve", pin_memory=False
self,
cache_shape,
dtype,
policy="sieve",
num_parts=None,
pin_memory=False,
):
assert (
policy in caching_policies
), f"{list(caching_policies.keys())} are the available caching policies."
if num_parts is None:
num_parts = torch.get_num_threads()
self._policy = caching_policies[policy](cache_shape[0], num_parts)
self._cache = torch.ops.graphbolt.feature_cache(
cache_shape, dtype, pin_memory
Expand Down
19 changes: 16 additions & 3 deletions tests/python/pytorch/graphbolt/impl/test_feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
],
)
@pytest.mark.parametrize("feature_size", [2, 16])
@pytest.mark.parametrize("num_parts", [1, 2])
@pytest.mark.parametrize("num_parts", [1, 2, None])
@pytest.mark.parametrize("policy", ["s3-fifo", "sieve", "lru", "clock"])
def test_feature_cache(dtype, feature_size, num_parts, policy):
cache_size = 32 * num_parts
cache_size = 32 * (
torch.get_num_threads() if num_parts is None else num_parts
)
a = torch.randint(0, 2, [1024, feature_size], dtype=dtype)
cache = gb.impl.FeatureCache(
(cache_size,) + a.shape[1:], a.dtype, num_parts, policy
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
)

keys = torch.tensor([0, 1])
Expand Down Expand Up @@ -73,3 +75,14 @@ def test_feature_cache(dtype, feature_size, num_parts, policy):
cache.replace(missing_keys, missing_values)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])

raw_feature_cache = torch.ops.graphbolt.feature_cache(
(cache_size,) + a.shape[1:], a.dtype, pin_memory
)
idx = torch.tensor([0, 1, 2])
raw_feature_cache.replace(idx, a[idx])
val = raw_feature_cache.index_select(idx)
assert torch.equal(val, a[idx])
if pin_memory:
val = raw_feature_cache.index_select(idx.to(F.ctx()))
assert torch.equal(val, a[idx].to(F.ctx()))

0 comments on commit c3b774d

Please sign in to comment.