Skip to content

Commit

Permalink
Merge branch 'master' into gb_dataloader_datapipe
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 22, 2024
2 parents 026d2b8 + c45d299 commit 6edc601
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 83 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/python/dgl.graphbolt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ collection of features.
TorchBasedFeature
TorchBasedFeatureStore
DiskBasedFeature
cpu_cached_feature
CPUCachedFeature
gpu_cached_feature
GPUCachedFeature


Expand Down
4 changes: 2 additions & 2 deletions examples/graphbolt/disk_based_feature/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def main():
if args.cpu_cache_size_in_gigabytes > 0 and isinstance(
features[("node", None, "feat")], gb.DiskBasedFeature
):
features[("node", None, "feat")] = gb.CPUCachedFeature(
features[("node", None, "feat")] = gb.cpu_cached_feature(
features[("node", None, "feat")],
int(args.cpu_cache_size_in_gigabytes * 1024 * 1024 * 1024),
args.cpu_feature_cache_policy,
Expand All @@ -474,7 +474,7 @@ def main():
host-to-device copy operations for this feature.
"""
if args.gpu_cache_size_in_gigabytes > 0 and args.feature_device != "cuda":
features[("node", None, "feat")] = gb.GPUCachedFeature(
features[("node", None, "feat")] = gb.gpu_cached_feature(
features[("node", None, "feat")],
int(args.gpu_cache_size_in_gigabytes * 1024 * 1024 * 1024),
)
Expand Down
4 changes: 2 additions & 2 deletions examples/graphbolt/pyg/labor/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def main():
if args.num_cpu_cached_features > 0 and isinstance(
features[("node", None, "feat")], gb.DiskBasedFeature
):
features[("node", None, "feat")] = gb.CPUCachedFeature(
features[("node", None, "feat")] = gb.cpu_cached_feature(
features[("node", None, "feat")],
args.num_cpu_cached_features * feature_num_bytes,
args.cpu_feature_cache_policy,
Expand All @@ -505,7 +505,7 @@ def main():
else:
cpu_cache_miss_rate_fn = lambda: 1
if args.num_gpu_cached_features > 0 and args.feature_device != "cuda":
features[("node", None, "feat")] = gb.GPUCachedFeature(
features[("node", None, "feat")] = gb.gpu_cached_feature(
features[("node", None, "feat")],
args.num_gpu_cached_features * feature_num_bytes,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def main():
num_classes = dataset.tasks[0].metadata["num_classes"]

if args.gpu_cache_size > 0 and args.feature_device != "cuda":
features._features[("node", None, "feat")] = gb.GPUCachedFeature(
features._features[("node", None, "feat")] = gb.gpu_cached_feature(
features._features[("node", None, "feat")],
args.gpu_cache_size,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def run(rank, world_size, args, devices, dataset):
out_size = num_classes

if args.gpu_cache_size > 0 and args.storage_device != "cuda":
feature[("node", None, "feat")] = gb.GPUCachedFeature(
feature[("node", None, "feat")] = gb.gpu_cached_feature(
feature[("node", None, "feat")],
args.gpu_cache_size,
)
Expand Down
64 changes: 62 additions & 2 deletions python/dgl/graphbolt/feature_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Feature store for GraphBolt."""

from typing import NamedTuple
from typing import Dict, NamedTuple, Union

import torch

__all__ = ["Feature", "FeatureStore", "FeatureKey"]
__all__ = [
"bytes_to_number_of_items",
"Feature",
"FeatureStore",
"FeatureKey",
"wrap_with_cached_feature",
]


class FeatureKey(NamedTuple):
Expand Down Expand Up @@ -289,3 +295,57 @@ def keys(self):
feat_name)` format.
"""
raise NotImplementedError


def bytes_to_number_of_items(cache_capacity_in_bytes, single_item):
"""Returns the number of rows to be cached."""
item_bytes = single_item.nbytes
# Round up so that we never get a size of 0, unless bytes is 0.
return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes


def wrap_with_cached_feature(
cached_feature_type,
fallback_features: Union[Feature, Dict[FeatureKey, Feature]],
max_cache_size_in_bytes: int,
*args,
**kwargs,
) -> Union[Feature, Dict[FeatureKey, Feature]]:
"""Wraps the given features with the given cached feature type using
a single cache instance."""
if not isinstance(fallback_features, dict):
assert isinstance(fallback_features, Feature)
return wrap_with_cached_feature(
cached_feature_type,
{"a": fallback_features},
max_cache_size_in_bytes,
*args,
**kwargs,
)["a"]
row_bytes = None
cache = None
wrapped_features = {}
offset = 0
for feature_key, fallback_feature in fallback_features.items():
# Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0]))
if row_bytes is None:
row_bytes = feat0.nbytes
else:
assert (
row_bytes == feat0.nbytes
), "The # bytes of a single row of the features should match."
cache_size = bytes_to_number_of_items(max_cache_size_in_bytes, feat0)
if cache is None:
cache = cached_feature_type._cache_type(
cache_shape=(cache_size,) + feat0.shape[1:],
dtype=feat0.dtype,
*args,
**kwargs,
)
wrapped_features[feature_key] = cached_feature_type(
fallback_feature, cache=cache, offset=offset
)
offset += fallback_feature.count()

return wrapped_features
110 changes: 71 additions & 39 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,53 @@
"""CPU cached feature for GraphBolt."""
from typing import Dict, Optional, Union

import torch

from ..base import get_device_to_host_uva_stream, get_host_to_device_uva_stream
from ..feature_store import Feature
from ..feature_store import (
bytes_to_number_of_items,
Feature,
FeatureKey,
wrap_with_cached_feature,
)

from .cpu_feature_cache import CPUFeatureCache

__all__ = ["CPUCachedFeature"]


def bytes_to_number_of_items(cache_capacity_in_bytes, single_item):
"""Returns the number of rows to be cached."""
item_bytes = single_item.nbytes
# Round up so that we never get a size of 0, unless bytes is 0.
return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes
__all__ = ["CPUCachedFeature", "cpu_cached_feature"]


class CPUCachedFeature(Feature):
r"""CPU cached feature wrapping a fallback feature.
r"""CPU cached feature wrapping a fallback feature. Use `cpu_feature_cache`
to construct an instance of this class.
Parameters
----------
fallback_feature : Feature
The fallback feature.
max_cache_size_in_bytes : int
The capacity of the cache in bytes. The size should be a few factors
larger than the size of each read request. Otherwise, the caching policy
will hang due to all cache entries being read and/or write locked,
resulting in a deadlock.
policy : str
The cache eviction policy algorithm name. The available policies are
["s3-fifo", "sieve", "lru", "clock"]. Default is "sieve".
pin_memory : bool
Whether the cache storage should be allocated on system pinned memory.
Default is False.
cache : CPUFeatureCache
A CPUFeatureCache instance to serve as the cache backend.
offset : int, optional
The offset value to add to the given ids before using the cache. This
parameter is useful if multiple `CPUCachedFeature`s are sharing a single
CPUFeatureCache object.
"""

_cache_type = CPUFeatureCache

def __init__(
self,
fallback_feature: Feature,
max_cache_size_in_bytes: int,
policy: str = None,
pin_memory: bool = False,
cache: CPUFeatureCache,
offset: int = 0,
):
super(CPUCachedFeature, self).__init__()
assert isinstance(fallback_feature, Feature), (
f"The fallback_feature must be an instance of Feature, but got "
f"{type(fallback_feature)}."
)
self._fallback_feature = fallback_feature
self.max_cache_size_in_bytes = max_cache_size_in_bytes
# Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0]))
cache_size = bytes_to_number_of_items(max_cache_size_in_bytes, feat0)
self._feature = CPUFeatureCache(
(cache_size,) + feat0.shape[1:],
feat0.dtype,
policy=policy,
pin_memory=pin_memory,
)
self._is_pinned = pin_memory
self._offset = 0
self._feature = cache
self._offset = offset

def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
Expand Down Expand Up @@ -111,7 +97,7 @@ def read_async(self, ids: torch.Tensor):
"""
policy = self._feature._policy
cache = self._feature._cache
if ids.is_cuda and self._is_pinned:
if ids.is_cuda and self.is_pinned():
ids_device = ids.device
current_stream = torch.cuda.current_stream()
device_to_host_stream = get_device_to_host_uva_stream()
Expand Down Expand Up @@ -450,18 +436,64 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
feat0 = value[:1]
self._fallback_feature.update(value)
cache_size = min(
bytes_to_number_of_items(self.max_cache_size_in_bytes, feat0),
bytes_to_number_of_items(self.cache_size_in_bytes, feat0),
value.shape[0],
)
self._feature = None # Destroy the existing cache first.
self._feature = CPUFeatureCache(
self._feature = self._cache_type(
(cache_size,) + feat0.shape[1:], feat0.dtype
)
else:
self._fallback_feature.update(value, ids)
self._feature.replace(ids, value, None, self._offset)

def is_pinned(self):
"""Returns True if the cache storage is pinned."""
return self._feature.is_pinned()

@property
def cache_size_in_bytes(self):
"""Return the size taken by the cache in bytes."""
return self._feature.max_size_in_bytes

@property
def miss_rate(self):
"""Returns the cache miss rate since creation."""
return self._feature.miss_rate


def cpu_cached_feature(
fallback_features: Union[Feature, Dict[FeatureKey, Feature]],
max_cache_size_in_bytes: int,
policy: Optional[str] = None,
pin_memory: bool = False,
) -> Union[CPUCachedFeature, Dict[FeatureKey, CPUCachedFeature]]:
r"""CPU cached feature wrapping a fallback feature.
Parameters
----------
fallback_features : Union[Feature, Dict[FeatureKey, Feature]]
The fallback feature(s).
max_cache_size_in_bytes : int
The capacity of the cache in bytes. The size should be a few factors
larger than the size of each read request. Otherwise, the caching policy
will hang due to all cache entries being read and/or write locked,
resulting in a deadlock.
policy : str, optional
The cache eviction policy algorithm name. The available policies are
["s3-fifo", "sieve", "lru", "clock"]. Default is "sieve".
pin_memory : bool, optional
Whether the cache storage should be allocated on system pinned memory.
Default is False.
Returns
-------
Union[CPUCachedFeature, Dict[FeatureKey, CPUCachedFeature]]
New feature(s) wrapped with CPUCachedFeature.
"""
return wrap_with_cached_feature(
CPUCachedFeature,
fallback_features,
max_cache_size_in_bytes,
policy=policy,
pin_memory=pin_memory,
)
Loading

0 comments on commit 6edc601

Please sign in to comment.