1
0
forked from 0ad/0ad

Split Receiver from SharedState

The Function is not restricted to std::function anymore. Move only
function become possible.

Differential Revision: https://code.wildfiregames.com/D4840
This was SVN commit r27962.
This commit is contained in:
phosit 2023-11-30 09:20:35 +00:00
parent 5ce3478317
commit 6ee136dd11
2 changed files with 69 additions and 40 deletions

View File

@ -27,7 +27,7 @@
#include <optional> #include <optional>
#include <type_traits> #include <type_traits>
template<typename ResultType> template<typename Callback>
class PackagedTask; class PackagedTask;
namespace FutureSharedStateDetail namespace FutureSharedStateDetail
@ -48,15 +48,14 @@ using ResultHolder = std::conditional_t<std::is_void_v<T>, std::nullopt_t, std::
* Holds all relevant data. * Holds all relevant data.
*/ */
template<typename ResultType> template<typename ResultType>
class SharedState : public ResultHolder<ResultType> class Receiver : public ResultHolder<ResultType>
{ {
static constexpr bool VoidResult = std::is_same_v<ResultType, void>; static constexpr bool VoidResult = std::is_same_v<ResultType, void>;
public: public:
SharedState(std::function<ResultType()>&& func) : Receiver() :
ResultHolder<ResultType>{std::nullopt}, ResultHolder<ResultType>{std::nullopt}
m_Func(std::move(func))
{} {}
~SharedState() ~Receiver()
{ {
// For safety, wait on started task completion, but not on pending ones (auto-cancelled). // For safety, wait on started task completion, but not on pending ones (auto-cancelled).
if (!Cancel()) if (!Cancel())
@ -66,8 +65,8 @@ public:
} }
} }
SharedState(const SharedState&) = delete; Receiver(const Receiver&) = delete;
SharedState(SharedState&&) = delete; Receiver(Receiver&&) = delete;
bool IsDoneOrCanceled() const bool IsDoneOrCanceled() const
{ {
@ -122,8 +121,17 @@ public:
std::atomic<Status> m_Status = Status::PENDING; std::atomic<Status> m_Status = Status::PENDING;
std::mutex m_Mutex; std::mutex m_Mutex;
std::condition_variable m_ConditionVariable; std::condition_variable m_ConditionVariable;
};
std::function<ResultType()> m_Func; template<typename Callback>
struct SharedState
{
SharedState(Callback&& callbackFunc) :
callback{std::forward<Callback>(callbackFunc)}
{}
Callback callback;
Receiver<std::invoke_result_t<Callback>> receiver;
}; };
} // namespace FutureSharedStateDetail } // namespace FutureSharedStateDetail
@ -150,7 +158,6 @@ class Future
static constexpr bool VoidResult = std::is_same_v<ResultType, void>; static constexpr bool VoidResult = std::is_same_v<ResultType, void>;
using Status = FutureSharedStateDetail::Status; using Status = FutureSharedStateDetail::Status;
using SharedState = FutureSharedStateDetail::SharedState<ResultType>;
public: public:
Future() = default; Future() = default;
Future(const Future& o) = delete; Future(const Future& o) = delete;
@ -161,8 +168,8 @@ public:
/** /**
* Make the future wait for the result of @a func. * Make the future wait for the result of @a func.
*/ */
template<typename T> template<typename Callback>
PackagedTask<ResultType> Wrap(T&& func); PackagedTask<Callback> Wrap(Callback&& callback);
/** /**
* Move the result out of the future, and invalidate the future. * Move the result out of the future, and invalidate the future.
@ -172,17 +179,17 @@ public:
template<typename SfinaeType = ResultType> template<typename SfinaeType = ResultType>
std::enable_if_t<!std::is_same_v<SfinaeType, void>, ResultType> Get() std::enable_if_t<!std::is_same_v<SfinaeType, void>, ResultType> Get()
{ {
ENSURE(!!m_SharedState); ENSURE(!!m_Receiver);
Wait(); Wait();
if constexpr (VoidResult) if constexpr (VoidResult)
return; return;
else else
{ {
ENSURE(m_SharedState->m_Status != Status::CANCELED); ENSURE(m_Receiver->m_Status != Status::CANCELED);
// This mark the state invalid - can't call Get again. // This mark the state invalid - can't call Get again.
return m_SharedState->GetResult(); return m_Receiver->GetResult();
} }
} }
@ -191,7 +198,7 @@ public:
*/ */
bool IsReady() const bool IsReady() const
{ {
return !!m_SharedState && m_SharedState->m_Status == Status::DONE; return !!m_Receiver && m_Receiver->m_Status == Status::DONE;
} }
/** /**
@ -199,13 +206,13 @@ public:
*/ */
bool Valid() const bool Valid() const
{ {
return !!m_SharedState && m_SharedState->m_Status != Status::CANCELED; return !!m_Receiver && m_Receiver->m_Status != Status::CANCELED;
} }
void Wait() void Wait()
{ {
if (Valid()) if (Valid())
m_SharedState->Wait(); m_Receiver->Wait();
} }
/** /**
@ -217,13 +224,13 @@ public:
{ {
if (!Valid()) if (!Valid())
return; return;
if (!m_SharedState->Cancel()) if (!m_Receiver->Cancel())
m_SharedState->Wait(); m_Receiver->Wait();
m_SharedState.reset(); m_Receiver.reset();
} }
protected: protected:
std::shared_ptr<SharedState> m_SharedState; std::shared_ptr<FutureSharedStateDetail::Receiver<ResultType>> m_Receiver;
}; };
/** /**
@ -232,35 +239,39 @@ protected:
* This type is mostly just the shared state and the call operator, * This type is mostly just the shared state and the call operator,
* handling the promise & continuation logic. * handling the promise & continuation logic.
*/ */
template<typename ResultType> template<typename Callback>
class PackagedTask class PackagedTask
{ {
static constexpr bool VoidResult = std::is_same_v<ResultType, void>;
public: public:
PackagedTask() = delete; PackagedTask() = delete;
PackagedTask(std::shared_ptr<typename Future<ResultType>::SharedState> ss) : m_SharedState(std::move(ss)) {} PackagedTask(std::shared_ptr<FutureSharedStateDetail::SharedState<Callback>> ss) :
m_SharedState(std::move(ss))
{}
void operator()() void operator()()
{ {
typename Future<ResultType>::Status expected = Future<ResultType>::Status::PENDING; FutureSharedStateDetail::Status expected = FutureSharedStateDetail::Status::PENDING;
if (!m_SharedState->m_Status.compare_exchange_strong(expected, Future<ResultType>::Status::STARTED)) if (!m_SharedState->receiver.m_Status.compare_exchange_strong(expected,
FutureSharedStateDetail::Status::STARTED))
{
return; return;
}
if constexpr (VoidResult) if constexpr (std::is_void_v<std::invoke_result_t<Callback>>)
m_SharedState->m_Func(); m_SharedState->callback();
else else
m_SharedState->emplace(m_SharedState->m_Func()); m_SharedState->receiver.emplace(m_SharedState->callback());
// Because we might have threads waiting on us, we need to make sure that they either: // Because we might have threads waiting on us, we need to make sure that they either:
// - don't wait on our condition variable // - don't wait on our condition variable
// - receive the notification when we're done. // - receive the notification when we're done.
// This requires locking the mutex (@see Wait). // This requires locking the mutex (@see Wait).
{ {
std::lock_guard<std::mutex> lock(m_SharedState->m_Mutex); std::lock_guard<std::mutex> lock(m_SharedState->receiver.m_Mutex);
m_SharedState->m_Status = Future<ResultType>::Status::DONE; m_SharedState->receiver.m_Status = FutureSharedStateDetail::Status::DONE;
} }
m_SharedState->m_ConditionVariable.notify_all(); m_SharedState->receiver.m_ConditionVariable.notify_all();
// We no longer need the shared state, drop it immediately. // We no longer need the shared state, drop it immediately.
m_SharedState.reset(); m_SharedState.reset();
@ -272,18 +283,19 @@ public:
m_SharedState.reset(); m_SharedState.reset();
} }
protected: private:
std::shared_ptr<typename Future<ResultType>::SharedState> m_SharedState; std::shared_ptr<FutureSharedStateDetail::SharedState<Callback>> m_SharedState;
}; };
template<typename ResultType> template<typename ResultType>
template<typename T> template<typename Callback>
PackagedTask<ResultType> Future<ResultType>::Wrap(T&& func) PackagedTask<Callback> Future<ResultType>::Wrap(Callback&& callback)
{ {
static_assert(std::is_same_v<std::invoke_result_t<T>, ResultType>, static_assert(std::is_same_v<std::invoke_result_t<Callback>, ResultType>,
"The return type of the wrapped function is not the same as the type the Future expects."); "The return type of the wrapped function is not the same as the type the Future expects.");
m_SharedState = std::make_shared<SharedState>(std::move(func)); auto temp = std::make_shared<FutureSharedStateDetail::SharedState<Callback>>(std::move(callback));
return PackagedTask<ResultType>(m_SharedState); m_Receiver = {temp, &temp->receiver};
return PackagedTask<Callback>(std::move(temp));
} }
#endif // INCLUDED_FUTURE #endif // INCLUDED_FUTURE

View File

@ -124,4 +124,21 @@ public:
task2(); task2();
TS_ASSERT_EQUALS(future.Get(), 7); TS_ASSERT_EQUALS(future.Get(), 7);
} }
void test_move_only_function()
{
Future<void> future;
class MoveOnlyType
{
public:
MoveOnlyType() = default;
MoveOnlyType(MoveOnlyType&) = delete;
MoveOnlyType& operator=(MoveOnlyType&) = delete;
MoveOnlyType(MoveOnlyType&&) = default;
MoveOnlyType& operator=(MoveOnlyType&&) = default;
};
future.Wrap([t = MoveOnlyType{}]{});
}
}; };