Skip to content

Commit

Permalink
Experimentation on cancellable awaiter.
Browse files Browse the repository at this point in the history
  • Loading branch information
hwc0919 committed Dec 29, 2023
1 parent 1fd5c7e commit 55929b1
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ if (BUILD_BROTLI)
endif (BUILD_BROTLI)

set(DROGON_SOURCES
lib/src/coroutine.cc
lib/src/AOPAdvice.cc
lib/src/AccessLogger.cc
lib/src/CacheFile.cc
Expand Down
56 changes: 55 additions & 1 deletion lib/inc/drogon/utils/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,27 @@
#include <type_traits>
#include <optional>

namespace drogon
{
struct CancelHandle;
using CancelHandlePtr = std::shared_ptr<CancelHandle>;

struct CancelHandle
{
static CancelHandlePtr create();

virtual void cancel() = 0;
virtual bool isCancelRequested() = 0;
virtual void registerCancelCallback(std::function<void()> callback) = 0;
};

class TaskCancelledException final : public std::runtime_error
{
public:
using std::runtime_error::runtime_error;
};
} // namespace drogon

namespace drogon
{
namespace internal
Expand Down Expand Up @@ -596,6 +617,30 @@ struct [[nodiscard]] TimerAwaiter : CallbackAwaiter<void>
double delay_;
};

struct [[nodiscard]] CancellableTimeAwaiter : CallbackAwaiter<void>
{
CancellableTimeAwaiter(trantor::EventLoop *loop,
const std::chrono::duration<double> &delay,
CancelHandlePtr cancelHandle)
: CancellableTimeAwaiter(loop, delay.count(), std::move(cancelHandle))
{
}

CancellableTimeAwaiter(trantor::EventLoop *loop,
double delay,
CancelHandlePtr cancelHandle)
: loop_(loop), delay_(delay), cancelHandle_(std::move(cancelHandle))
{
}

void await_suspend(std::coroutine_handle<> handle);

private:
trantor::EventLoop *loop_;
double delay_;
CancelHandlePtr cancelHandle_;
};

struct [[nodiscard]] LoopAwaiter : CallbackAwaiter<void>
{
LoopAwaiter(trantor::EventLoop *workLoop,
Expand Down Expand Up @@ -684,6 +729,15 @@ inline internal::TimerAwaiter sleepCoro(trantor::EventLoop *loop,
return {loop, delay};
}

inline internal::CancellableTimeAwaiter sleepCoro(
trantor::EventLoop *loop,
double delay,
CancelHandlePtr cancelHandle) noexcept
{
assert(loop);
return {loop, delay, std::move(cancelHandle)};
}

inline internal::LoopAwaiter queueInLoopCoro(
trantor::EventLoop *workLoop,
std::function<void()> taskFunc,
Expand Down Expand Up @@ -749,7 +803,7 @@ void async_run(Coro &&coro)

/**
* @brief returns a function that calls a coroutine
* @param coro A coroutine that is awaitable
* @param Coro A coroutine that is awaitable
*/
template <typename Coro>
std::function<void()> async_func(Coro &&coro)
Expand Down
95 changes: 95 additions & 0 deletions lib/src/coroutine.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//
// Created by wanchen.he on 2023/12/29.
//
#ifdef __cpp_impl_coroutine
#include <drogon/utils/coroutine.h>

namespace drogon
{
class CancelHandleImpl : public CancelHandle
{
public:
CancelHandleImpl() = default;

void cancel() override
{
std::function<void()> handle;
{
std::lock_guard<std::mutex> lock(mutex_);
cancelRequested_ = true;
handle = std::move(cancelHandle_);
}
if (handle)
handle();
}

bool isCancelRequested() override
{
std::lock_guard<std::mutex> lock(mutex_);
return cancelRequested_;
}

void registerCancelCallback(std::function<void()> callback) override
{
}

void setCancelHandle(std::function<void()> handle)
{
bool cancelled{false};
{
std::lock_guard<std::mutex> lock(mutex_);
if (cancelRequested_)
{
cancelled = true;
}
else
{
cancelHandle_ = std::move(handle);
}
}
if (cancelled)
{
handle();
}
}

private:
std::mutex mutex_;
bool cancelRequested_{false};

std::shared_ptr<std::atomic_bool> flagPtr_;
std::function<void()> cancelHandle_;
};

CancelHandlePtr CancelHandle::create()
{
return std::make_shared<CancelHandleImpl>();
}

void internal::CancellableTimeAwaiter::await_suspend(
std::coroutine_handle<> handle)
{
auto execFlagPtr = std::make_shared<std::atomic_bool>(false);
if (cancelHandle_)
{
static_cast<CancelHandleImpl *>(cancelHandle_.get())
->setCancelHandle([this, handle, execFlagPtr, loop = loop_]() {
if (!execFlagPtr->exchange(true))
{
setException(std::make_exception_ptr(
TaskCancelledException("Task cancelled")));
loop->queueInLoop([handle]() { handle.resume(); });
return;
}
});
}
loop_->runAfter(delay_, [handle, execFlagPtr = std::move(execFlagPtr)]() {
if (!execFlagPtr->exchange(true))
{
handle.resume();
}
});
}
} // namespace drogon

#endif
34 changes: 34 additions & 0 deletions lib/tests/unittests/CoroutineTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,37 @@ DROGON_TEST(SwitchThread)
sync_wait(switch_thread());
thread.wait();
}

DROGON_TEST(Cancellation)
{
using namespace drogon::internal;

trantor::EventLoopThread thread; // helper thread
thread.run();

auto testCancelTask = [TEST_CTX, loop = thread.getLoop()]() -> Task<> {
auto cancelHandle = CancelHandle::create();

// wait coro for 10 seconds, but cancel after 1 second
loop->runAfter(1, [cancelHandle]() { cancelHandle->cancel(); });

int64_t start = time(nullptr);
try
{
LOG_INFO << "Waiting for 10 seconds...";
co_await sleepCoro(loop, 10, cancelHandle);
CHECK(false); // should not reach here
}
catch (const TaskCancelledException &ex)
{
int64_t waitTime = time(nullptr) - start;
CHECK(waitTime < 2);
LOG_INFO << "Oops... only waited for " << waitTime << " second(s)";
}

loop->quit();
};

sync_wait(testCancelTask());
thread.wait();
}

0 comments on commit 55929b1

Please sign in to comment.