From d7461cd10a91ab24dacdc88d07168dfd855f1b0c Mon Sep 17 00:00:00 2001 From: pyynb <52124938+pyynb@users.noreply.github.com> Date: Thu, 1 Aug 2024 16:36:43 +0800 Subject: [PATCH 1/5] [GraphBolt]support torch2.4&cuda12.4 (#7629) --- script/create_dev_conda_env.sh | 2 +- script/dgl_dev.yml.template | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/script/create_dev_conda_env.sh b/script/create_dev_conda_env.sh index 01a8795f63ec..891be0595400 100644 --- a/script/create_dev_conda_env.sh +++ b/script/create_dev_conda_env.sh @@ -1,6 +1,6 @@ #!/bin/bash -readonly CUDA_VERSIONS="11.8,12.1" +readonly CUDA_VERSIONS="11.8,12.1,12.4" readonly TORCH_VERSION="2.1.0" readonly PYTHON_VERSION="3.10" diff --git a/script/dgl_dev.yml.template b/script/dgl_dev.yml.template index b41d4a3db8c9..708df84aa1e5 100644 --- a/script/dgl_dev.yml.template +++ b/script/dgl_dev.yml.template @@ -10,7 +10,7 @@ dependencies: - pandoc - pygraphviz - pip: - - --find-links https://download.pytorch.org/whl/torch_stable.html + - --find-links https://download.pytorch.org/whl/torch/ - cmake>=3.18 - cython - filelock From 69eef91422903cbc26cd8aef4e40aab4bb016569 Mon Sep 17 00:00:00 2001 From: "Hongzhi (Steve), Chen" Date: Thu, 1 Aug 2024 17:39:17 +0800 Subject: [PATCH 2/5] Update Jenkinsfile (#7630) --- Jenkinsfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 07fb149dda25..8328384e7e11 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -157,8 +157,7 @@ def is_authorized(name) { 'rudongyu', 'classicsong', 'HuXiangkun', 'hetong007', 'kylasa', 'frozenbugs', 'peizhou001', 'zheng-da', 'czkkkkkk', 'thvasilo', // Intern: - 'keli-wen', 'caojy1998', 'RamonZhou', 'xiangyuzhi', 'Skeleton003', 'yxy235', - 'hutiechuan', 'pyynb', 'az15240', 'BowenYao18', 'kec020', + 'pyynb', 'az15240', 'BowenYao18', 'kec020', 'Liu-rj', // Friends: 'nv-dlasalle', 'yaox12', 'chang-l', 'Kh4L', 'VibhuJawa', 'kkranen', 'TristonC', 'mfbalin', From 105f0c6834c52a03d3fcc8d3f586153e20f8b985 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 1 Aug 2024 06:37:25 -0400 Subject: [PATCH 3/5] [GraphBolt] Parallel backend experimentation (#7596) --- .gitmodules | 3 + graphbolt/CMakeLists.txt | 5 ++ graphbolt/include/graphbolt/async.h | 113 ++++++++++++++++++++++++---- graphbolt/src/cnumpy.cc | 2 +- third_party/taskflow | 1 + 5 files changed, 107 insertions(+), 17 deletions(-) create mode 160000 third_party/taskflow diff --git a/.gitmodules b/.gitmodules index 52f9fc22876e..a0d9d0856ca8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -34,3 +34,6 @@ [submodule "third_party/GKlib"] path = third_party/GKlib url = https://github.com/KarypisLab/GKlib.git +[submodule "third_party/taskflow"] + path = third_party/taskflow + url = https://github.com/taskflow/taskflow.git diff --git a/graphbolt/CMakeLists.txt b/graphbolt/CMakeLists.txt index e6a286724ba4..bde937150b32 100644 --- a/graphbolt/CMakeLists.txt +++ b/graphbolt/CMakeLists.txt @@ -45,6 +45,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb") set(LIB_GRAPHBOLT_NAME "graphbolt_pytorch_${TORCH_VER}") +option(BUILD_WITH_TASKFLOW "Use taskflow as parallel backend" ON) set(BOLT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src") set(BOLT_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -79,6 +80,10 @@ include_directories(BEFORE ${BOLT_DIR} "../third_party/pcg/include" "../third_party/phmap") target_link_libraries(${LIB_GRAPHBOLT_NAME} "${TORCH_LIBRARIES}") +if(BUILD_WITH_TASKFLOW) + target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE "../third_party/taskflow") + target_compile_definitions(${LIB_GRAPHBOLT_NAME} PRIVATE BUILD_WITH_TASKFLOW=1) +endif() if(CMAKE_SYSTEM_NAME MATCHES "Linux") if(USE_LIBURING) diff --git a/graphbolt/include/graphbolt/async.h b/graphbolt/include/graphbolt/async.h index 1e4780b0672a..b720e64d1e07 100644 --- a/graphbolt/include/graphbolt/async.h +++ b/graphbolt/include/graphbolt/async.h @@ -27,10 +27,62 @@ #include #include #include +#include #include +#ifdef BUILD_WITH_TASKFLOW +#include +#include +#endif + namespace graphbolt { +enum ThreadPool { intraop, interop }; + +#ifdef BUILD_WITH_TASKFLOW + +template +inline tf::Executor& _get_thread_pool() { + static std::unique_ptr pool; + static std::once_flag flag; + std::call_once(flag, [&] { + const int num_threads = pool_type == ThreadPool::intraop + ? torch::get_num_threads() + : torch::get_num_interop_threads(); + pool = std::make_unique(num_threads); + }); + return *pool.get(); +} + +inline tf::Executor& intraop_pool() { + return _get_thread_pool(); +} + +inline tf::Executor& interop_pool() { + return _get_thread_pool(); +} + +inline tf::Executor& get_thread_pool(ThreadPool pool_type) { + return pool_type == ThreadPool::intraop ? intraop_pool() : interop_pool(); +} +#endif // BUILD_WITH_TASKFLOW + +inline int get_num_threads() { +#ifdef BUILD_WITH_TASKFLOW + return intraop_pool().num_workers(); +#else + return torch::get_num_threads(); +#endif +} + +inline int get_num_interop_threads() { +#ifdef BUILD_WITH_TASKFLOW + return interop_pool().num_workers(); +#else + return torch::get_num_interop_threads(); +#endif +} + template class Future : public torch::CustomClassHolder { public: @@ -52,6 +104,9 @@ class Future : public torch::CustomClassHolder { template inline auto async(F function) { using T = decltype(function()); +#ifdef BUILD_WITH_TASKFLOW + auto future = interop_pool().async(function); +#else auto promise = std::make_shared>(); auto future = promise->get_future(); at::launch([=]() { @@ -61,27 +116,16 @@ inline auto async(F function) { } else promise->set_value(function()); }); +#endif return c10::make_intrusive>(std::move(future)); } -/** - * @brief GraphBolt's version of torch::parallel_for. Since torch::parallel_for - * uses OpenMP threadpool, async tasks can not make use of it due to multiple - * OpenMP threadpools being created for each async thread. Moreover, inside - * graphbolt::parallel_for, we should not make use of any native CPU torch ops - * as they will spawn an OpenMP threadpool. - */ -template -inline void parallel_for( +template +inline void _parallel_for( const int64_t begin, const int64_t end, const int64_t grain_size, const F& f) { if (begin >= end) return; - std::promise promise; - std::future future; - std::atomic_flag err_flag = ATOMIC_FLAG_INIT; - std::exception_ptr eptr; - - int64_t num_threads = torch::get_num_threads(); + int64_t num_threads = get_num_threads(); const auto num_iter = end - begin; const bool use_parallel = (num_iter > grain_size && num_iter > 1 && num_threads > 1); @@ -93,12 +137,27 @@ inline void parallel_for( num_threads = std::min(num_threads, at::divup(end - begin, grain_size)); } int64_t chunk_size = at::divup((end - begin), num_threads); +#ifdef BUILD_WITH_TASKFLOW + tf::Taskflow flow; + flow.for_each_index(int64_t{0}, num_threads, int64_t{1}, [=](int64_t tid) { + const int64_t begin_tid = begin + tid * chunk_size; + if (begin_tid < end) { + const int64_t end_tid = std::min(end, begin_tid + chunk_size); + f(begin_tid, end_tid); + } + }); + _get_thread_pool().run(flow).wait(); +#else + std::promise promise; + std::future future; + std::atomic_flag err_flag = ATOMIC_FLAG_INIT; + std::exception_ptr eptr; int num_launched = 0; std::atomic num_finished = 0; for (int tid = num_threads - 1; tid >= 0; tid--) { const int64_t begin_tid = begin + tid * chunk_size; - const int64_t end_tid = std::min(end, begin_tid + chunk_size); if (begin_tid < end) { + const int64_t end_tid = std::min(end, begin_tid + chunk_size); if (tid == 0) { // Launch the thread 0's work inline. f(begin_tid, end_tid); @@ -132,6 +191,28 @@ inline void parallel_for( std::rethrow_exception(eptr); } } +#endif +} + +/** + * @brief GraphBolt's version of torch::parallel_for. Since torch::parallel_for + * uses OpenMP threadpool, async tasks can not make use of it due to multiple + * OpenMP threadpools being created for each async thread. Moreover, inside + * graphbolt::parallel_for, we should not make use of any native CPU torch ops + * as they will spawn an OpenMP threadpool. + */ +template +inline void parallel_for( + const int64_t begin, const int64_t end, const int64_t grain_size, + const F& f) { + _parallel_for(begin, end, grain_size, f); +} + +template +inline void parallel_for_interop( + const int64_t begin, const int64_t end, const int64_t grain_size, + const F& f) { + _parallel_for(begin, end, grain_size, f); } } // namespace graphbolt diff --git a/graphbolt/src/cnumpy.cc b/graphbolt/src/cnumpy.cc index 457028cddba1..bc1fdf713c4a 100644 --- a/graphbolt/src/cnumpy.cc +++ b/graphbolt/src/cnumpy.cc @@ -173,7 +173,7 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) { // Consume a slot so that parallel_for is called only if there are available // queues. semaphore_.acquire(); - graphbolt::parallel_for(0, num_thread_, 1, [&](int thread_id, int) { + graphbolt::parallel_for_interop(0, num_thread_, 1, [&](int thread_id, int) { // The completion queue might contain 4 * kGroupSize while we may submit // 4 * kGroupSize more. No harm in overallocation here. CircularQueue read_queue(8 * kGroupSize); diff --git a/third_party/taskflow b/third_party/taskflow new file mode 160000 index 000000000000..7d9e85b6b2e9 --- /dev/null +++ b/third_party/taskflow @@ -0,0 +1 @@ +Subproject commit 7d9e85b6b2e9bf501021f857f2f3cbe43bc37c85 From ea3716d562138ff81c20566c5fa287c3b8e7111a Mon Sep 17 00:00:00 2001 From: "Hongzhi (Steve), Chen" Date: Thu, 1 Aug 2024 19:13:09 +0800 Subject: [PATCH 4/5] Use torch core instead of torchdata modules. (#7609) Co-authored-by: Muhammed Fatih BALIN --- python/dgl/graphbolt/base.py | 3 +-- python/dgl/graphbolt/dataloader.py | 9 ++++----- python/dgl/graphbolt/impl/neighbor_sampler.py | 2 +- python/dgl/graphbolt/minibatch_transformer.py | 2 +- tests/python/pytorch/graphbolt/test_feature_fetcher.py | 2 +- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 8e7e6365f413..69c812c54b26 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -17,8 +17,7 @@ ) # pylint: disable=wrong-import-position -from torch.utils.data import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe +from torch.utils.data import functional_datapipe, IterDataPipe from .internal_utils import ( get_nonproperty_attributes, diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index c81a6b0a7e6a..0711a24fcee4 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -4,9 +4,8 @@ from concurrent.futures import ThreadPoolExecutor import torch -import torch.utils.data +import torch.utils.data as torch_data import torchdata.dataloader2.graph as dp_utils -import torchdata.datapipes as dp from .base import CopyTo, get_host_to_device_uva_stream from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker @@ -70,7 +69,7 @@ def _set_worker_id(worked_id): torch.ops.graphbolt.set_worker_id(worked_id) -class MultiprocessingWrapper(dp.iter.IterDataPipe): +class MultiprocessingWrapper(torch_data.IterDataPipe): """Wraps a datapipe with multiprocessing. Parameters @@ -88,7 +87,7 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe): def __init__(self, datapipe, num_workers=0, persistent_workers=True): self.datapipe = datapipe - self.dataloader = torch.utils.data.DataLoader( + self.dataloader = torch_data.DataLoader( datapipe, batch_size=None, num_workers=num_workers, @@ -100,7 +99,7 @@ def __iter__(self): yield from self.dataloader -class DataLoader(torch.utils.data.DataLoader): +class DataLoader(torch_data.DataLoader): """Multiprocessing DataLoader. Iterates over the data pipeline with everything before feature fetching diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 1552c3c83333..59e65f722f2c 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -5,7 +5,7 @@ import torch from torch.utils.data import functional_datapipe -from torchdata.datapipes.iter import Mapper +from torch.utils.data.datapipes.iter import Mapper from ..base import ORIGINAL_EDGE_ID from ..internal import compact_csc_format, unique_and_compact_csc_formats diff --git a/python/dgl/graphbolt/minibatch_transformer.py b/python/dgl/graphbolt/minibatch_transformer.py index b7b00b7a1b29..9163ee150671 100644 --- a/python/dgl/graphbolt/minibatch_transformer.py +++ b/python/dgl/graphbolt/minibatch_transformer.py @@ -2,7 +2,7 @@ from torch.utils.data import functional_datapipe -from torchdata.datapipes.iter import Mapper +from torch.utils.data.datapipes.iter import Mapper from .minibatch import MiniBatch diff --git a/tests/python/pytorch/graphbolt/test_feature_fetcher.py b/tests/python/pytorch/graphbolt/test_feature_fetcher.py index b6b906356d54..e945d90d2389 100644 --- a/tests/python/pytorch/graphbolt/test_feature_fetcher.py +++ b/tests/python/pytorch/graphbolt/test_feature_fetcher.py @@ -4,7 +4,7 @@ import dgl.graphbolt as gb import pytest import torch -from torchdata.datapipes.iter import Mapper +from torch.utils.data.datapipes.iter import Mapper from . import gb_test_utils From 489ab1a81e3491c8d134b899d08289499b3e7efb Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 1 Aug 2024 10:23:18 -0400 Subject: [PATCH 5/5] [GraphBolt] `QueryAndReplace` replaces `QueryAndThenReplace`. (#7628) --- graphbolt/src/cache_policy.cc | 75 +++++++++---------- graphbolt/src/cache_policy.h | 57 ++++++++------ graphbolt/src/partitioned_cache_policy.cc | 12 +-- graphbolt/src/partitioned_cache_policy.h | 6 +- graphbolt/src/python_binding.cc | 8 +- .../dgl/graphbolt/impl/cpu_cached_feature.py | 10 +-- python/dgl/graphbolt/impl/feature_cache.py | 4 +- .../graphbolt/impl/test_feature_cache.py | 30 ++++---- 8 files changed, 100 insertions(+), 102 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index a707b1a2f280..d6dfeeae849d 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -75,8 +75,7 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) { template std::tuple -BaseCachePolicy::QueryAndThenReplaceImpl( - CachePolicy& policy, torch::Tensor keys) { +BaseCachePolicy::QueryAndReplaceImpl(CachePolicy& policy, torch::Tensor keys) { auto positions = torch::empty_like( keys, keys.options() .dtype(torch::kInt64) @@ -100,9 +99,9 @@ BaseCachePolicy::QueryAndThenReplaceImpl( auto pointers_ptr = reinterpret_cast(pointers.data_ptr()); auto missing_keys_ptr = missing_keys.data_ptr(); - auto iterators = std::unique_ptr( - new typename CachePolicy::map_iterator[keys.size(0)]); - // QueryImpl here. + set_t position_set; + position_set.reserve(keys.size(0)); + // Query and Replace combined. for (int64_t i = 0; i < keys.size(0); i++) { const auto key = keys_ptr[i]; const auto [it, can_read] = policy.Emplace(key); @@ -114,28 +113,20 @@ BaseCachePolicy::QueryAndThenReplaceImpl( } else { indices_ptr[--missing_cnt] = i; missing_keys_ptr[missing_cnt] = key; - iterators[missing_cnt] = it; - } - } - // ReplaceImpl here. - set_t position_set; - position_set.reserve(keys.size(0)); - for (int64_t i = missing_cnt; i < missing_keys.size(0); i++) { - auto it = iterators[i]; - if (it->second == policy.getMapSentinelValue()) { - policy.Insert(it); - // After Insert, it->second is not nullptr anymore. - TORCH_CHECK( - // If there are duplicate values and the key was just inserted, - // we do not have to check for the uniqueness of the positions. - std::get<1>(position_set.insert(it->second->getPos())), - "Can't insert all, larger cache capacity is needed."); - } else { - policy.MarkExistingWriting(it); + CacheKey* cache_key_ptr; + if (it->second == policy.getMapSentinelValue()) { + cache_key_ptr = policy.Insert(it); + TORCH_CHECK( + // We check for the uniqueness of the positions. + std::get<1>(position_set.insert(cache_key_ptr->getPos())), + "Can't insert all, larger cache capacity is needed."); + } else { + cache_key_ptr = &*it->second; + policy.MarkExistingWriting(it); + } + positions_ptr[missing_cnt] = cache_key_ptr->getPos(); + pointers_ptr[missing_cnt] = cache_key_ptr; } - auto& cache_key = *it->second; - positions_ptr[i] = cache_key.getPos(); - pointers_ptr[i] = &cache_key; } })); return {positions, indices, pointers, missing_keys.slice(0, found_cnt)}; @@ -192,15 +183,16 @@ void BaseCachePolicy::ReadingWritingCompletedImpl( } S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity) - : small_queue_(capacity), - main_queue_(capacity), + // We sometimes first insert and then evict. + 1 is to compensate for that. + : small_queue_(capacity + 1), + main_queue_(capacity + 1), ghost_queue_(capacity - capacity / 10), capacity_(capacity), cache_usage_(0), small_queue_size_target_(capacity / 10) { TORCH_CHECK(small_queue_size_target_ > 0, "Capacity is not large enough."); ghost_set_.reserve(ghost_queue_.Capacity()); - key_to_cache_key_.reserve(kCapacityFactor * capacity); + key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1)); } std::tuple @@ -209,8 +201,8 @@ S3FifoCachePolicy::Query(torch::Tensor keys) { } std::tuple -S3FifoCachePolicy::QueryAndThenReplace(torch::Tensor keys) { - return QueryAndThenReplaceImpl(*this, keys); +S3FifoCachePolicy::QueryAndReplace(torch::Tensor keys) { + return QueryAndReplaceImpl(*this, keys); } std::tuple S3FifoCachePolicy::Replace( @@ -230,7 +222,7 @@ SieveCachePolicy::SieveCachePolicy(int64_t capacity) // Ensure that queue_ is constructed first before accessing its `.end()`. : queue_(), hand_(queue_.end()), capacity_(capacity), cache_usage_(0) { TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); - key_to_cache_key_.reserve(kCapacityFactor * capacity); + key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1)); } std::tuple @@ -239,8 +231,8 @@ SieveCachePolicy::Query(torch::Tensor keys) { } std::tuple -SieveCachePolicy::QueryAndThenReplace(torch::Tensor keys) { - return QueryAndThenReplaceImpl(*this, keys); +SieveCachePolicy::QueryAndReplace(torch::Tensor keys) { + return QueryAndReplaceImpl(*this, keys); } std::tuple SieveCachePolicy::Replace( @@ -259,7 +251,7 @@ void SieveCachePolicy::WritingCompleted(torch::Tensor keys) { LruCachePolicy::LruCachePolicy(int64_t capacity) : capacity_(capacity), cache_usage_(0) { TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); - key_to_cache_key_.reserve(kCapacityFactor * capacity); + key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1)); } std::tuple @@ -268,8 +260,8 @@ LruCachePolicy::Query(torch::Tensor keys) { } std::tuple -LruCachePolicy::QueryAndThenReplace(torch::Tensor keys) { - return QueryAndThenReplaceImpl(*this, keys); +LruCachePolicy::QueryAndReplace(torch::Tensor keys) { + return QueryAndReplaceImpl(*this, keys); } std::tuple LruCachePolicy::Replace( @@ -286,9 +278,10 @@ void LruCachePolicy::WritingCompleted(torch::Tensor keys) { } ClockCachePolicy::ClockCachePolicy(int64_t capacity) - : queue_(capacity), capacity_(capacity), cache_usage_(0) { + // We sometimes first insert and then evict. + 1 is to compensate for that. + : queue_(capacity + 1), capacity_(capacity), cache_usage_(0) { TORCH_CHECK(capacity > 0, "Capacity needs to be positive."); - key_to_cache_key_.reserve(kCapacityFactor * capacity); + key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1)); } std::tuple @@ -297,8 +290,8 @@ ClockCachePolicy::Query(torch::Tensor keys) { } std::tuple -ClockCachePolicy::QueryAndThenReplace(torch::Tensor keys) { - return QueryAndThenReplaceImpl(*this, keys); +ClockCachePolicy::QueryAndReplace(torch::Tensor keys) { + return QueryAndReplaceImpl(*this, keys); } std::tuple ClockCachePolicy::Replace( diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index 024ceef059b6..f2bf7914a9c0 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -33,6 +33,8 @@ namespace graphbolt { namespace storage { struct CacheKey { + CacheKey(int64_t key) : CacheKey(key, -1) {} + CacheKey(int64_t key, int64_t position) : freq_(0), key_(key), @@ -50,6 +52,11 @@ struct CacheKey { auto getPos() const { return position_in_cache_; } + CacheKey& setPos(int64_t pos) { + position_in_cache_ = pos; + return *this; + } + CacheKey& Increment() { freq_ = std::min(3, static_cast(freq_ + 1)); return *this; @@ -144,7 +151,7 @@ class BaseCachePolicy { * identical to missing_keys. */ virtual std::tuple - QueryAndThenReplace(torch::Tensor keys) = 0; + QueryAndReplace(torch::Tensor keys) = 0; /** * @brief The policy replace function. @@ -182,7 +189,7 @@ class BaseCachePolicy { template static std::tuple - QueryAndThenReplaceImpl(CachePolicy& policy, torch::Tensor keys); + QueryAndReplaceImpl(CachePolicy& policy, torch::Tensor keys); template static std::tuple ReplaceImpl( @@ -220,10 +227,10 @@ class S3FifoCachePolicy : public BaseCachePolicy { torch::Tensor keys); /** - * @brief See BaseCachePolicy::QueryAndThenReplace. + * @brief See BaseCachePolicy::QueryAndReplace. */ std::tuple - QueryAndThenReplace(torch::Tensor keys); + QueryAndReplace(torch::Tensor keys); /** * @brief See BaseCachePolicy::Replace. @@ -286,12 +293,13 @@ class S3FifoCachePolicy : public BaseCachePolicy { return {pos, cache_key_ptr}; } - void Insert(map_iterator it) { + CacheKey* Insert(map_iterator it) { const auto key = it->first; - const auto pos = Evict(); const auto in_ghost_queue = ghost_set_.erase(key); auto& queue = in_ghost_queue ? main_queue_ : small_queue_; - it->second = queue.Push(CacheKey(key, pos)); + auto cache_key_ptr = queue.Push(CacheKey(key)); + it->second = cache_key_ptr; + return &cache_key_ptr->setPos(Evict()); } void MarkExistingWriting(map_iterator it) { @@ -380,10 +388,10 @@ class SieveCachePolicy : public BaseCachePolicy { torch::Tensor keys); /** - * @brief See BaseCachePolicy::QueryAndThenReplace. + * @brief See BaseCachePolicy::QueryAndReplace. */ std::tuple - QueryAndThenReplace(torch::Tensor keys); + QueryAndReplace(torch::Tensor keys); /** * @brief See BaseCachePolicy::Replace. @@ -437,11 +445,12 @@ class SieveCachePolicy : public BaseCachePolicy { return {pos, cache_key_ptr}; } - void Insert(map_iterator it) { + CacheKey* Insert(map_iterator it) { const auto key = it->first; - const auto pos = Evict(); - queue_.push_front(CacheKey(key, pos)); - it->second = &queue_.front(); + queue_.push_front(CacheKey(key)); + auto cache_key_ptr = &queue_.front(); + it->second = cache_key_ptr; + return &cache_key_ptr->setPos(Evict()); } void MarkExistingWriting(map_iterator it) { @@ -507,10 +516,10 @@ class LruCachePolicy : public BaseCachePolicy { torch::Tensor keys); /** - * @brief See BaseCachePolicy::QueryAndThenReplace. + * @brief See BaseCachePolicy::QueryAndReplace. */ std::tuple - QueryAndThenReplace(torch::Tensor keys); + QueryAndReplace(torch::Tensor keys); /** * @brief See BaseCachePolicy::Replace. @@ -577,11 +586,12 @@ class LruCachePolicy : public BaseCachePolicy { return {pos, &queue_.front()}; } - void Insert(map_iterator it) { + CacheKey* Insert(map_iterator it) { const auto key = it->first; - const auto pos = Evict(); - queue_.push_front(CacheKey(key, pos)); + queue_.push_front(CacheKey(key)); it->second = queue_.begin(); + auto cache_key_ptr = &*queue_.begin(); + return &cache_key_ptr->setPos(Evict()); } void MarkExistingWriting(map_iterator it) { @@ -649,10 +659,10 @@ class ClockCachePolicy : public BaseCachePolicy { torch::Tensor keys); /** - * @brief See BaseCachePolicy::QueryAndThenReplace. + * @brief See BaseCachePolicy::QueryAndReplace. */ std::tuple - QueryAndThenReplace(torch::Tensor keys); + QueryAndReplace(torch::Tensor keys); /** * @brief See BaseCachePolicy::Replace. @@ -705,10 +715,11 @@ class ClockCachePolicy : public BaseCachePolicy { return {pos, cache_key_ptr}; } - void Insert(map_iterator it) { + CacheKey* Insert(map_iterator it) { const auto key = it->first; - const auto pos = Evict(); - it->second = queue_.Push(CacheKey(key, pos)); + auto cache_key_ptr = queue_.Push(CacheKey(key)); + it->second = cache_key_ptr; + return &cache_key_ptr->setPos(Evict()); } void MarkExistingWriting(map_iterator it) { diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index c41618141368..2d15821dabe3 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -47,7 +47,7 @@ PartitionedCachePolicy::Partition(torch::Tensor keys) { torch::Tensor offsets = torch::empty( num_parts * num_parts + 1, keys.options().dtype(torch::kInt64)); auto offsets_ptr = offsets.data_ptr(); - std::memset(offsets_ptr, 0, offsets.size(0) * offsets.element_size()); + std::fill_n(offsets_ptr, offsets.size(0), int64_t{}); auto indices = torch::empty_like(keys, keys.options().dtype(torch::kInt64)); auto part_id = torch::empty_like(keys, keys.options().dtype(torch::kInt32)); const auto num_keys = keys.size(0); @@ -244,11 +244,11 @@ PartitionedCachePolicy::QueryAsync(torch::Tensor keys) { std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> -PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) { +PartitionedCachePolicy::QueryAndReplace(torch::Tensor keys) { if (policies_.size() == 1) { std::lock_guard lock(mtx_); auto [positions, output_indices, pointers, missing_keys] = - policies_[0]->QueryAndThenReplace(keys); + policies_[0]->QueryAndReplace(keys); auto found_and_missing_offsets = torch::empty(4, pointers.options()); auto found_and_missing_offsets_ptr = found_and_missing_offsets.data_ptr(); @@ -282,7 +282,7 @@ PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) { const auto tid = begin; begin = offsets_ptr[tid]; end = offsets_ptr[tid + 1]; - results[tid] = policies_.at(tid)->QueryAndThenReplace( + results[tid] = policies_.at(tid)->QueryAndReplace( permuted_keys.slice(0, begin, end)); const auto missing_cnt = std::get<3>(results[tid]).size(0); result_offsets[tid] = end - begin - missing_cnt; @@ -361,11 +361,11 @@ PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) { } c10::intrusive_ptr>> -PartitionedCachePolicy::QueryAndThenReplaceAsync(torch::Tensor keys) { +PartitionedCachePolicy::QueryAndReplaceAsync(torch::Tensor keys) { return async([=] { auto [positions, output_indices, pointers, missing_keys, found_offsets, - missing_offsets] = QueryAndThenReplace(keys); + missing_offsets] = QueryAndReplace(keys); return std::vector{positions, output_indices, pointers, missing_keys, found_offsets, missing_offsets}; }); diff --git a/graphbolt/src/partitioned_cache_policy.h b/graphbolt/src/partitioned_cache_policy.h index 3ee81dd9c550..030aab4f4054 100644 --- a/graphbolt/src/partitioned_cache_policy.h +++ b/graphbolt/src/partitioned_cache_policy.h @@ -92,10 +92,10 @@ class PartitionedCachePolicy : public torch::CustomClassHolder { std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> - QueryAndThenReplace(torch::Tensor keys); + QueryAndReplace(torch::Tensor keys); - c10::intrusive_ptr>> - QueryAndThenReplaceAsync(torch::Tensor keys); + c10::intrusive_ptr>> QueryAndReplaceAsync( + torch::Tensor keys); /** * @brief The policy replace function. diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 223b7b20b267..e447c60f9617 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -108,11 +108,11 @@ TORCH_LIBRARY(graphbolt, m) { .def("query", &storage::PartitionedCachePolicy::Query) .def("query_async", &storage::PartitionedCachePolicy::QueryAsync) .def( - "query_and_then_replace", - &storage::PartitionedCachePolicy::QueryAndThenReplace) + "query_and_replace", + &storage::PartitionedCachePolicy::QueryAndReplace) .def( - "query_and_then_replace_async", - &storage::PartitionedCachePolicy::QueryAndThenReplaceAsync) + "query_and_replace_async", + &storage::PartitionedCachePolicy::QueryAndReplaceAsync) .def("replace", &storage::PartitionedCachePolicy::Replace) .def("replace_async", &storage::PartitionedCachePolicy::ReplaceAsync) .def( diff --git a/python/dgl/graphbolt/impl/cpu_cached_feature.py b/python/dgl/graphbolt/impl/cpu_cached_feature.py index 2fb85d781ad7..acc2ad84cedc 100644 --- a/python/dgl/graphbolt/impl/cpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/cpu_cached_feature.py @@ -75,9 +75,7 @@ def read(self, ids: torch.Tensor = None): """ if ids is None: return self._fallback_feature.read() - return self._feature.query_and_then_replace( - ids, self._fallback_feature.read - ) + return self._feature.query_and_replace(ids, self._fallback_feature.read) def read_async(self, ids: torch.Tensor): """Read the feature by index asynchronously. @@ -121,7 +119,7 @@ def read_async(self, ids: torch.Tensor): yield # first stage is done. ids_copy_event.synchronize() - policy_future = policy.query_and_then_replace_async(ids) + policy_future = policy.query_and_replace_async(ids) yield @@ -235,7 +233,7 @@ def wait(self): yield # first stage is done. ids_copy_event.synchronize() - policy_future = policy.query_and_then_replace_async(ids) + policy_future = policy.query_and_replace_async(ids) yield @@ -313,7 +311,7 @@ def wait(self): yield _Waiter([values_copy_event, writing_completed], values) else: - policy_future = policy.query_and_then_replace_async(ids) + policy_future = policy.query_and_replace_async(ids) yield diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index ba185c58d4a0..136ad2a7314c 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -92,7 +92,7 @@ def query(self, keys): missing_index = index[positions.size(0) :] return values, missing_index, missing_keys, missing_offsets - def query_and_then_replace(self, keys, reader_fn): + def query_and_replace(self, keys, reader_fn): """Queries the cache. Then inserts the keys that are not found by reading them by calling `reader_fn(missing_keys)`, which are then inserted into the cache using the selected caching policy algorithm @@ -120,7 +120,7 @@ def query_and_then_replace(self, keys, reader_fn): missing_keys, found_offsets, missing_offsets, - ) = self._policy.query_and_then_replace(keys) + ) = self._policy.query_and_replace(keys) found_cnt = keys.size(0) - missing_keys.size(0) found_positions = positions[:found_cnt] values = self._cache.query(found_positions, index, keys.shape[0]) diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index 08e817b4fd6c..69b09e2796c1 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -6,8 +6,8 @@ from dgl import graphbolt as gb -def _test_query_and_then_replace(policy1, policy2, keys): - # Testing query_and_then_replace equivalence to query and then replace. +def _test_query_and_replace(policy1, policy2, keys): + # Testing query_and_replace equivalence to query and then replace. ( positions, index, @@ -15,17 +15,15 @@ def _test_query_and_then_replace(policy1, policy2, keys): missing_keys, found_offsets, missing_offsets, - ) = policy1.query_and_then_replace(keys) + ) = policy1.query_and_replace(keys) found_cnt = keys.size(0) - missing_keys.size(0) - found_positions = positions[:found_cnt] found_pointers = pointers[:found_cnt] policy1.reading_completed(found_pointers, found_offsets) - missing_positions = positions[found_cnt:] missing_pointers = pointers[found_cnt:] policy1.writing_completed(missing_pointers, missing_offsets) ( - found_positions2, + _, index2, missing_keys2, found_pointers2, @@ -33,13 +31,11 @@ def _test_query_and_then_replace(policy1, policy2, keys): missing_offsets2, ) = policy2.query(keys) policy2.reading_completed(found_pointers2, found_offsets2) - (missing_positions2, missing_pointers2, missing_offsets2) = policy2.replace( + (_, missing_pointers2, missing_offsets2) = policy2.replace( missing_keys2, missing_offsets2 ) policy2.writing_completed(missing_pointers2, missing_offsets2) - assert torch.equal(found_positions, found_positions2) - assert torch.equal(missing_positions, missing_positions2) assert torch.equal(index, index2) assert torch.equal(missing_keys, missing_keys2) @@ -95,9 +91,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) + assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) - _test_query_and_then_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys) pin_memory = F._default_context_str == "gpu" @@ -115,9 +111,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) + assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) - _test_query_and_then_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys) values, missing_index, missing_keys, missing_offsets = cache.query(keys) if not offsets: @@ -128,9 +124,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) + assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) - _test_query_and_then_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys) values, missing_index, missing_keys, missing_offsets = cache.query(keys) if not offsets: @@ -141,9 +137,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) + assert torch.equal(cache2.query_and_replace(keys, reader_fn), a[keys]) - _test_query_and_then_replace(policy1, policy2, keys) + _test_query_and_replace(policy1, policy2, keys) assert cache.miss_rate == cache2.miss_rate