Skip to content

Commit

Permalink
put these changes to another PR.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 18, 2024
1 parent e5a9bc4 commit 4f2b098
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
8 changes: 3 additions & 5 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""CPU Feature Cache implementation wrapper for graphbolt."""
import torch

__all__ = ["CPUFeatureCache"]
__all__ = ["FeatureCache"]

caching_policies = {
"s3-fifo": torch.ops.graphbolt.s3_fifo_cache_policy,
Expand All @@ -11,7 +11,7 @@
}


class CPUFeatureCache(object):
class FeatureCache(object):
r"""High level wrapper for the CPU feature cache.
Parameters
Expand All @@ -34,12 +34,10 @@ def __init__(
self,
cache_shape,
dtype,
policy=None,
policy="sieve",
num_parts=None,
pin_memory=False,
):
if policy is None:
policy = "sieve"
assert (
policy in caching_policies
), f"{list(caching_policies.keys())} are the available caching policies."
Expand Down
2 changes: 1 addition & 1 deletion tests/python/pytorch/graphbolt/impl/test_feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_feature_cache(dtype, feature_size, num_parts, policy):
torch.get_num_threads() if num_parts is None else num_parts
)
a = torch.randint(0, 2, [1024, feature_size], dtype=dtype)
cache = gb.impl.CPUFeatureCache(
cache = gb.impl.FeatureCache(
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
)

Expand Down

0 comments on commit 4f2b098

Please sign in to comment.