From c46f149c2c77e0f84dfc7afd45e9676d7cd0e637 Mon Sep 17 00:00:00 2001 From: fantasy-peak <82742316+fantasy-peak@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:17:06 +0800 Subject: [PATCH] Add coroutine mutex (#2095) --- lib/inc/drogon/utils/coroutine.h | 176 +++++++++++++++++++++++++++ lib/tests/unittests/CoroutineTest.cc | 33 +++++ 2 files changed, 209 insertions(+) diff --git a/lib/inc/drogon/utils/coroutine.h b/lib/inc/drogon/utils/coroutine.h index 19c22dd20b..9c6678d2e4 100644 --- a/lib/inc/drogon/utils/coroutine.h +++ b/lib/inc/drogon/utils/coroutine.h @@ -811,4 +811,180 @@ inline internal::EventLoopAwaiter queueInLoopCoro(trantor::EventLoop *loop, return internal::EventLoopAwaiter(std::move(task), loop); } +class Mutex final +{ + class ScopedCoroMutexAwaiter; + class CoroMutexAwaiter; + + public: + Mutex() noexcept : state_(unlockedValue()), waiters_(nullptr) + { + } + + Mutex(const Mutex &) = delete; + Mutex(Mutex &&) = delete; + Mutex &operator=(const Mutex &) = delete; + Mutex &operator=(Mutex &&) = delete; + + ~Mutex() + { + [[maybe_unused]] auto state = state_.load(std::memory_order_relaxed); + assert(state == unlockedValue() || state == nullptr); + assert(waiters_ == nullptr); + } + + bool try_lock() noexcept + { + void *oldValue = unlockedValue(); + return state_.compare_exchange_strong(oldValue, + nullptr, + std::memory_order_acquire, + std::memory_order_relaxed); + } + + [[nodiscard]] ScopedCoroMutexAwaiter scoped_lock( + trantor::EventLoop *loop = + trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept + { + return ScopedCoroMutexAwaiter(*this, loop); + } + + [[nodiscard]] CoroMutexAwaiter lock( + trantor::EventLoop *loop = + trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept + { + return CoroMutexAwaiter(*this, loop); + } + + void unlock() noexcept + { + assert(state_.load(std::memory_order_relaxed) != unlockedValue()); + auto *waitersHead = waiters_; + if (waitersHead == nullptr) + { + void *currentState = state_.load(std::memory_order_relaxed); + if (currentState == nullptr) + { + const bool releasedLock = + state_.compare_exchange_strong(currentState, + unlockedValue(), + std::memory_order_release, + std::memory_order_relaxed); + if (releasedLock) + { + return; + } + } + currentState = state_.exchange(nullptr, std::memory_order_acquire); + assert(currentState != unlockedValue()); + assert(currentState != nullptr); + auto *waiter = static_cast(currentState); + do + { + auto *temp = waiter->next_; + waiter->next_ = waitersHead; + waitersHead = waiter; + waiter = temp; + } while (waiter != nullptr); + } + assert(waitersHead != nullptr); + waiters_ = waitersHead->next_; + if (waitersHead->loop_) + { + auto handle = waitersHead->handle_; + waitersHead->loop_->runInLoop([handle] { handle.resume(); }); + } + else + { + waitersHead->handle_.resume(); + } + } + + private: + class CoroMutexAwaiter + { + public: + CoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop) noexcept + : mutex_(mutex), loop_(loop) + { + } + + bool await_ready() noexcept + { + return mutex_.try_lock(); + } + + bool await_suspend(std::coroutine_handle<> handle) noexcept + { + handle_ = handle; + return mutex_.asynclockImpl(this); + } + + void await_resume() noexcept + { + } + + private: + friend class Mutex; + + Mutex &mutex_; + trantor::EventLoop *loop_; + std::coroutine_handle<> handle_; + CoroMutexAwaiter *next_; + }; + + class ScopedCoroMutexAwaiter : public CoroMutexAwaiter + { + public: + ScopedCoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop) + : CoroMutexAwaiter(mutex, loop) + { + } + + [[nodiscard]] auto await_resume() noexcept + { + return std::unique_lock{mutex_, std::adopt_lock}; + } + }; + + bool asynclockImpl(CoroMutexAwaiter *awaiter) + { + void *oldValue = state_.load(std::memory_order_relaxed); + while (true) + { + if (oldValue == unlockedValue()) + { + void *newValue = nullptr; + if (state_.compare_exchange_weak(oldValue, + newValue, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + return false; + } + } + else + { + void *newValue = awaiter; + awaiter->next_ = static_cast(oldValue); + if (state_.compare_exchange_weak(oldValue, + newValue, + std::memory_order_release, + std::memory_order_relaxed)) + { + return true; + } + } + } + } + + void *unlockedValue() noexcept + { + return this; + } + + std::atomic state_; + CoroMutexAwaiter *waiters_; +}; + } // namespace drogon diff --git a/lib/tests/unittests/CoroutineTest.cc b/lib/tests/unittests/CoroutineTest.cc index bb8900c32a..c2e0f9dcff 100644 --- a/lib/tests/unittests/CoroutineTest.cc +++ b/lib/tests/unittests/CoroutineTest.cc @@ -2,6 +2,10 @@ #include #include #include +#include +#include +#include +#include #include using namespace drogon; @@ -212,3 +216,32 @@ DROGON_TEST(SwitchThread) sync_wait(switch_thread()); thread.wait(); } + +DROGON_TEST(Mutex) +{ + trantor::EventLoopThreadPool pool{3}; + pool.start(); + Mutex mutex; + async_run([&]() -> Task<> { + co_await switchThreadCoro(pool.getLoop(0)); + auto guard = co_await mutex.scoped_lock(); + co_await sleepCoro(pool.getLoop(1), std::chrono::seconds(2)); + co_return; + }); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::promise done; + async_run([&]() -> Task<> { + co_await switchThreadCoro(pool.getLoop(2)); + auto id = std::this_thread::get_id(); + co_await mutex.lock(); + CHECK(id == std::this_thread::get_id()); + mutex.unlock(); + CHECK(id == std::this_thread::get_id()); + done.set_value(); + co_return; + }); + done.get_future().wait(); + for (int16_t i = 0; i < 3; i++) + pool.getLoop(i)->quit(); + pool.wait(); +}