Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment: cancellable awaiter #1893

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ if (BUILD_BROTLI)
endif (BUILD_BROTLI)

set(DROGON_SOURCES
lib/src/coroutine.cc
lib/src/AOPAdvice.cc
lib/src/AccessLogger.cc
lib/src/CacheFile.cc
Expand Down
78 changes: 78 additions & 0 deletions lib/inc/drogon/utils/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,27 @@
#include <type_traits>
#include <optional>

namespace drogon
{
struct CancelHandle;
using CancelHandlePtr = std::shared_ptr<CancelHandle>;

struct CancelHandle
{
static CancelHandlePtr newHandle();
static CancelHandlePtr newSharedHandle();

virtual void cancel() = 0;
virtual bool isCancelRequested() = 0;
};

class TaskCancelledException final : public std::runtime_error
{
public:
using std::runtime_error::runtime_error;
};
} // namespace drogon

namespace drogon
{
namespace internal
Expand Down Expand Up @@ -596,6 +617,44 @@ struct [[nodiscard]] TimerAwaiter : CallbackAwaiter<void>
double delay_;
};

struct [[nodiscard]] CancellableAwaiter : CallbackAwaiter<void>
{
CancellableAwaiter(trantor::EventLoop *loop, CancelHandlePtr cancelHandle)
: loop_(loop), cancelHandle_(std::move(cancelHandle))
{
}

void await_suspend(std::coroutine_handle<> handle);

private:
trantor::EventLoop *loop_;
CancelHandlePtr cancelHandle_;
};

struct [[nodiscard]] CancellableTimeAwaiter : CallbackAwaiter<void>
{
CancellableTimeAwaiter(trantor::EventLoop *loop,
const std::chrono::duration<double> &delay,
CancelHandlePtr cancelHandle)
: CancellableTimeAwaiter(loop, delay.count(), std::move(cancelHandle))
{
}

CancellableTimeAwaiter(trantor::EventLoop *loop,
double delay,
CancelHandlePtr cancelHandle)
: loop_(loop), delay_(delay), cancelHandle_(std::move(cancelHandle))
{
}

void await_suspend(std::coroutine_handle<> handle);

private:
trantor::EventLoop *loop_;
double delay_;
CancelHandlePtr cancelHandle_;
};

struct [[nodiscard]] LoopAwaiter : CallbackAwaiter<void>
{
LoopAwaiter(trantor::EventLoop *workLoop,
Expand Down Expand Up @@ -684,6 +743,25 @@ inline internal::TimerAwaiter sleepCoro(trantor::EventLoop *loop,
return {loop, delay};
}

inline internal::CancellableTimeAwaiter sleepCoro(
trantor::EventLoop *loop,
double delay,
CancelHandlePtr cancelHandle) noexcept
{
assert(loop);
assert(cancelHandle);
return {loop, delay, std::move(cancelHandle)};
}

inline internal::CancellableAwaiter sleepForeverCoro(
trantor::EventLoop *loop,
CancelHandlePtr cancelHandle) noexcept
{
assert(loop);
assert(cancelHandle);
return {loop, std::move(cancelHandle)};
}

inline internal::LoopAwaiter queueInLoopCoro(
trantor::EventLoop *workLoop,
std::function<void()> taskFunc,
Expand Down
163 changes: 163 additions & 0 deletions lib/src/coroutine.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
//
// Created by wanchen.he on 2023/12/29.
//
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>

namespace drogon
{
class CancelHandleImpl : public CancelHandle
{
public:
virtual void setCancelHandle(std::function<void()> handle) = 0;
};

class SimpleCancelHandleImpl : public CancelHandleImpl
{
public:
SimpleCancelHandleImpl() = default;

void cancel() override
{
std::function<void()> handle;
{
std::lock_guard<std::mutex> lock(mutex_);
cancelRequested_ = true;
handle = std::move(cancelHandle_);
}
if (handle)
handle();
}

bool isCancelRequested() override
{
std::lock_guard<std::mutex> lock(mutex_);
return cancelRequested_;
}

void setCancelHandle(std::function<void()> handle) override
{
bool cancelled{false};
{
std::lock_guard<std::mutex> lock(mutex_);
if (cancelRequested_)
{
cancelled = true;
}
else
{
cancelHandle_ = std::move(handle);
}
}
if (cancelled)
{
handle();
}
}

private:
std::mutex mutex_;
bool cancelRequested_{false};
std::function<void()> cancelHandle_;
};

class SharedCancelHandleImpl : public CancelHandleImpl
{
public:
SharedCancelHandleImpl() = default;

void cancel() override
{
std::vector<std::function<void()>> handles;
{
std::lock_guard<std::mutex> lock(mutex_);
cancelRequested_ = true;
handles.swap(cancelHandles_);
}
if (!handles.empty())
{
for (auto &handle : handles)
{
handle();
}
}
}

bool isCancelRequested() override
{
std::lock_guard<std::mutex> lock(mutex_);
return cancelRequested_;
}

void setCancelHandle(std::function<void()> handle) override
{
bool cancelled{false};
{
std::lock_guard<std::mutex> lock(mutex_);
if (cancelRequested_)
{
cancelled = true;
}
else
{
cancelHandles_.emplace_back(std::move(handle));
}
}
if (cancelled)
{
handle();
}
}

private:
std::mutex mutex_;
bool cancelRequested_{false};
std::vector<std::function<void()>> cancelHandles_;
};

CancelHandlePtr CancelHandle::newHandle()
{
return std::make_shared<SimpleCancelHandleImpl>();
}

CancelHandlePtr CancelHandle::newSharedHandle()
{
return std::make_shared<SharedCancelHandleImpl>();
}

void internal::CancellableAwaiter::await_suspend(std::coroutine_handle<> handle)
{
static_cast<CancelHandleImpl *>(cancelHandle_.get())
->setCancelHandle([this, handle, loop = loop_]() {
setException(std::make_exception_ptr(
TaskCancelledException("Task cancelled")));
loop->queueInLoop([handle]() { handle.resume(); });
return;
});
}

void internal::CancellableTimeAwaiter::await_suspend(
std::coroutine_handle<> handle)
{
auto execFlagPtr = std::make_shared<std::atomic_bool>(false);
static_cast<CancelHandleImpl *>(cancelHandle_.get())
->setCancelHandle([this, handle, execFlagPtr, loop = loop_]() {
if (!execFlagPtr->exchange(true))
{
setException(std::make_exception_ptr(
TaskCancelledException("Task cancelled")));
loop->queueInLoop([handle]() { handle.resume(); });
return;
}
});
loop_->runAfter(delay_, [handle, execFlagPtr = std::move(execFlagPtr)]() {
if (!execFlagPtr->exchange(true))
{
handle.resume();
}
});
}

} // namespace drogon

#endif
99 changes: 99 additions & 0 deletions lib/tests/unittests/CoroutineTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,102 @@ DROGON_TEST(SwitchThread)
sync_wait(switch_thread());
thread.wait();
}

DROGON_TEST(Cancellation)
{
using namespace drogon::internal;

trantor::EventLoopThread thread; // helper thread
thread.run();

auto testCancelTask = [TEST_CTX, loop = thread.getLoop()]() -> Task<> {
auto cancelHandle = CancelHandle::newHandle();

// wait coro for 10 seconds, but cancel after 1 second
loop->runAfter(1, [cancelHandle]() { cancelHandle->cancel(); });

int64_t start = time(nullptr);
try
{
LOG_INFO << "Waiting for 10 seconds...";
co_await sleepCoro(loop, 10, cancelHandle);
CHECK(false); // should not reach here
}
catch (const TaskCancelledException &ex)
{
int64_t waitTime = time(nullptr) - start;
CHECK(waitTime < 2);
LOG_INFO << "Oops... only waited for " << waitTime << " second(s)";
}
};

sync_wait(testCancelTask());
thread.getLoop()->quit();
thread.wait();
}

DROGON_TEST(SharedCancellation)
{
using namespace drogon::internal;

trantor::EventLoopThread thread; // helper thread
thread.run();
auto loop = thread.getLoop();
auto sharedHandle = CancelHandle::newSharedHandle();

auto testSharedCancelTask1 = [TEST_CTX, sharedHandle, loop]() -> Task<> {
int64_t start = time(nullptr);
try
{
LOG_INFO << "Waiting for 10 seconds...";
co_await sleepCoro(loop, 10, sharedHandle);
CHECK(false); // should not reach here
}
catch (const TaskCancelledException &ex)
{
int64_t waitTime = time(nullptr) - start;
CHECK(waitTime < 2);
LOG_INFO << "Oops... only waited for " << waitTime << " second(s)";
}
};

auto testSharedCancelTask2 = [TEST_CTX, sharedHandle, loop]() -> Task<> {
int64_t start = time(nullptr);
try
{
LOG_INFO << "Sleep forever...";
co_await sleepForeverCoro(loop, sharedHandle);
CHECK(false); // should not reach here
}
catch (const TaskCancelledException &ex)
{
int64_t waitTime = time(nullptr) - start;
CHECK(waitTime < 2);
LOG_INFO << "Oops... only slept for " << waitTime << " second(s)";
}
};

auto testSharedCancelTask3 = [TEST_CTX, sharedHandle, loop]() -> Task<> {
co_await sleepCoro(loop, 1.5);
int64_t start = time(nullptr);
try
{
LOG_INFO << "Try sleep after cancel...";
co_await sleepForeverCoro(loop, sharedHandle);
CHECK(false); // should not reach here
}
catch (const TaskCancelledException &ex)
{
int64_t waitTime = time(nullptr) - start;
CHECK(waitTime < 2);
LOG_INFO << "Oops... only slept for " << waitTime << " second(s)";
}
};
// cancel both tasks after 1 second
loop->runAfter(1, [sharedHandle]() { sharedHandle->cancel(); });
sync_wait(testSharedCancelTask1());
sync_wait(testSharedCancelTask2());
sync_wait(testSharedCancelTask3());
loop->quit();
thread.wait();
}