Skip to content

Commit

Permalink
Merge branch 'master' into gb_read_async_cpu_cached_impl_branch_1
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 20, 2024
2 parents a1bcb70 + 78a8611 commit 2a5d9f9
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,53 @@ def wait():
elif ids.is_cuda:
pass
else:
pass
policy_future = policy.query_async(ids)

yield

positions, index, missing_keys, found_keys = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
values_future = cache.query_async(positions, index, ids.shape[0])

positions_future = policy.replace_async(missing_keys)

fallback_reader = self._fallback_feature.read_async(missing_keys)
for _ in range(
self._fallback_feature.read_async_num_stages(
missing_keys.device
)
):
missing_values_future = next(fallback_reader, None)
yield # fallback feature stages.

values = values_future.wait()
reading_completed = policy.reading_completed_async(found_keys)

missing_index = index[positions.size(0) :]

missing_values = missing_values_future.wait()
replace_future = cache.replace_async(
positions_future.wait(), missing_values
)
values = torch.ops.graphbolt.scatter_async(
values, missing_index, missing_values
)

yield

reading_completed.wait()
replace_future.wait()
reading_completed = policy.reading_completed_async(missing_keys)

class _Waiter:
@staticmethod
def wait():
"""Returns the stored value when invoked."""
reading_completed.wait()
return values.wait()

yield _Waiter()

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
Expand Down

0 comments on commit 2a5d9f9

Please sign in to comment.