From 78a86115f1e701e42adacd9d579c5de80825082a Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Sat, 20 Jul 2024 15:34:26 -0400 Subject: [PATCH] [GraphBolt] `CPUCachedFeature.read_async` branch 3 [8] (#7553) --- .../dgl/graphbolt/impl/cpu_cached_feature.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/impl/cpu_cached_feature.py b/python/dgl/graphbolt/impl/cpu_cached_feature.py index 1fe1cb8d1b59..e823aeac92e0 100644 --- a/python/dgl/graphbolt/impl/cpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/cpu_cached_feature.py @@ -104,12 +104,60 @@ def read_async(self, ids: torch.Tensor): ... future = next(async_handle) >>> result = future.wait() # result contains the read values. """ + policy = self._feature._policy + cache = self._feature._cache if ids.is_cuda and self._is_pinned: pass elif ids.is_cuda: pass else: - pass + policy_future = policy.query_async(ids) + + yield + + positions, index, missing_keys, found_keys = policy_future.wait() + self._feature.total_queries += ids.shape[0] + self._feature.total_miss += missing_keys.shape[0] + values_future = cache.query_async(positions, index, ids.shape[0]) + + positions_future = policy.replace_async(missing_keys) + + fallback_reader = self._fallback_feature.read_async(missing_keys) + for _ in range( + self._fallback_feature.read_async_num_stages( + missing_keys.device + ) + ): + missing_values_future = next(fallback_reader, None) + yield # fallback feature stages. + + values = values_future.wait() + reading_completed = policy.reading_completed_async(found_keys) + + missing_index = index[positions.size(0) :] + + missing_values = missing_values_future.wait() + replace_future = cache.replace_async( + positions_future.wait(), missing_values + ) + values = torch.ops.graphbolt.scatter_async( + values, missing_index, missing_values + ) + + yield + + reading_completed.wait() + replace_future.wait() + reading_completed = policy.reading_completed_async(missing_keys) + + class _Waiter: + @staticmethod + def wait(): + """Returns the stored value when invoked.""" + reading_completed.wait() + return values.wait() + + yield _Waiter() def read_async_num_stages(self, ids_device: torch.device): """The number of stages of the read_async operation. See read_async