From 3f8e507b05b802eaeacf030753786494940f7c0b Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 15 Aug 2024 13:22:23 -0400 Subject: [PATCH 1/2] [GraphBolt][CUDA] Eliminate GPUCache synchronization. --- graphbolt/src/cuda/extension/gpu_cache.cu | 8 +++++ graphbolt/src/cuda/extension/gpu_cache.h | 4 +++ graphbolt/src/python_binding.cc | 1 + python/dgl/graphbolt/impl/gpu_cache.py | 29 +++++++++++++++---- .../dgl/graphbolt/impl/gpu_cached_feature.py | 8 +++-- 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/graphbolt/src/cuda/extension/gpu_cache.cu b/graphbolt/src/cuda/extension/gpu_cache.cu index 8abe5eec71f5..7e280187976a 100644 --- a/graphbolt/src/cuda/extension/gpu_cache.cu +++ b/graphbolt/src/cuda/extension/gpu_cache.cu @@ -76,6 +76,14 @@ std::tuple GpuCache::Query( return std::make_tuple(values, missing_index, missing_keys); } +c10::intrusive_ptr>> GpuCache::QueryAsync( + torch::Tensor keys) { + return async([=] { + auto [values, missing_index, missing_keys] = Query(keys); + return std::vector{values, missing_index, missing_keys}; + }); +} + void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) { TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device."); TORCH_CHECK( diff --git a/graphbolt/src/cuda/extension/gpu_cache.h b/graphbolt/src/cuda/extension/gpu_cache.h index 556bdaa5b5bd..6ca2b12995a7 100644 --- a/graphbolt/src/cuda/extension/gpu_cache.h +++ b/graphbolt/src/cuda/extension/gpu_cache.h @@ -21,6 +21,7 @@ #ifndef GRAPHBOLT_GPU_CACHE_H_ #define GRAPHBOLT_GPU_CACHE_H_ +#include #include #include @@ -53,6 +54,9 @@ class GpuCache : public torch::CustomClassHolder { std::tuple Query( torch::Tensor keys); + c10::intrusive_ptr>> QueryAsync( + torch::Tensor keys); + void Replace(torch::Tensor keys, torch::Tensor values); static c10::intrusive_ptr Create( diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 9e017dd1df3d..62822c28a478 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -109,6 +109,7 @@ TORCH_LIBRARY(graphbolt, m) { #ifdef GRAPHBOLT_USE_CUDA m.class_("GpuCache") .def("query", &cuda::GpuCache::Query) + .def("query_async", &cuda::GpuCache::QueryAsync) .def("replace", &cuda::GpuCache::Replace); m.def("gpu_cache", &cuda::GpuCache::Create); m.class_("GpuGraphCache") diff --git a/python/dgl/graphbolt/impl/gpu_cache.py b/python/dgl/graphbolt/impl/gpu_cache.py index 7c07e7c52a0b..3ca6d994dca4 100644 --- a/python/dgl/graphbolt/impl/gpu_cache.py +++ b/python/dgl/graphbolt/impl/gpu_cache.py @@ -14,13 +14,16 @@ def __init__(self, cache_shape, dtype): self.total_miss = 0 self.total_queries = 0 - def query(self, keys): + def query(self, keys, async_op=False): """Queries the GPU cache. Parameters ---------- keys : Tensor The keys to query the GPU cache with. + async_op: bool + Boolean indicating whether the call is asynchronous. If so, the + result can be obtained by calling wait on the returned future. Returns ------- @@ -29,10 +32,26 @@ def query(self, keys): values[missing_indices] corresponds to cache misses that should be filled by quering another source with missing_keys. """ - self.total_queries += keys.shape[0] - values, missing_index, missing_keys = self._cache.query(keys) - self.total_miss += missing_keys.shape[0] - return values, missing_index, missing_keys + class _Waiter: + def __init__(self, gpu_cache, future): + self.gpu_cache = gpu_cache + self.future = future + + def wait(self): + """Returns the stored value when invoked.""" + gpu_cache = self.gpu_cache + values, missing_index, missing_keys = self.future.wait() if async_op else self.future + # Ensure there is no leak. + self.gpu_cache = self.future = None + + gpu_cache.total_queries += values.shape[0] + gpu_cache.total_miss += missing_keys.shape[0] + return values, missing_index, missing_keys + + if async_op: + return _Waiter(self, self._cache.query_async(keys)) + else: + return _Waiter(self, self._cache.query(keys)).wait() def replace(self, keys, values): """Inserts key-value pairs into the GPU cache using the Least-Recently diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index d8fb36add1da..f20a343ad818 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -114,7 +114,11 @@ def read_async(self, ids: torch.Tensor): >>> assert stage + 1 == feature.read_async_num_stages(ids.device) >>> result = future.wait() # result contains the read values. """ - values, missing_index, missing_keys = self._feature.query(ids) + future = self._feature.query(ids, async_op=True) + + yield + + values, missing_index, missing_keys = future.wait() fallback_reader = self._fallback_feature.read_async(missing_keys) fallback_num_stages = self._fallback_feature.read_async_num_stages( @@ -175,7 +179,7 @@ def read_async_num_stages(self, ids_device: torch.device): The number of stages of the read_async operation. """ assert ids_device.type == "cuda" - return self._fallback_feature.read_async_num_stages(ids_device) + return 1 + self._fallback_feature.read_async_num_stages(ids_device) def size(self): """Get the size of the feature. From dfcc1d13dee68f6321b4387e78d90c12e02597d3 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 15 Aug 2024 13:24:26 -0400 Subject: [PATCH 2/2] linting --- python/dgl/graphbolt/impl/gpu_cache.py | 7 +++++-- python/dgl/graphbolt/impl/gpu_cached_feature.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/dgl/graphbolt/impl/gpu_cache.py b/python/dgl/graphbolt/impl/gpu_cache.py index 3ca6d994dca4..413fa5527a7a 100644 --- a/python/dgl/graphbolt/impl/gpu_cache.py +++ b/python/dgl/graphbolt/impl/gpu_cache.py @@ -32,6 +32,7 @@ def query(self, keys, async_op=False): values[missing_indices] corresponds to cache misses that should be filled by quering another source with missing_keys. """ + class _Waiter: def __init__(self, gpu_cache, future): self.gpu_cache = gpu_cache @@ -40,14 +41,16 @@ def __init__(self, gpu_cache, future): def wait(self): """Returns the stored value when invoked.""" gpu_cache = self.gpu_cache - values, missing_index, missing_keys = self.future.wait() if async_op else self.future + values, missing_index, missing_keys = ( + self.future.wait() if async_op else self.future + ) # Ensure there is no leak. self.gpu_cache = self.future = None gpu_cache.total_queries += values.shape[0] gpu_cache.total_miss += missing_keys.shape[0] return values, missing_index, missing_keys - + if async_op: return _Waiter(self, self._cache.query_async(keys)) else: diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index f20a343ad818..e19c8752fa2a 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -115,7 +115,7 @@ def read_async(self, ids: torch.Tensor): >>> result = future.wait() # result contains the read values. """ future = self._feature.query(ids, async_op=True) - + yield values, missing_index, missing_keys = future.wait()