diff --git a/CMakeLists.txt b/CMakeLists.txt index b6906f4407..850b683f61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -245,6 +245,7 @@ if (BUILD_BROTLI) endif (BUILD_BROTLI) set(DROGON_SOURCES + lib/src/coroutine.cc lib/src/AOPAdvice.cc lib/src/AccessLogger.cc lib/src/CacheFile.cc diff --git a/lib/inc/drogon/utils/coroutine.h b/lib/inc/drogon/utils/coroutine.h index 19c22dd20b..fabcd8c6ac 100644 --- a/lib/inc/drogon/utils/coroutine.h +++ b/lib/inc/drogon/utils/coroutine.h @@ -27,6 +27,27 @@ #include #include +namespace drogon +{ +struct CancelHandle; +using CancelHandlePtr = std::shared_ptr; + +struct CancelHandle +{ + static CancelHandlePtr newHandle(); + static CancelHandlePtr newSharedHandle(); + + virtual void cancel() = 0; + virtual bool isCancelRequested() = 0; +}; + +class TaskCancelledException final : public std::runtime_error +{ + public: + using std::runtime_error::runtime_error; +}; +} // namespace drogon + namespace drogon { namespace internal @@ -596,6 +617,44 @@ struct [[nodiscard]] TimerAwaiter : CallbackAwaiter double delay_; }; +struct [[nodiscard]] CancellableAwaiter : CallbackAwaiter +{ + CancellableAwaiter(trantor::EventLoop *loop, CancelHandlePtr cancelHandle) + : loop_(loop), cancelHandle_(std::move(cancelHandle)) + { + } + + void await_suspend(std::coroutine_handle<> handle); + + private: + trantor::EventLoop *loop_; + CancelHandlePtr cancelHandle_; +}; + +struct [[nodiscard]] CancellableTimeAwaiter : CallbackAwaiter +{ + CancellableTimeAwaiter(trantor::EventLoop *loop, + const std::chrono::duration &delay, + CancelHandlePtr cancelHandle) + : CancellableTimeAwaiter(loop, delay.count(), std::move(cancelHandle)) + { + } + + CancellableTimeAwaiter(trantor::EventLoop *loop, + double delay, + CancelHandlePtr cancelHandle) + : loop_(loop), delay_(delay), cancelHandle_(std::move(cancelHandle)) + { + } + + void await_suspend(std::coroutine_handle<> handle); + + private: + trantor::EventLoop *loop_; + double delay_; + CancelHandlePtr cancelHandle_; +}; + struct [[nodiscard]] LoopAwaiter : CallbackAwaiter { LoopAwaiter(trantor::EventLoop *workLoop, @@ -684,6 +743,25 @@ inline internal::TimerAwaiter sleepCoro(trantor::EventLoop *loop, return {loop, delay}; } +inline internal::CancellableTimeAwaiter sleepCoro( + trantor::EventLoop *loop, + double delay, + CancelHandlePtr cancelHandle) noexcept +{ + assert(loop); + assert(cancelHandle); + return {loop, delay, std::move(cancelHandle)}; +} + +inline internal::CancellableAwaiter sleepForeverCoro( + trantor::EventLoop *loop, + CancelHandlePtr cancelHandle) noexcept +{ + assert(loop); + assert(cancelHandle); + return {loop, std::move(cancelHandle)}; +} + inline internal::LoopAwaiter queueInLoopCoro( trantor::EventLoop *workLoop, std::function taskFunc, diff --git a/lib/src/coroutine.cc b/lib/src/coroutine.cc new file mode 100644 index 0000000000..2f02f69926 --- /dev/null +++ b/lib/src/coroutine.cc @@ -0,0 +1,163 @@ +// +// Created by wanchen.he on 2023/12/29. +// +#ifdef __cpp_impl_coroutine +#include + +namespace drogon +{ +class CancelHandleImpl : public CancelHandle +{ + public: + virtual void setCancelHandle(std::function handle) = 0; +}; + +class SimpleCancelHandleImpl : public CancelHandleImpl +{ + public: + SimpleCancelHandleImpl() = default; + + void cancel() override + { + std::function handle; + { + std::lock_guard lock(mutex_); + cancelRequested_ = true; + handle = std::move(cancelHandle_); + } + if (handle) + handle(); + } + + bool isCancelRequested() override + { + std::lock_guard lock(mutex_); + return cancelRequested_; + } + + void setCancelHandle(std::function handle) override + { + bool cancelled{false}; + { + std::lock_guard lock(mutex_); + if (cancelRequested_) + { + cancelled = true; + } + else + { + cancelHandle_ = std::move(handle); + } + } + if (cancelled) + { + handle(); + } + } + + private: + std::mutex mutex_; + bool cancelRequested_{false}; + std::function cancelHandle_; +}; + +class SharedCancelHandleImpl : public CancelHandleImpl +{ + public: + SharedCancelHandleImpl() = default; + + void cancel() override + { + std::vector> handles; + { + std::lock_guard lock(mutex_); + cancelRequested_ = true; + handles.swap(cancelHandles_); + } + if (!handles.empty()) + { + for (auto &handle : handles) + { + handle(); + } + } + } + + bool isCancelRequested() override + { + std::lock_guard lock(mutex_); + return cancelRequested_; + } + + void setCancelHandle(std::function handle) override + { + bool cancelled{false}; + { + std::lock_guard lock(mutex_); + if (cancelRequested_) + { + cancelled = true; + } + else + { + cancelHandles_.emplace_back(std::move(handle)); + } + } + if (cancelled) + { + handle(); + } + } + + private: + std::mutex mutex_; + bool cancelRequested_{false}; + std::vector> cancelHandles_; +}; + +CancelHandlePtr CancelHandle::newHandle() +{ + return std::make_shared(); +} + +CancelHandlePtr CancelHandle::newSharedHandle() +{ + return std::make_shared(); +} + +void internal::CancellableAwaiter::await_suspend(std::coroutine_handle<> handle) +{ + static_cast(cancelHandle_.get()) + ->setCancelHandle([this, handle, loop = loop_]() { + setException(std::make_exception_ptr( + TaskCancelledException("Task cancelled"))); + loop->queueInLoop([handle]() { handle.resume(); }); + return; + }); +} + +void internal::CancellableTimeAwaiter::await_suspend( + std::coroutine_handle<> handle) +{ + auto execFlagPtr = std::make_shared(false); + static_cast(cancelHandle_.get()) + ->setCancelHandle([this, handle, execFlagPtr, loop = loop_]() { + if (!execFlagPtr->exchange(true)) + { + setException(std::make_exception_ptr( + TaskCancelledException("Task cancelled"))); + loop->queueInLoop([handle]() { handle.resume(); }); + return; + } + }); + loop_->runAfter(delay_, [handle, execFlagPtr = std::move(execFlagPtr)]() { + if (!execFlagPtr->exchange(true)) + { + handle.resume(); + } + }); +} + +} // namespace drogon + +#endif diff --git a/lib/tests/unittests/CoroutineTest.cc b/lib/tests/unittests/CoroutineTest.cc index 9768e62658..de37fcfdc6 100644 --- a/lib/tests/unittests/CoroutineTest.cc +++ b/lib/tests/unittests/CoroutineTest.cc @@ -212,3 +212,102 @@ DROGON_TEST(SwitchThread) sync_wait(switch_thread()); thread.wait(); } + +DROGON_TEST(Cancellation) +{ + using namespace drogon::internal; + + trantor::EventLoopThread thread; // helper thread + thread.run(); + + auto testCancelTask = [TEST_CTX, loop = thread.getLoop()]() -> Task<> { + auto cancelHandle = CancelHandle::newHandle(); + + // wait coro for 10 seconds, but cancel after 1 second + loop->runAfter(1, [cancelHandle]() { cancelHandle->cancel(); }); + + int64_t start = time(nullptr); + try + { + LOG_INFO << "Waiting for 10 seconds..."; + co_await sleepCoro(loop, 10, cancelHandle); + CHECK(false); // should not reach here + } + catch (const TaskCancelledException &ex) + { + int64_t waitTime = time(nullptr) - start; + CHECK(waitTime < 2); + LOG_INFO << "Oops... only waited for " << waitTime << " second(s)"; + } + }; + + sync_wait(testCancelTask()); + thread.getLoop()->quit(); + thread.wait(); +} + +DROGON_TEST(SharedCancellation) +{ + using namespace drogon::internal; + + trantor::EventLoopThread thread; // helper thread + thread.run(); + auto loop = thread.getLoop(); + auto sharedHandle = CancelHandle::newSharedHandle(); + + auto testSharedCancelTask1 = [TEST_CTX, sharedHandle, loop]() -> Task<> { + int64_t start = time(nullptr); + try + { + LOG_INFO << "Waiting for 10 seconds..."; + co_await sleepCoro(loop, 10, sharedHandle); + CHECK(false); // should not reach here + } + catch (const TaskCancelledException &ex) + { + int64_t waitTime = time(nullptr) - start; + CHECK(waitTime < 2); + LOG_INFO << "Oops... only waited for " << waitTime << " second(s)"; + } + }; + + auto testSharedCancelTask2 = [TEST_CTX, sharedHandle, loop]() -> Task<> { + int64_t start = time(nullptr); + try + { + LOG_INFO << "Sleep forever..."; + co_await sleepForeverCoro(loop, sharedHandle); + CHECK(false); // should not reach here + } + catch (const TaskCancelledException &ex) + { + int64_t waitTime = time(nullptr) - start; + CHECK(waitTime < 2); + LOG_INFO << "Oops... only slept for " << waitTime << " second(s)"; + } + }; + + auto testSharedCancelTask3 = [TEST_CTX, sharedHandle, loop]() -> Task<> { + co_await sleepCoro(loop, 1.5); + int64_t start = time(nullptr); + try + { + LOG_INFO << "Try sleep after cancel..."; + co_await sleepForeverCoro(loop, sharedHandle); + CHECK(false); // should not reach here + } + catch (const TaskCancelledException &ex) + { + int64_t waitTime = time(nullptr) - start; + CHECK(waitTime < 2); + LOG_INFO << "Oops... only slept for " << waitTime << " second(s)"; + } + }; + // cancel both tasks after 1 second + loop->runAfter(1, [sharedHandle]() { sharedHandle->cancel(); }); + sync_wait(testSharedCancelTask1()); + sync_wait(testSharedCancelTask2()); + sync_wait(testSharedCancelTask3()); + loop->quit(); + thread.wait(); +}