Skip to content

Commit

Permalink
Merge branch 'master' into gb_read_async_cpu_cached_impl_branch_1
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 20, 2024
2 parents 89defad + ddef6ab commit 0f41999
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 8 deletions.
27 changes: 24 additions & 3 deletions python/dgl/graphbolt/impl/gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def read(self, ids: torch.Tensor = None):
if ids is None:
return self._fallback_feature.read()
values, missing_index, missing_keys = self._feature.query(ids)
missing_values = self._fallback_feature.read(missing_keys).to("cuda")
missing_values = self._fallback_feature.read(missing_keys)
values[missing_index] = missing_values
self._feature.replace(missing_keys, missing_values)
return values
Expand Down Expand Up @@ -112,7 +112,27 @@ def read_async(self, ids: torch.Tensor):
... future = next(async_handle)
>>> result = future.wait() # result contains the read values.
"""
raise NotImplementedError
values, missing_index, missing_keys = self._feature.query(ids)

fallback_reader = self._fallback_feature.read_async(missing_keys)
fallback_num_stages = self._fallback_feature.read_async_num_stages(
missing_keys.device
)
for i in range(fallback_num_stages):
missing_values_future = next(fallback_reader, None)
if i < fallback_num_stages - 1:
yield # fallback feature stages.

class _Waiter:
@staticmethod
def wait():
"""Returns the stored value when invoked."""
missing_values = missing_values_future.wait()
self._feature.replace(missing_keys, missing_values)
values[missing_index] = missing_values
return values

yield _Waiter()

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
Expand All @@ -126,7 +146,8 @@ def read_async_num_stages(self, ids_device: torch.device):
int
The number of stages of the read_async operation.
"""
raise NotImplementedError
assert ids_device.type == "cuda"
return self._fallback_feature.read_async_num_stages(ids_device)

def size(self):
"""Get the size of the feature.
Expand Down
104 changes: 99 additions & 5 deletions python/dgl/graphbolt/impl/torch_based_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import numpy as np
import torch

from ..base import index_select
from ..base import (
get_device_to_host_uva_stream,
get_host_to_device_uva_stream,
index_select,
)
from ..feature_store import Feature
from ..internal_utils import gb_warning, is_wsl
from .basic_feature_store import BasicFeatureStore
Expand Down Expand Up @@ -145,7 +149,60 @@ def read_async(self, ids: torch.Tensor):
... future = next(async_handle)
>>> result = future.wait() # result contains the read values.
"""
raise NotImplementedError
assert self._tensor.device.type == "cpu"
if ids.is_cuda and self.is_pinned():
current_stream = torch.cuda.current_stream()
host_to_device_stream = get_host_to_device_uva_stream()
host_to_device_stream.wait_stream(current_stream)
with torch.cuda.stream(host_to_device_stream):
ids.record_stream(torch.cuda.current_stream())
values = index_select(self._tensor, ids)
values.record_stream(current_stream)
values_copy_event = torch.cuda.Event()
values_copy_event.record()

class _Waiter:
@staticmethod
def wait():
"""Returns the stored value when invoked."""
values_copy_event.wait()
return values

yield _Waiter()
elif ids.is_cuda:
ids_device = ids.device
current_stream = torch.cuda.current_stream()
device_to_host_stream = get_device_to_host_uva_stream()
device_to_host_stream.wait_stream(current_stream)
with torch.cuda.stream(device_to_host_stream):
ids.record_stream(torch.cuda.current_stream())
ids = ids.to(self._tensor.device, non_blocking=True)
ids_copy_event = torch.cuda.Event()
ids_copy_event.record()

yield # first stage is done.

ids_copy_event.synchronize()
values = torch.ops.graphbolt.index_select_async(self._tensor, ids)
yield

host_to_device_stream = get_host_to_device_uva_stream()
with torch.cuda.stream(host_to_device_stream):
values_cuda = values.wait().to(ids_device, non_blocking=True)
values_cuda.record_stream(current_stream)
values_copy_event = torch.cuda.Event()
values_copy_event.record()

class _Waiter:
@staticmethod
def wait():
"""Returns the stored value when invoked."""
values_copy_event.wait()
return values_cuda

yield _Waiter()
else:
yield torch.ops.graphbolt.index_select_async(self._tensor, ids)

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
Expand All @@ -159,7 +216,10 @@ def read_async_num_stages(self, ids_device: torch.device):
int
The number of stages of the read_async operation.
"""
raise NotImplementedError
if ids_device.type == "cuda":
return 1 if self.is_pinned() else 3
else:
return 1

def size(self):
"""Get the size of the feature.
Expand Down Expand Up @@ -367,7 +427,41 @@ def read_async(self, ids: torch.Tensor):
... future = next(async_handle)
>>> result = future.wait() # result contains the read values.
"""
raise NotImplementedError
assert torch.ops.graphbolt.detect_io_uring()
if ids.is_cuda:
ids_device = ids.device
current_stream = torch.cuda.current_stream()
device_to_host_stream = get_device_to_host_uva_stream()
device_to_host_stream.wait_stream(current_stream)
with torch.cuda.stream(device_to_host_stream):
ids.record_stream(torch.cuda.current_stream())
ids = ids.to(self._tensor.device, non_blocking=True)
ids_copy_event = torch.cuda.Event()
ids_copy_event.record()

yield # first stage is done.

ids_copy_event.synchronize()
values = self._ondisk_npy_array.index_select(ids)
yield

host_to_device_stream = get_host_to_device_uva_stream()
with torch.cuda.stream(host_to_device_stream):
values_cuda = values.wait().to(ids_device, non_blocking=True)
values_cuda.record_stream(current_stream)
values_copy_event = torch.cuda.Event()
values_copy_event.record()

class _Waiter:
@staticmethod
def wait():
"""Returns the stored value when invoked."""
values_copy_event.wait()
return values_cuda

yield _Waiter()
else:
yield self._ondisk_npy_array.index_select(ids)

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
Expand All @@ -381,7 +475,7 @@ def read_async_num_stages(self, ids_device: torch.device):
int
The number of stages of the read_async operation.
"""
raise NotImplementedError
return 3 if ids_device.type == "cuda" else 1

def size(self):
"""Get the size of the feature.
Expand Down

0 comments on commit 0f41999

Please sign in to comment.