Skip to content

Commit

Permalink
Add coroutine mutex (#2095)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantasy-peak committed Aug 8, 2024
1 parent 0546032 commit c46f149
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 0 deletions.
176 changes: 176 additions & 0 deletions lib/inc/drogon/utils/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -811,4 +811,180 @@ inline internal::EventLoopAwaiter<T> queueInLoopCoro(trantor::EventLoop *loop,
return internal::EventLoopAwaiter<T>(std::move(task), loop);
}

class Mutex final
{
class ScopedCoroMutexAwaiter;
class CoroMutexAwaiter;

public:
Mutex() noexcept : state_(unlockedValue()), waiters_(nullptr)
{
}

Mutex(const Mutex &) = delete;
Mutex(Mutex &&) = delete;
Mutex &operator=(const Mutex &) = delete;
Mutex &operator=(Mutex &&) = delete;

~Mutex()
{
[[maybe_unused]] auto state = state_.load(std::memory_order_relaxed);
assert(state == unlockedValue() || state == nullptr);
assert(waiters_ == nullptr);
}

bool try_lock() noexcept
{
void *oldValue = unlockedValue();
return state_.compare_exchange_strong(oldValue,
nullptr,
std::memory_order_acquire,
std::memory_order_relaxed);
}

[[nodiscard]] ScopedCoroMutexAwaiter scoped_lock(
trantor::EventLoop *loop =
trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept
{
return ScopedCoroMutexAwaiter(*this, loop);
}

[[nodiscard]] CoroMutexAwaiter lock(
trantor::EventLoop *loop =
trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept
{
return CoroMutexAwaiter(*this, loop);
}

void unlock() noexcept
{
assert(state_.load(std::memory_order_relaxed) != unlockedValue());
auto *waitersHead = waiters_;
if (waitersHead == nullptr)
{
void *currentState = state_.load(std::memory_order_relaxed);
if (currentState == nullptr)
{
const bool releasedLock =
state_.compare_exchange_strong(currentState,
unlockedValue(),
std::memory_order_release,
std::memory_order_relaxed);
if (releasedLock)
{
return;
}
}
currentState = state_.exchange(nullptr, std::memory_order_acquire);
assert(currentState != unlockedValue());
assert(currentState != nullptr);
auto *waiter = static_cast<CoroMutexAwaiter *>(currentState);
do
{
auto *temp = waiter->next_;
waiter->next_ = waitersHead;
waitersHead = waiter;
waiter = temp;
} while (waiter != nullptr);
}
assert(waitersHead != nullptr);
waiters_ = waitersHead->next_;
if (waitersHead->loop_)
{
auto handle = waitersHead->handle_;
waitersHead->loop_->runInLoop([handle] { handle.resume(); });
}
else
{
waitersHead->handle_.resume();
}
}

private:
class CoroMutexAwaiter
{
public:
CoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop) noexcept
: mutex_(mutex), loop_(loop)
{
}

bool await_ready() noexcept
{
return mutex_.try_lock();
}

bool await_suspend(std::coroutine_handle<> handle) noexcept
{
handle_ = handle;
return mutex_.asynclockImpl(this);
}

void await_resume() noexcept
{
}

private:
friend class Mutex;

Mutex &mutex_;
trantor::EventLoop *loop_;
std::coroutine_handle<> handle_;
CoroMutexAwaiter *next_;
};

class ScopedCoroMutexAwaiter : public CoroMutexAwaiter
{
public:
ScopedCoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop)
: CoroMutexAwaiter(mutex, loop)
{
}

[[nodiscard]] auto await_resume() noexcept
{
return std::unique_lock<Mutex>{mutex_, std::adopt_lock};
}
};

bool asynclockImpl(CoroMutexAwaiter *awaiter)
{
void *oldValue = state_.load(std::memory_order_relaxed);
while (true)
{
if (oldValue == unlockedValue())
{
void *newValue = nullptr;
if (state_.compare_exchange_weak(oldValue,
newValue,
std::memory_order_acquire,
std::memory_order_relaxed))
{
return false;
}
}
else
{
void *newValue = awaiter;
awaiter->next_ = static_cast<CoroMutexAwaiter *>(oldValue);
if (state_.compare_exchange_weak(oldValue,
newValue,
std::memory_order_release,
std::memory_order_relaxed))
{
return true;
}
}
}
}

void *unlockedValue() noexcept
{
return this;
}

std::atomic<void *> state_;
CoroMutexAwaiter *waiters_;
};

} // namespace drogon
33 changes: 33 additions & 0 deletions lib/tests/unittests/CoroutineTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#include <drogon/utils/coroutine.h>
#include <drogon/HttpAppFramework.h>
#include <trantor/net/EventLoopThread.h>
#include <trantor/net/EventLoopThreadPool.h>
#include <chrono>
#include <cstdint>
#include <future>
#include <type_traits>

using namespace drogon;
Expand Down Expand Up @@ -212,3 +216,32 @@ DROGON_TEST(SwitchThread)
sync_wait(switch_thread());
thread.wait();
}

DROGON_TEST(Mutex)
{
trantor::EventLoopThreadPool pool{3};
pool.start();
Mutex mutex;
async_run([&]() -> Task<> {
co_await switchThreadCoro(pool.getLoop(0));
auto guard = co_await mutex.scoped_lock();
co_await sleepCoro(pool.getLoop(1), std::chrono::seconds(2));
co_return;
});
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::promise<void> done;
async_run([&]() -> Task<> {
co_await switchThreadCoro(pool.getLoop(2));
auto id = std::this_thread::get_id();
co_await mutex.lock();
CHECK(id == std::this_thread::get_id());
mutex.unlock();
CHECK(id == std::this_thread::get_id());
done.set_value();
co_return;
});
done.get_future().wait();
for (int16_t i = 0; i < 3; i++)
pool.getLoop(i)->quit();
pool.wait();
}

0 comments on commit c46f149

Please sign in to comment.