Skip to content

Commit

Permalink
Fix handling non-Task coroutines in TaskFinalSuspend
Browse files Browse the repository at this point in the history
  • Loading branch information
danvratil committed Aug 20, 2024
1 parent 34a797b commit 2dd5a6c
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 24 deletions.
9 changes: 9 additions & 0 deletions qcoro/impl/taskawaiterbase.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#pragma once

#include "../qcorotask.h"
#include <coroutine>
#include <type_traits>

namespace QCoro::detail
{
Expand All @@ -18,6 +20,13 @@ inline bool TaskAwaiterBase<Promise>::await_ready() const noexcept {
return !mAwaitedCoroutine || mAwaitedCoroutine.done();
}

template<typename Promise>
template<typename T>
inline void TaskAwaiterBase<Promise>::await_suspend(std::coroutine_handle<TaskPromise<T>> awaitingCoroutine) noexcept {
auto handle = std::coroutine_handle<TaskPromiseBase>::from_address(awaitingCoroutine.address());
mAwaitedCoroutine.promise().addAwaitingCoroutine(handle);
}

template<typename Promise>
inline void TaskAwaiterBase<Promise>::await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept {
mAwaitedCoroutine.promise().addAwaitingCoroutine(awaitingCoroutine);
Expand Down
24 changes: 14 additions & 10 deletions qcoro/impl/taskfinalsuspend.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
#pragma once

#include "../qcorotask.h"
#include <coroutine>
#include <QDebug>

namespace QCoro::detail
{

inline TaskFinalSuspend::TaskFinalSuspend(const std::vector<std::coroutine_handle<>> &awaitingCoroutines)
inline TaskFinalSuspend::TaskFinalSuspend(const std::vector<CoroutineHandle> &awaitingCoroutines)
: mAwaitingCoroutines(awaitingCoroutines) {}

inline bool TaskFinalSuspend::await_ready() const noexcept {
Expand All @@ -25,16 +27,18 @@ inline void TaskFinalSuspend::await_suspend(std::coroutine_handle<Promise> finis
auto &promise = finishedCoroutine.promise();

for (auto &awaiter : mAwaitingCoroutines) {
auto handle = std::coroutine_handle<TaskPromiseBase>::from_address(awaiter.address());
auto &promise = handle.promise();
const CoroutineFeatures &features = promise.features();
if (const auto &guardedThis = features.guardedThis(); guardedThis.has_value() && guardedThis->isNull()) {
// We have a QPointer, but it's null which means that observed QObject has been destroyed,
// so just destroy the current coroutine as well.
qDebug() << "Destroy direct!";
promise.destroyCoroutine();
if (const auto qcoro_handle = std::get_if<std::coroutine_handle<TaskPromiseBase>>(&awaiter); qcoro_handle != nullptr) {
auto &promise = qcoro_handle->promise();
const CoroutineFeatures &features = promise.features();
if (const auto &guardedThis = features.guardedThis(); guardedThis.has_value() && guardedThis->isNull()) {
// We have a QPointer, but it's null which means that observed QObject has been destroyed,
// so just destroy the current coroutine as well.
promise.destroyCoroutine(true);
} else {
qcoro_handle->resume();
}
} else {
awaiter.resume();
std::get<std::coroutine_handle<>>(awaiter).resume();
}
}
mAwaitingCoroutines.clear();
Expand Down
26 changes: 20 additions & 6 deletions qcoro/impl/taskpromisebase.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "../qcorotask.h"
#include "bits/features.h"

#include <QDebug>
#include <coroutine>

namespace QCoro::detail
{
Expand Down Expand Up @@ -54,8 +54,13 @@ inline auto &TaskPromiseBase::await_transform(T &awaitable) {
return awaitable;
}

inline void TaskPromiseBase::addAwaitingCoroutine(std::coroutine_handle<> awaitingCoroutine) {
mAwaitingCoroutines.push_back(awaitingCoroutine);
template<typename T>
inline void TaskPromiseBase::addAwaitingCoroutine(std::coroutine_handle<T> awaitingCoroutine) {
if constexpr (std::is_same_v<T, TaskPromiseBase>) {
mAwaitingCoroutines.emplace_back(CoroutineHandle{std::in_place_index<0>, awaitingCoroutine});
} else {
mAwaitingCoroutines.emplace_back(CoroutineHandle{std::in_place_index<1>, awaitingCoroutine});
}
}

inline bool TaskPromiseBase::hasAwaitingCoroutine() const {
Expand All @@ -64,21 +69,30 @@ inline bool TaskPromiseBase::hasAwaitingCoroutine() const {

inline void TaskPromiseBase::derefCoroutine() {
--mRefCount;
qDebug() << this << "coroutine refcount descreased to " << mRefCount.load();
if (mRefCount == 0) {
destroyCoroutine();
}
}

inline void TaskPromiseBase::refCoroutine() {
++mRefCount;
qDebug() << this << "coroutine refcount increased to " << mRefCount.load();
}

inline void TaskPromiseBase::destroyCoroutine() {
inline void TaskPromiseBase::destroyCoroutine(bool wakeUpAwaiters) {
if (wakeUpAwaiters) {
for (auto &awaiter : mAwaitingCoroutines) {
if (const auto qcoro_handle = std::get_if<std::coroutine_handle<TaskPromiseBase>>(&awaiter); qcoro_handle != nullptr) {
qcoro_handle->resume();
} else {
std::get<std::coroutine_handle<>>(awaiter).resume();
}
}
}

mRefCount = 0;
auto handle = std::coroutine_handle<TaskPromiseBase>::from_promise(*this);
handle.destroy();

}

inline CoroutineFeatures &TaskPromiseBase::features() noexcept {
Expand Down
1 change: 1 addition & 0 deletions qcoro/qcoroasyncgenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "coroutine.h"

#include <coroutine>
#include <iterator>
#include <exception>

Expand Down
19 changes: 13 additions & 6 deletions qcoro/qcorotask.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include "bits/features.h"

#include <atomic>
#include <coroutine>
#include <exception>
#include <variant>
#include <memory>
#include <type_traits>
#include <vector>

Expand All @@ -28,6 +28,9 @@ struct awaiter_type;
template<typename T>
using awaiter_type_t = typename awaiter_type<T>::type;

class TaskPromiseBase;
using CoroutineHandle = std::variant<std::coroutine_handle<TaskPromiseBase>, std::coroutine_handle<>>;

//! Continuation that resumes a coroutine co_awaiting on currently finished coroutine.
class TaskFinalSuspend {
public:
Expand All @@ -36,7 +39,7 @@ class TaskFinalSuspend {
* \param[in] awaitingCoroutine handle of the coroutine that is co_awaiting the current
* coroutine (continuation).
*/
explicit TaskFinalSuspend(const std::vector<std::coroutine_handle<>> &awaitingCoroutines);
explicit TaskFinalSuspend(const std::vector<CoroutineHandle> &awaitingCoroutines);

//! Returns whether the just finishing coroutine should do final suspend or not
/*!
Expand Down Expand Up @@ -68,7 +71,7 @@ class TaskFinalSuspend {
constexpr void await_resume() const noexcept;

private:
std::vector<std::coroutine_handle<>> mAwaitingCoroutines;
std::vector<CoroutineHandle> mAwaitingCoroutines;
};

//! Base class for the \c Task<T> promise_type.
Expand Down Expand Up @@ -195,7 +198,8 @@ class TaskPromiseBase {
* represented by this promise. When our coroutine finishes, it's
* our job to resume the awaiting coroutine.
*/
void addAwaitingCoroutine(std::coroutine_handle<> awaitingCoroutine);
template<typename T>
void addAwaitingCoroutine(std::coroutine_handle<T> awaitingCoroutine);

bool hasAwaitingCoroutine() const;

Expand All @@ -205,7 +209,7 @@ class TaskPromiseBase {

CoroutineFeatures &features() noexcept;

void destroyCoroutine();
void destroyCoroutine(bool wakeUpAwaiters = false);

protected:
TaskPromiseBase();
Expand All @@ -214,7 +218,7 @@ class TaskPromiseBase {
friend class TaskFinalSuspend;

//! Handle of the coroutine that is currently co_awaiting this Awaitable
std::vector<std::coroutine_handle<>> mAwaitingCoroutines;
std::vector<CoroutineHandle> mAwaitingCoroutines;

//! Indicates whether we can destroy the coroutine handle
std::atomic<int> mRefCount{0};
Expand Down Expand Up @@ -345,6 +349,9 @@ class TaskAwaiterBase {
* co_awaited coroutine has finished synchronously and the co_awaiting coroutine doesn't
* have to suspend.
*/
template<typename T>
void await_suspend(std::coroutine_handle<TaskPromise<T>> awaitingCoroutine) noexcept;

void await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept;

protected:
Expand Down
2 changes: 0 additions & 2 deletions tests/qcorotask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,9 @@ class QCoroTaskTest : public QCoro::TestObject<QCoroTaskTest>
features.guardThis(obj);

QTimer::singleShot(0, [obj]() {
qDebug() << "Destroyed!";
delete obj;
});
co_await timer();
qDebug() << "RESUMED";

notCalled = false;
};
Expand Down

0 comments on commit 2dd5a6c

Please sign in to comment.