Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sync_wait a proper CPO. #369

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions include/unifex/sync_wait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,32 +140,23 @@ std::optional<Result> _impl(Sender&& sender) {
}
} // namespace _sync_wait

namespace _sync_wait_cpo {
namespace _sync_wait_r_cpo {
template <typename Result>
struct _fn {
template(typename Sender)
(requires typed_sender<Sender>)
auto operator()(Sender&& sender) const
-> std::optional<sender_single_value_result_t<remove_cvref_t<Sender>>> {
using Result = sender_single_value_result_t<remove_cvref_t<Sender>>;
return _sync_wait::_impl<Result>((Sender&&) sender);
(requires sender<Sender>)
decltype(auto) operator()(Sender&& sender) const {
return tag_invoke(_fn{}, (Sender &&) sender);
}
constexpr auto operator()() const
noexcept(is_nothrow_callable_v<
tag_t<bind_back>, _fn>)
-> bind_back_result_t<_fn> {
noexcept(is_nothrow_callable_v<tag_t<bind_back>, _fn>)
-> bind_back_result_t<_fn> {
return bind_back(*this);
}
};
} // namespace _sync_wait_cpo

inline constexpr _sync_wait_cpo::_fn sync_wait {};

namespace _sync_wait_r_cpo {
template <typename Result>
struct _fn {
template(typename Sender)
(requires sender<Sender>)
decltype(auto) operator()(Sender&& sender) const {
friend decltype(auto) tag_invoke(_fn, Sender&& sender) {
using Result2 = non_void_t<wrap_reference_t<decay_rvalue_t<Result>>>;
return _sync_wait::_impl<Result2>((Sender&&) sender);
}
Expand All @@ -175,6 +166,33 @@ namespace _sync_wait_r_cpo {
template <typename Result>
inline constexpr _sync_wait_r_cpo::_fn<Result> sync_wait_r {};

namespace _sync_wait_cpo {
struct _fn {
template(typename Sender)
(requires typed_sender<Sender>)
auto operator()(Sender&& sender) const
-> std::optional<sender_single_value_result_t<remove_cvref_t<Sender>>> {
return tag_invoke(_fn{}, (Sender &&) sender);
}
constexpr auto operator()() const
noexcept(is_nothrow_callable_v<
tag_t<bind_back>, _fn>)
-> bind_back_result_t<_fn> {
return bind_back(*this);
}

template(typename Sender)
(requires typed_sender<Sender>)
friend auto tag_invoke(_fn, Sender&& sender)
-> std::optional<sender_single_value_result_t<remove_cvref_t<Sender>>> {
using Result = sender_single_value_result_t<remove_cvref_t<Sender>>;
return tag_invoke(tag_t<sync_wait_r<Result>>{}, (Sender&&) sender);
}
};
} // namespace _sync_wait_cpo

inline constexpr _sync_wait_cpo::_fn sync_wait {};

} // namespace unifex

#include <unifex/detail/epilogue.hpp>
133 changes: 133 additions & 0 deletions test/sync_wait_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language go4verning permissions and
* limitations under the License.
*/
#include <unifex/just.hpp>
#include <unifex/let_done.hpp>
#include <unifex/scheduler_concepts.hpp>
#include <unifex/sequence.hpp>
#include <unifex/stop_when.hpp>
#include <unifex/sync_wait.hpp>
#include <unifex/then.hpp>
#include <unifex/timed_single_thread_context.hpp>

#include <chrono>
#include <iostream>

#include <gtest/gtest.h>

using namespace unifex;

template <typename R>
struct CpoTestSenderOp {
void start() noexcept { set_value(std::move(rec), 12); }

R rec;
};

struct CpoTestSenderSyncWaitR {
template <
template <typename...>
class Variant,
template <typename...>
class Tuple>
using value_types = Variant<Tuple<int>>;

template <template <typename...> class Variant>
using error_types = Variant<>;

static constexpr bool sends_done = false;

friend auto tag_invoke(tag_t<sync_wait_r<int>>, CpoTestSenderSyncWaitR) {
return std::make_optional<int>(42);
}

template <typename Receiver>
friend auto
tag_invoke(tag_t<connect>, CpoTestSenderSyncWaitR, Receiver&& rec) {
return CpoTestSenderOp<Receiver>{(Receiver &&) rec};
}
};

struct CpoTestSenderSyncWait {
template <
template <typename...>
class Variant,
template <typename...>
class Tuple>
using value_types = Variant<Tuple<int>>;

template <template <typename...> class Variant>
using error_types = Variant<>;

static constexpr bool sends_done = false;

friend auto tag_invoke(tag_t<sync_wait>, CpoTestSenderSyncWait) {
return std::make_optional<int>(42);
}

template <typename Receiver>
friend auto
tag_invoke(tag_t<connect>, CpoTestSenderSyncWait, Receiver&& rec) {
return CpoTestSenderOp<Receiver>{(Receiver &&) rec};
}
};

TEST(SyncWait, CpoSyncWaitR) {
// CpoTestSenderSyncWaitR redefines `sync_wait_r<int>` and this also affects
// `sync_wait`.
std::optional<int> i = sync_wait_r<int>(CpoTestSenderSyncWaitR{});
ASSERT_TRUE(i.has_value());
EXPECT_EQ(*i, 42);

std::optional<int> j = sync_wait(CpoTestSenderSyncWaitR{});
ASSERT_TRUE(j.has_value());
EXPECT_EQ(*j, 42);
}

TEST(SyncWait, CpoSyncWaitRPiped) {
// CpoTestSenderSyncWaitR redefines `sync_wait_r<int>` and this also affects
// `sync_wait`.
std::optional<int> i = CpoTestSenderSyncWaitR{} | sync_wait_r<int>();
ASSERT_TRUE(i.has_value());
EXPECT_EQ(*i, 42);

std::optional<int> j = CpoTestSenderSyncWaitR{} | sync_wait();
ASSERT_TRUE(j.has_value());
EXPECT_EQ(*j, 42);
}

TEST(SyncWait, CpoSyncWait) {
// CpoTestSenderSyncWaitR redefines `sync_wait` and this does not affects
// `sync_wait_r<int>` which gets the default behaviour of `sync_wait_r`.
std::optional<int> i = sync_wait_r<int>(CpoTestSenderSyncWait{});
ASSERT_TRUE(i.has_value());
EXPECT_EQ(*i, 12);

std::optional<int> j = sync_wait(CpoTestSenderSyncWait{});
ASSERT_TRUE(j.has_value());
EXPECT_EQ(*j, 42);
}

TEST(SyncWait, CpoSyncWaitPiped) {
// CpoTestSenderSyncWaitR redefines `sync_wait` and this does not affects
// `sync_wait_r<int>` which gets the default behaviour of `sync_wait_r`.
std::optional<int> i = CpoTestSenderSyncWait{} | sync_wait_r<int>();
ASSERT_TRUE(i.has_value());
EXPECT_EQ(*i, 12);

std::optional<int> j = CpoTestSenderSyncWait{} | sync_wait();
ASSERT_TRUE(j.has_value());
EXPECT_EQ(*j, 42);
}