Skip to content

Commit

Permalink
Add mutate_apply for db::Access
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsdeppe committed Feb 8, 2024
1 parent 7bf17a3 commit b4cffe3
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 65 deletions.
100 changes: 100 additions & 0 deletions src/DataStructures/DataBox/Access.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

#include <string>

#include "DataStructures/DataBox/DataBoxTag.hpp"
#include "DataStructures/DataBox/IsApplyCallable.hpp"
#include "DataStructures/DataBox/TagName.hpp"
#include "Utilities/CleanupRoutine.hpp"
#include "Utilities/ErrorHandling/Error.hpp"
#include "Utilities/ForceInline.hpp"
#include "Utilities/Gsl.hpp"
#include "Utilities/PrettyType.hpp"
#include "Utilities/TMPL.hpp"
#include "Utilities/TypeTraits/IsCallable.hpp"

namespace db {
/*!
Expand Down Expand Up @@ -92,4 +95,101 @@ decltype(auto) mutate(Invokable&& invokable, const gsl::not_null<Access*> box,
return invokable(box->template mutate<MutateTags>()...,
std::forward<Args>(args)...);
}

namespace detail {
template <typename... ReturnTags, typename... ArgumentTags, typename F,
typename... Args>
SPECTRE_ALWAYS_INLINE constexpr decltype(auto) mutate_apply(
F&& f, const gsl::not_null<Access*> box, tmpl::list<ReturnTags...> /*meta*/,
tmpl::list<ArgumentTags...> /*meta*/, Args&&... args) {
static_assert(not(... or std::is_same_v<ArgumentTags, Tags::DataBox>),
"Cannot pass Tags::DataBox to mutate_apply when mutating "
"since the db::get won't work inside mutate_apply.");
if constexpr (detail::is_apply_callable_v<
F, const gsl::not_null<typename ReturnTags::type*>...,
const_item_type<ArgumentTags, tmpl::list<>>..., Args...>) {
return ::db::mutate<ReturnTags...>(
[](const gsl::not_null<typename ReturnTags::type*>... mutated_items,
const_item_type<ArgumentTags, tmpl::list<>>... args_items,
decltype(std::forward<Args>(args))... l_args) {
return std::decay_t<F>::apply(mutated_items..., args_items...,
std::forward<Args>(l_args)...);
},
box, db::get<ArgumentTags>(*box)..., std::forward<Args>(args)...);
} else if constexpr (::tt::is_callable_v<
F,
const gsl::not_null<typename ReturnTags::type*>...,
const_item_type<ArgumentTags, tmpl::list<>>...,
Args...>) {
return ::db::mutate<ReturnTags...>(f, box, db::get<ArgumentTags>(*box)...,
std::forward<Args>(args)...);
} else {
error_function_not_callable<F, gsl::not_null<typename ReturnTags::type*>...,
const_item_type<ArgumentTags, tmpl::list<>>...,
Args...>();
}
}
} // namespace detail

/// @{
/*!
* \ingroup DataBoxGroup
* \brief Apply the invokable `f` mutating items `MutateTags` and taking as
* additional arguments `ArgumentTags` and `args`.
*
* \details
* `f` must either be invokable with the arguments of type
* `gsl::not_null<db::item_type<MutateTags>*>...,
* db::const_item_type<ArgumentTags>..., Args...`
* where the first two pack expansions are over the elements in the typelists
* `MutateTags` and `ArgumentTags`, or have a static `apply` function that is
* callable with the same types. If the type of `f` specifies `return_tags` and
* `argument_tags` typelists, these are used for the `MutateTags` and
* `ArgumentTags`, respectively.
*
* Any return values of the invokable `f` are forwarded as returns to the
* `mutate_apply` call.
*
* \example
* An example of using `mutate_apply` with a lambda:
* \snippet Test_DataBox.cpp mutate_apply_lambda_example
*
* An example of a class with a static `apply` function
* \snippet Test_DataBox.cpp mutate_apply_struct_definition_example
* and how to use `mutate_apply` with the above class
* \snippet Test_DataBox.cpp mutate_apply_struct_example_stateful
* Note that the class exposes `return_tags` and `argument_tags` typelists, so
* we don't specify the template parameters explicitly.
* If the class `F` has no state, like in this example,
* \snippet Test_DataBox.cpp mutate_apply_struct_definition_example
* you can also use the stateless overload of `mutate_apply`:
* \snippet Test_DataBox.cpp mutate_apply_struct_example_stateless
*
* \tparam MutateTags typelist of Tags to mutate
* \tparam ArgumentTags typelist of additional items to retrieve from the
* `Access`
* \tparam F The invokable to apply
*/
template <typename MutateTags, typename ArgumentTags, typename F,
typename... Args>
SPECTRE_ALWAYS_INLINE constexpr decltype(auto) mutate_apply(
F&& f, const gsl::not_null<Access*> box, Args&&... args) {
return detail::mutate_apply(std::forward<F>(f), box, MutateTags{},
ArgumentTags{}, std::forward<Args>(args)...);
}

template <typename F, typename... Args>
SPECTRE_ALWAYS_INLINE constexpr decltype(auto) mutate_apply(
F&& f, const gsl::not_null<Access*> box, Args&&... args) {
return mutate_apply<typename std::decay_t<F>::return_tags,
typename std::decay_t<F>::argument_tags>(
std::forward<F>(f), box, std::forward<Args>(args)...);
}

template <typename F, typename... Args>
SPECTRE_ALWAYS_INLINE constexpr decltype(auto) mutate_apply(
const gsl::not_null<Access*> box, Args&&... args) {
return mutate_apply(F{}, box, std::forward<Args>(args)...);
}
/// @}
} // namespace db
168 changes: 103 additions & 65 deletions tests/Unit/DataStructures/DataBox/Test_DataBox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1384,24 +1384,9 @@ struct PointerMutateApplyBase {
}
};

void test_mutate_apply() {
INFO("test mutate apply");
auto box = db::create<
db::AddSimpleTags<
test_databox_tags::Tag0, test_databox_tags::Tag1,
test_databox_tags::Tag2,
Tags::Variables<tmpl::list<test_databox_tags::ScalarTag,
test_databox_tags::VectorTag>>,
test_databox_tags::Pointer>,
db::AddComputeTags<test_databox_tags::Tag4Compute,
test_databox_tags::Tag5Compute,
test_databox_tags::MultiplyScalarByTwoCompute,
test_databox_tags::PointerToCounterCompute,
test_databox_tags::PointerToSumCompute>>(
3.14, std::vector<double>{8.7, 93.2, 84.7}, "My Sample String"s,
Variables<tmpl::list<test_databox_tags::ScalarTag,
test_databox_tags::VectorTag>>(2, 3.),
std::make_unique<int>(3));
template <typename T>
void test_mutate_apply(T& box) {
constexpr bool using_db_access = std::is_same_v<std::decay_t<T>, db::Access>;
CHECK(approx(db::get<test_databox_tags::Tag4>(box)) == 3.14 * 2.0);
CHECK(db::get<test_databox_tags::ScalarTag>(box) ==
Scalar<DataVector>(DataVector(2, 3.)));
Expand All @@ -1417,7 +1402,9 @@ void test_mutate_apply() {
// [mutate_apply_struct_example_stateful]
db::mutate_apply(TestDataboxMutateApply{}, make_not_null(&box));
// [mutate_apply_struct_example_stateful]
db::mutate_apply(TestDataboxMutateApplyBase{}, make_not_null(&box));
if constexpr (not using_db_access) {
db::mutate_apply(TestDataboxMutateApplyBase{}, make_not_null(&box));
}

CHECK(approx(db::get<test_databox_tags::Tag4>(box)) == 3.14 * 2.0);
CHECK(db::get<test_databox_tags::ScalarTag>(box) ==
Expand All @@ -1444,14 +1431,16 @@ void test_mutate_apply() {
},
make_not_null(&box));
// [mutate_apply_lambda_example]
db::mutate_apply<tmpl::list<test_databox_tags::ScalarTag>,
tmpl::list<test_databox_tags::Tag2Base>>(
[](const gsl::not_null<Scalar<DataVector>*> scalar,
const std::string& tag2) {
CHECK(*scalar == Scalar<DataVector>(DataVector(2, 12.)));
CHECK(tag2 == "My Sample String"s);
},
make_not_null(&box));
if constexpr (not using_db_access) {
db::mutate_apply<tmpl::list<test_databox_tags::ScalarTag>,
tmpl::list<test_databox_tags::Tag2Base>>(
[](const gsl::not_null<Scalar<DataVector>*> scalar,
const std::string& tag2) {
CHECK(*scalar == Scalar<DataVector>(DataVector(2, 12.)));
CHECK(tag2 == "My Sample String"s);
},
make_not_null(&box));
}
CHECK(approx(db::get<test_databox_tags::Tag4>(box)) == 3.14 * 2.0);
CHECK(db::get<test_databox_tags::ScalarTag>(box) ==
Scalar<DataVector>(DataVector(2, 12.)));
Expand All @@ -1463,17 +1452,31 @@ void test_mutate_apply() {
CHECK(db::get<test_databox_tags::VectorTag2>(box) ==
(tnsr::I<DataVector, 3>(DataVector(2, 2.))));
// check with a forwarded return value
size_t size_of_internal_string =
db::mutate_apply<tmpl::list<test_databox_tags::ScalarTag>,
tmpl::list<test_databox_tags::Tag2Base>>(
[](const gsl::not_null<Scalar<DataVector>*> scalar,
const std::string& tag2) {
CHECK(*scalar == Scalar<DataVector>(DataVector(2, 12.)));
CHECK(tag2 == "My Sample String"s);
return tag2.size();
},
make_not_null(&box));
CHECK(size_of_internal_string == 16_st);
if constexpr (not using_db_access) {
size_t size_of_internal_string =
db::mutate_apply<tmpl::list<test_databox_tags::ScalarTag>,
tmpl::list<test_databox_tags::Tag2Base>>(
[](const gsl::not_null<Scalar<DataVector>*> scalar,
const std::string& tag2) {
CHECK(*scalar == Scalar<DataVector>(DataVector(2, 12.)));
CHECK(tag2 == "My Sample String"s);
return tag2.size();
},
make_not_null(&box));
CHECK(size_of_internal_string == 16_st);
} else {
size_t size_of_internal_string =
db::mutate_apply<tmpl::list<test_databox_tags::ScalarTag>,
tmpl::list<test_databox_tags::Tag2>>(
[](const gsl::not_null<Scalar<DataVector>*> scalar,
const std::string& tag2) {
CHECK(*scalar == Scalar<DataVector>(DataVector(2, 12.)));
CHECK(tag2 == "My Sample String"s);
return tag2.size();
},
make_not_null(&box));
CHECK(size_of_internal_string == 16_st);
}

db::mutate_apply<
tmpl::list<Tags::Variables<tmpl::list<test_databox_tags::ScalarTag,
Expand Down Expand Up @@ -1538,52 +1541,55 @@ void test_mutate_apply() {
db::mutate_apply<tmpl::list<test_databox_tags::Pointer>, tmpl::list<>>(
[](const gsl::not_null<std::unique_ptr<int>*> p) { **p = 6; },
make_not_null(&box));
db::mutate_apply<tmpl::list<>,
tmpl::list<test_databox_tags::PointerBase,
test_databox_tags::PointerToCounterBase>>(
[](const int& simple_base, const int& compute_base) {
CHECK(simple_base == 6);
CHECK(compute_base == 7);
},
make_not_null(&box));
if constexpr (not using_db_access) {
db::mutate_apply<tmpl::list<>,
tmpl::list<test_databox_tags::PointerBase,
test_databox_tags::PointerToCounterBase>>(
[](const int& simple_base, const int& compute_base) {
CHECK(simple_base == 6);
CHECK(compute_base == 7);
},
make_not_null(&box));
}

db::mutate_apply<PointerMutateApply>(make_not_null(&box));
CHECK(db::get<test_databox_tags::Pointer>(box) == 7);

db::mutate_apply<PointerMutateApplyBase>(make_not_null(&box));
CHECK(db::get<test_databox_tags::Pointer>(box) == 8);
if constexpr (not using_db_access) {
db::mutate_apply<PointerMutateApplyBase>(make_not_null(&box));
CHECK(db::get<test_databox_tags::Pointer>(box) == 8);
}
}

{
if constexpr (not using_db_access) {
INFO("Tags::DataBox");
db::mutate_apply<tmpl::list<>,
tmpl::list<::Tags::DataBox, test_databox_tags::Tag0>>(
[](const decltype(box)& /*box*/, const double& /*tag0*/) {},
make_not_null(&box));
[](const T& /*box*/, const double& /*tag0*/) {}, make_not_null(&box));
db::mutate_apply<tmpl::list<::Tags::DataBox>,
tmpl::list<test_databox_tags::Tag0>>(
[](const gsl::not_null<decltype(box)*> /*box*/,
const double& /*tag0*/) {},
[](const gsl::not_null<T*> /*box*/, const double& /*tag0*/) {},
make_not_null(&box));

struct ReadApply {
using return_tags = tmpl::list<>;
using argument_tags = tmpl::list<::Tags::DataBox>;
static void apply(const decltype(box)& /*box*/) {}
using return_tags [[maybe_unused]] = tmpl::list<>;
using argument_tags [[maybe_unused]] = tmpl::list<::Tags::DataBox>;
static void apply(const T& /*box*/) {}
};
struct ReadCall {
using return_tags = tmpl::list<>;
using argument_tags = tmpl::list<::Tags::DataBox>;
void operator()(const decltype(box)& /*box*/) const {}
using return_tags [[maybe_unused]] = tmpl::list<>;
using argument_tags [[maybe_unused]] = tmpl::list<::Tags::DataBox>;
void operator()(const T& /*box*/) const {}
};
struct WriteApply {
using return_tags = tmpl::list<::Tags::DataBox>;
using argument_tags = tmpl::list<>;
static void apply(const gsl::not_null<decltype(box)*> /*box*/) {}
using return_tags [[maybe_unused]] = tmpl::list<::Tags::DataBox>;
using argument_tags [[maybe_unused]] = tmpl::list<>;
static void apply(const gsl::not_null<T*> /*box*/) {}
};
struct WriteCall {
using return_tags = tmpl::list<::Tags::DataBox>;
using argument_tags = tmpl::list<>;
void operator()(const gsl::not_null<decltype(box)*> /*box*/) const {}
using return_tags [[maybe_unused]] = tmpl::list<::Tags::DataBox>;
using argument_tags [[maybe_unused]] = tmpl::list<>;
void operator()(const gsl::not_null<T*> /*box*/) const {}
};

db::mutate_apply<ReadApply>(make_not_null(&box));
Expand All @@ -1593,6 +1599,38 @@ void test_mutate_apply() {
}
}

void test_mutate_apply() {
INFO("test mutate apply");
const auto create_box = []() {
return db::create<
db::AddSimpleTags<
test_databox_tags::Tag0, test_databox_tags::Tag1,
test_databox_tags::Tag2,
Tags::Variables<tmpl::list<test_databox_tags::ScalarTag,
test_databox_tags::VectorTag>>,
test_databox_tags::Pointer>,
db::AddComputeTags<test_databox_tags::Tag4Compute,
test_databox_tags::Tag5Compute,
test_databox_tags::MultiplyScalarByTwoCompute,
test_databox_tags::PointerToCounterCompute,
test_databox_tags::PointerToSumCompute>>(
3.14, std::vector<double>{8.7, 93.2, 84.7}, "My Sample String"s,
Variables<tmpl::list<test_databox_tags::ScalarTag,
test_databox_tags::VectorTag>>(2, 3.),
std::make_unique<int>(3));
};
{
INFO("DataBox");
auto box = create_box();
test_mutate_apply(box);
}
{
INFO("Access");
auto box = create_box();
test_mutate_apply(db::as_access(box));
}
}

static_assert(
std::is_same_v<
db::compute_databox_type<tmpl::list<
Expand Down

0 comments on commit b4cffe3

Please sign in to comment.