Skip to content

Commit

Permalink
[GraphBolt] CPUCachedFeature.read_async body.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 20, 2024
1 parent 8d770b6 commit b5d9299
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
policy=policy,
pin_memory=pin_memory,
)
self._is_pinned = pin_memory

def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
Expand Down Expand Up @@ -103,7 +104,12 @@ def read_async(self, ids: torch.Tensor):
... future = next(async_handle)
>>> result = future.wait() # result contains the read values.
"""
raise NotImplementedError
if ids.is_cuda and self._is_pinned:
raise NotImplementedError
elif ids.is_cuda:
raise NotImplementedError
else:
raise NotImplementedError

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
Expand All @@ -117,7 +123,12 @@ def read_async_num_stages(self, ids_device: torch.device):
int
The number of stages of the read_async operation.
"""
raise NotImplementedError
if ids_device.type == "cuda":
return 4 + self._fallback_feature.read_async_num_stages(
torch.device("cpu")
)
else:
return 3 + self._fallback_feature.read_async_num_stages(ids_device)

def size(self):
"""Get the size of the feature.
Expand Down

0 comments on commit b5d9299

Please sign in to comment.