Skip to content

Commit

Permalink
[GraphBolt] CPUCachedFeature.read_async branch 1. [6]
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 20, 2024
1 parent b5992a2 commit 89defad
Showing 1 changed file with 84 additions and 1 deletion.
85 changes: 84 additions & 1 deletion python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from ..base import get_device_to_host_uva_stream, get_host_to_device_uva_stream
from ..feature_store import Feature

from .feature_cache import CPUFeatureCache
Expand Down Expand Up @@ -104,8 +105,90 @@ def read_async(self, ids: torch.Tensor):
... future = next(async_handle)
>>> result = future.wait() # result contains the read values.
"""
policy = self._feature._policy
cache = self._feature._cache
if ids.is_cuda and self._is_pinned:
pass
ids_device = ids.device
current_stream = torch.cuda.current_stream()
device_to_host_stream = get_device_to_host_uva_stream()
device_to_host_stream.wait_stream(current_stream)
with torch.cuda.stream(device_to_host_stream):
ids.record_stream(torch.cuda.current_stream())
ids_cuda = ids
ids = ids.to("cpu", non_blocking=True)
ids_copy_event = torch.cuda.Event()
ids_copy_event.record()

yield # first stage is done.

ids_copy_event.synchronize()
policy_future = policy.query_async(ids)

yield

positions, index, missing_keys, found_keys = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
host_to_device_stream = get_host_to_device_uva_stream()
with torch.cuda.stream(host_to_device_stream):
positions_cuda = positions.to(ids_device, non_blocking=True)
values_from_cpu = cache.index_select(positions_cuda)
values_from_cpu.record_stream(current_stream)
values_from_cpu_copy_event = torch.cuda.Event()
values_from_cpu_copy_event.record()

positions_future = policy.replace_async(missing_keys)

fallback_reader = self._fallback_feature.read_async(missing_keys)
for _ in range(
self._fallback_feature.read_async_num_stages(
missing_keys.device
)
):
missing_values_future = next(fallback_reader, None)
yield # fallback feature stages.

values_from_cpu_copy_event.wait()
reading_completed = policy.reading_completed_async(found_keys)

missing_values = missing_values_future.wait()
replace_future = cache.replace_async(
positions_future.wait(), missing_values
)

host_to_device_stream = get_host_to_device_uva_stream()
with torch.cuda.stream(host_to_device_stream):
missing_values_cuda = missing_values.to(
ids_device, non_blocking=True
)
missing_values_cuda.record_stream(current_stream)
missing_values_copy_event = torch.cuda.Event()
missing_values_copy_event.record()

yield

reading_completed.wait()
replace_future.wait()
reading_completed = policy.reading_completed_async(missing_keys)

class _Waiter:
@staticmethod
def wait():
"""Returns the stored value when invoked."""
missing_values_copy_event.wait()
reading_completed.wait()
values = torch.empty(
(ids_cuda.shape[0],) + missing_values_cuda.shape[1:],
dtype=missing_values_cuda.dtype,
device=ids_device,
)
found_index = index[: positions.size(0)]
missing_index = index[positions.size(0) :]
values[found_index] = values_from_cpu
values[missing_index] = missing_values_cuda
return values

yield _Waiter()
elif ids.is_cuda:
pass
else:
Expand Down

0 comments on commit 89defad

Please sign in to comment.