Skip to content

Commit

Permalink
[GraphBolt] CPUCachedFeature.read_async branch 3 [8]
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 20, 2024
1 parent b5992a2 commit 04241f1
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,60 @@ def read_async(self, ids: torch.Tensor):
... future = next(async_handle)
>>> result = future.wait() # result contains the read values.
"""
policy = self._feature._policy
cache = self._feature._cache
if ids.is_cuda and self._is_pinned:
pass
elif ids.is_cuda:
pass
else:
pass
policy_future = policy.query_async(ids)

yield

positions, index, missing_keys, found_keys = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
values_future = cache.query_async(positions, index, ids.shape[0])

positions_future = policy.replace_async(missing_keys)

fallback_reader = self._fallback_feature.read_async(missing_keys)
for _ in range(
self._fallback_feature.read_async_num_stages(
missing_keys.device
)
):
missing_values_future = next(fallback_reader, None)
yield # fallback feature stages.

values = values_future.wait()
reading_completed = policy.reading_completed_async(found_keys)

missing_index = index[positions.size(0) :]

missing_values = missing_values_future.wait()
replace_future = cache.replace_async(
positions_future.wait(), missing_values
)
values = torch.ops.graphbolt.scatter_async(
values, missing_index, missing_values
)

yield

reading_completed.wait()
replace_future.wait()
reading_completed = policy.reading_completed_async(missing_keys)

class _Waiter:
@staticmethod
def wait():
"""Returns the stored value when invoked."""
reading_completed.wait()
return values.wait()

yield _Waiter()

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
Expand Down

0 comments on commit 04241f1

Please sign in to comment.