Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add when_all to our coroutine library #1944

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
221 changes: 216 additions & 5 deletions lib/inc/drogon/utils/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cassert>
#include <condition_variable>
#include <coroutine>
#include <cstddef>
#include <exception>
#include <future>
#include <mutex>
Expand Down Expand Up @@ -54,35 +55,101 @@ auto getAwaiter(T &&value) noexcept(
{
return getAwaiterImpl(static_cast<T &&>(value));
}

} // end namespace internal

// Some concepts used in this file
// * Coroutine - Something C++ generated for us. It has promise_type, etc..
// * Awaiter - Something we wrote manually that has await_ready, await_suspend,
// etc..
template <typename T, typename = std::void_t<>>
struct coroutine_result : std::false_type
{
};

template <typename T>
struct await_result
struct coroutine_result<
T,
std::void_t<decltype(internal::getAwaiter(std::declval<T>()))>>
{
using awaiter_t = decltype(internal::getAwaiter(std::declval<T>()));
using type = decltype(std::declval<awaiter_t>().await_resume());
};

template <typename T>
using await_result_t = typename await_result<T>::type;
using coroutine_result_t = typename coroutine_result<T>::type;

template <typename T, typename = std::void_t<>>
struct is_awaitable : std::false_type
struct is_coroutine : std::false_type
{
};

template <typename T>
struct is_awaitable<
struct is_coroutine<
T,
std::void_t<decltype(internal::getAwaiter(std::declval<T>()))>>
: std::true_type
{
};

template <typename T>
constexpr bool is_coroutine_v = is_coroutine<T>::value;

template <typename T, typename = std::void_t<>>
struct awaiter_result : std::false_type
{
};

template <typename T>
struct awaiter_result<T,
std::void_t<decltype(std::declval<T>().await_ready()),
decltype(std::declval<T>().await_suspend(
std::declval<std::coroutine_handle<>>())),
decltype(std::declval<T>().await_resume())>>
{
using type = decltype(std::declval<T>().await_resume());
};

template <typename T>
using awaiter_result_t = typename awaiter_result<T>::type;

template <typename T, typename = std::void_t<>>
struct is_awaiter : std::false_type
{
};

template <typename T>
struct is_awaiter<T,
std::void_t<decltype(std::declval<T>().await_ready()),
decltype(std::declval<T>().await_suspend(
std::declval<std::coroutine_handle<>>())),
decltype(std::declval<T>().await_resume())>>
: std::true_type
{
};

template <typename T>
constexpr bool is_awaiter_v = is_awaiter<T>::value;

// More generic traits
template <typename T>
struct is_awaitable : std::bool_constant<is_awaiter_v<T> || is_coroutine_v<T>>
{
};

template <typename T>
constexpr bool is_awaitable_v = is_awaitable<T>::value;

template <typename T>
struct await_result
{
using type = std::conditional_t<is_coroutine_v<T>,
coroutine_result_t<T>,
awaiter_result_t<T>>;
};

template <typename T>
using await_result_t = typename await_result<T>::type;

/**
* @struct final_awaiter
* @brief An awaiter for `Task::promise_type::final_suspend()`. Transfer
Expand Down Expand Up @@ -798,6 +865,40 @@ struct [[nodiscard]] EventLoopAwaiter : public drogon::CallbackAwaiter<T>
std::function<T()> task_;
trantor::EventLoop *loop_;
};

struct WaitForNotify : public CallbackAwaiter<void>
{
void await_suspend(std::coroutine_handle<> handle)
{
bool should_resume = false;
{
std::lock_guard<std::mutex> lock(mtx);
if (notified)
should_resume = true;
else
handle_ = handle;
}
if (should_resume)
handle.resume();
}

void notify()
{
bool should_resume = false;
{
std::lock_guard<std::mutex> lock(mtx);
notified = true;
if (handle_)
should_resume = true;
}
if (should_resume)
handle_.resume();
}

bool notified = false;
std::coroutine_handle<> handle_;
std::mutex mtx;
};
} // namespace internal

/**
Expand All @@ -811,4 +912,114 @@ inline internal::EventLoopAwaiter<T> queueInLoopCoro(trantor::EventLoop *loop,
return internal::EventLoopAwaiter<T>(std::move(task), loop);
}

/**
* @brief Waits for all tasks to complete. Throws exception if any of the tasks
* throws. In such cases, all tasks are still waited for completion.
* @param tasks A list of tasks to wait for
* @param loop The event loop to switch to after all tasks are completed
* (default nullptr, which means to keep on whichever thread the last task is
* completed)
* @return A task that completes when all tasks are completed
*/
template <typename Awaiter,
typename = std::enable_if_t<std::is_void_v<await_result_t<Awaiter>>>>
inline Task<> when_all(std::vector<Awaiter> tasks,
trantor::EventLoop *loop = nullptr)
{
static_assert(is_awaitable_v<Awaiter>);
std::exception_ptr eptr;
std::atomic_size_t counter = tasks.size();
internal::WaitForNotify waiter;
for (auto &&task : tasks)
{
[](std::exception_ptr &eptr,
std::atomic_size_t &counter,
internal::WaitForNotify &waiter,
Awaiter task) -> AsyncTask {
try
{
co_await task;
}
catch (...)
{
eptr = std::current_exception();
}

size_t c = counter.fetch_sub(1, std::memory_order_acq_rel) - 1;
if (c == 0)
{
waiter.notify();
}
}(eptr, counter, waiter, std::move(task));
}
// In case there's no task, we should still wait for the notify
if (tasks.empty())
waiter.notify();
co_await waiter;
if (loop)
co_await switchThreadCoro(loop);

if (eptr)
std::rethrow_exception(eptr);
}

/**
* @brief Waits for all tasks to complete. Throws exception if any of the tasks
* throws. In such cases, all tasks are still waited for completion.
* @param tasks A list of tasks to wait for
* @param loop The event loop to switch to after all tasks are completed
* (default nullptr, which means to keep on whichever thread the last task is
* completed)
* @return A task that completes when all tasks are completed and yields
* results of all tasks
*/
template <typename Awaiter,
typename = std::enable_if_t<!std::is_void_v<await_result_t<Awaiter>>>>
inline Task<std::vector<await_result_t<Awaiter>>> when_all(
std::vector<Awaiter> tasks,
trantor::EventLoop *loop = nullptr)
{
static_assert(is_awaitable_v<Awaiter>);
std::exception_ptr eptr;
std::vector<await_result_t<Awaiter>> results;
results.resize(tasks.size());
std::atomic_size_t counter = tasks.size();
internal::WaitForNotify waiter;
for (size_t i = 0; i < tasks.size(); ++i)
{
[](std::exception_ptr &eptr,
std::atomic_size_t &counter,
internal::WaitForNotify &waiter,
std::vector<await_result_t<Awaiter>> &results,
size_t i,
Awaiter task) -> AsyncTask {
try
{
results[i] = co_await task;
}
catch (...)
{
eptr = std::current_exception();
}

size_t c = counter.fetch_sub(1, std::memory_order_acq_rel) - 1;
if (c == 0)
{
waiter.notify();
}
}(eptr, counter, waiter, results, i, std::move(tasks[i]));
}
// In case there's no task, we should still wait for the notify
if (tasks.empty())
waiter.notify();
co_await waiter;
if (loop)
co_await switchThreadCoro(loop);

if (eptr)
std::rethrow_exception(eptr);

co_return results;
}

} // namespace drogon
85 changes: 85 additions & 0 deletions lib/tests/unittests/CoroutineTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <drogon/utils/coroutine.h>
#include <drogon/HttpAppFramework.h>
#include <trantor/net/EventLoopThread.h>
#include <functional>
#include <type_traits>

using namespace drogon;
Expand Down Expand Up @@ -47,9 +48,20 @@ DROGON_TEST(CroutineBasics)
STATIC_REQUIRE(is_int<await_result_t<Task<int>>>::value);
STATIC_REQUIRE(is_void<await_result_t<Task<>>>::value);

// Regular functions should not be awaitable
STATIC_REQUIRE(is_awaitable_v<std::function<void()>> == false);
STATIC_REQUIRE(is_awaitable_v<std::function<int()>> == false);

// No, you cannot await AsyncTask. By design
STATIC_REQUIRE(is_awaitable_v<AsyncTask> == false);

// Coroutine bodies should be awaitable
auto empty_coro = []() -> Task<> { co_return; };
STATIC_REQUIRE(is_awaitable_v<decltype(empty_coro)> == false);

// But their return types should be
STATIC_REQUIRE(is_awaitable_v<decltype(empty_coro())>);

// AsyncTask should execute eagerly
int m = 0;
[&m]() -> AsyncTask {
Expand Down Expand Up @@ -118,6 +130,14 @@ DROGON_TEST(CroutineBasics)
});
}

DROGON_TEST(AwaiterTraits)
{
auto awaiter = sleepCoro(drogon::app().getLoop(), 0.001);
STATIC_REQUIRE(is_awaitable_v<decltype(awaiter)>);
STATIC_REQUIRE(std::is_void<await_result_t<decltype(awaiter)>>::value);
sync_wait(awaiter);
}

DROGON_TEST(CompilcatedCoroutineLifetime)
{
auto coro = []() -> Task<Task<std::string>> {
Expand Down Expand Up @@ -212,3 +232,68 @@ DROGON_TEST(SwitchThread)
sync_wait(switch_thread());
thread.wait();
}

DROGON_TEST(WhenAll)
{
// Check all tasks are executed
int counter = 0;
auto coro = [&]() -> Task<> {
counter++;
co_return;
};
auto except = []() -> Task<> {
throw std::runtime_error("test error");
co_return;
};
auto slow = []() -> Task<> {
co_await sleepCoro(drogon::app().getLoop(), 0.001);
co_return;
};
auto return42 = []() -> Task<int> { co_return 42; };

std::vector<Task<>> tasks;
for (int i = 0; i < 10; ++i)
tasks.push_back(coro());
sync_wait(when_all(std::move(tasks)));
CHECK(counter == 10);

// Check exceptions are propagated while all coroutines run until completion
counter = 0;
std::vector<Task<>> tasks2;
tasks2.push_back(coro());
tasks2.push_back(except());
tasks2.push_back(coro());

CHECK_THROWS_AS(sync_wait(when_all(std::move(tasks2))), std::runtime_error);
CHECK(counter == 2);

// Check waiting for tasks that can't complete immediately works
counter = 0;
std::vector<Task<>> tasks3;
tasks3.push_back(slow());
// tasks3.push_back(slow());
tasks3.push_back(coro());
sync_wait(when_all(std::move(tasks3)));
CHECK(counter == 1);

// Check we can get the results of the tasks
std::vector<Task<int>> tasks4;
tasks4.push_back(return42());
tasks4.push_back(return42());
auto results = sync_wait(when_all(std::move(tasks4)));
CHECK(results.size() == 2);
CHECK(results[0] == 42);
CHECK(results[1] == 42);

// Check waiting on non-task works
auto sleep = sleepCoro(drogon::app().getLoop(), 0.001);
auto sleep2 = sleepCoro(drogon::app().getLoop(), 0.001);
std::vector<decltype(sleep)> tasks5;
tasks5.emplace_back(std::move(sleep));
tasks5.emplace_back(std::move(sleep2));
sync_wait(when_all(std::move(tasks5)));

// Check waiting on empty list works
std::vector<Task<>> tasks6;
sync_wait(when_all(std::move(tasks6)));
}