Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_async_sample_neighbors
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Aug 13, 2024
2 parents 7d0e634 + d650422 commit c6a8414
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
27 changes: 22 additions & 5 deletions graphbolt/include/graphbolt/async.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,22 @@
#include <ATen/Parallel.h>
#include <torch/script.h>

#include <atomic>
#include <exception>
#include <future>
#include <memory>
#include <mutex>
#include <type_traits>

#ifdef BUILD_WITH_TASKFLOW
#include <taskflow/algorithm/for_each.hpp>
#include <taskflow/taskflow.hpp>
#else
#include <atomic>
#include <exception>
#include <type_traits>
#endif

#ifdef GRAPHBOLT_USE_CUDA
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#endif

namespace graphbolt {
Expand Down Expand Up @@ -104,12 +110,23 @@ class Future : public torch::CustomClassHolder {
template <typename F>
inline auto async(F&& function) {
using T = decltype(function());
#ifdef GRAPHBOLT_USE_CUDA
auto stream = c10::cuda::getCurrentCUDAStream();
#endif
auto fn = [=, func = std::move(function)] {
#ifdef GRAPHBOLT_USE_CUDA
// We make sure to use the same CUDA stream as the thread launching the
// async operation.
c10::cuda::CUDAStreamGuard guard(stream);
#endif
return func();
};
#ifdef BUILD_WITH_TASKFLOW
auto future = interop_pool().async(std::move(function));
auto future = interop_pool().async(std::move(fn));
#else
auto promise = std::make_shared<std::promise<T>>();
auto future = promise->get_future();
at::launch([promise, func = std::move(function)]() {
at::launch([promise, func = std::move(fn)]() {
if constexpr (std::is_void_v<T>) {
func();
promise->set_value();
Expand Down
21 changes: 10 additions & 11 deletions graphbolt/src/cnumpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
static inline int num_queues_; // Number of queues.
static inline std::unique_ptr<::io_uring[], io_uring_queue_destroyer>
io_uring_queue_; // io_uring queue.
static inline counting_semaphore_t
semaphore_; // Control access to the io_uring queues.
static inline counting_semaphore_t semaphore_{
0}; // Control access to the io_uring queues.
static inline std::mutex available_queues_mtx_; // available_queues_ mutex.
static inline std::vector<int> available_queues_;

Expand All @@ -161,10 +161,11 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
}
// If this is the first thread exiting, release the master thread's
// ticket as well by releasing 2 slots. Otherwise, release 1 slot.
const auto releasing = acquirer_->exiting_first_.test_and_set() ? 1 : 2;
const auto releasing =
acquirer_->exiting_first_.test_and_set(std::memory_order_relaxed)
? 1
: 2;
semaphore_.release(releasing);
acquirer_->num_acquisitions_.fetch_add(
-releasing, std::memory_order_relaxed);
}

::io_uring& get() const { return io_uring_queue_[thread_id_]; }
Expand All @@ -179,17 +180,16 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
}

~QueueAndBufferAcquirer() {
// If any of the worker threads exit early without being able to release
// the semaphore, we make sure to release it for them in the main thread.
const auto releasing = num_acquisitions_.load(std::memory_order_relaxed);
// If none of the worker threads acquire the semaphore, we make sure to
// release the ticket taken in the constructor.
const auto releasing =
exiting_first_.test_and_set(std::memory_order_relaxed) ? 0 : 1;
semaphore_.release(releasing);
TORCH_CHECK(releasing == 0, "An io_uring worker thread didn't not exit.");
}

std::pair<UniqueQueue, char*> get() {
// We consume a slot from the semaphore to use a queue.
semaphore_.acquire();
num_acquisitions_.fetch_add(1, std::memory_order_relaxed);
const auto thread_id = [&] {
std::lock_guard lock(available_queues_mtx_);
TORCH_CHECK(!available_queues_.empty());
Expand All @@ -205,7 +205,6 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
private:
const OnDiskNpyArray* array_;
std::atomic_flag exiting_first_ = ATOMIC_FLAG_INIT;
std::atomic<int> num_acquisitions_ = 1;
};

#endif // HAVE_LIBRARY_LIBURING
Expand Down

0 comments on commit c6a8414

Please sign in to comment.