Skip to content

Commit

Permalink
[GraphBolt] Async feature fetch refactor (#7540)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 22, 2024
1 parent 2074cbf commit 5b4635a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 20 deletions.
13 changes: 9 additions & 4 deletions python/dgl/graphbolt/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
Expand All @@ -52,21 +53,25 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
raise NotImplementedError

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device
Expand Down
13 changes: 9 additions & 4 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
Expand All @@ -95,14 +96,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
policy = self._feature._policy
Expand Down Expand Up @@ -309,7 +311,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device
Expand Down
13 changes: 9 additions & 4 deletions python/dgl/graphbolt/impl/gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
Expand All @@ -102,14 +103,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
values, missing_index, missing_keys = self._feature.query(ids)
Expand All @@ -136,7 +138,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device
Expand Down
26 changes: 18 additions & 8 deletions python/dgl/graphbolt/impl/torch_based_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
Expand All @@ -139,14 +140,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
assert self._tensor.device.type == "cpu"
Expand Down Expand Up @@ -206,7 +208,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device
Expand Down Expand Up @@ -408,6 +413,7 @@ def read(self, ids: torch.Tensor = None):

def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
Expand All @@ -420,14 +426,15 @@ def read_async(self, ids: torch.Tensor):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
assert torch.ops.graphbolt.detect_io_uring()
Expand Down Expand Up @@ -468,7 +475,10 @@ def wait():

def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device
Expand Down

0 comments on commit 5b4635a

Please sign in to comment.