Skip to content

Commit

Permalink
Add frame transformation overload for Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsvu committed Dec 8, 2023
1 parent 7f66ac1 commit f5a932b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/DataStructures/Variables/FrameTransform.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Distributed under the MIT License.
// See LICENSE.txt for details.

#pragma once

#include <cstddef>

#include "DataStructures/Tensor/EagerMath/FrameTransform.hpp"
#include "DataStructures/Tensor/Metafunctions.hpp"
#include "DataStructures/Tensor/TypeAliases.hpp"
#include "DataStructures/Variables.hpp"
#include "Utilities/Gsl.hpp"

namespace transform {

template <typename... ResultTags, typename... InputTags, size_t Dim,
typename SourceFrame, typename TargetFrame>
void first_index_to_different_frame(
const gsl::not_null<Variables<tmpl::list<ResultTags...>>*> result,
const Variables<tmpl::list<InputTags...>>& input,
const InverseJacobian<DataVector, Dim, SourceFrame, TargetFrame>&
inv_jacobian) {
EXPAND_PACK_LEFT_TO_RIGHT(
first_index_to_different_frame(make_not_null(&get<ResultTags>(*result)),
get<InputTags>(input), inv_jacobian));
}

namespace Tags {
template <typename Tag, typename FirstIndexFrame>
struct TransformedFirstIndex : db::SimpleTag {
using type = TensorMetafunctions::prepend_spatial_index<
TensorMetafunctions::remove_first_index<typename Tag::type>,
tmpl::front<typename Tag::type::index_list>::dim, UpLo::Up,
FirstIndexFrame>;
};
} // namespace Tags

template <typename... InputTags, size_t Dim, typename SourceFrame,
typename TargetFrame,
typename ResultVars = Variables<tmpl::list<
Tags::TransformedFirstIndex<InputTags, SourceFrame>...>>>
ResultVars first_index_to_different_frame(
const Variables<tmpl::list<InputTags...>>& input,
const InverseJacobian<DataVector, Dim, SourceFrame, TargetFrame>&
inv_jacobian) {
ResultVars result{input.number_of_grid_points()};
first_index_to_different_frame(make_not_null(&result), input, inv_jacobian);
return result;
}

} // namespace transform
35 changes: 35 additions & 0 deletions tests/Unit/DataStructures/Tensor/EagerMath/Test_FrameTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@
#include "DataStructures/Tensor/EagerMath/DeterminantAndInverse.hpp"
#include "DataStructures/Tensor/EagerMath/FrameTransform.hpp"
#include "DataStructures/Tensor/Tensor.hpp"
#include "DataStructures/Variables/FrameTransform.hpp"
#include "Framework/CheckWithRandomValues.hpp"
#include "Framework/SetupLocalPythonEnvironment.hpp"
#include "Helpers/DataStructures/MakeWithRandomValues.hpp"

namespace {
struct Var1 : db::SimpleTag {
using type = tnsr::I<DataVector, 2, Frame::Inertial>;
};
struct Var2 : db::SimpleTag {
using type = tnsr::Ij<DataVector, 2, Frame::Inertial>;
};
} // namespace

namespace transform {

template <typename SrcTensorType, typename DestTensorType, typename DataType,
Expand Down Expand Up @@ -160,6 +170,31 @@ void test_transform_first_index_to_different_frame() {
CHECK(get<1, 1>(result) == 16.5);
}
}
{
INFO("Variables");
const size_t num_points = 3;
InverseJacobian<DataVector, 2, Frame::ElementLogical, Frame::Inertial>
inv_jacobian{num_points};
get<0, 0>(inv_jacobian) = 2.0;
get<1, 1>(inv_jacobian) = 3.0;
get<0, 1>(inv_jacobian) = 0.5;
get<1, 0>(inv_jacobian) = 1.5;
Variables<tmpl::list<Var1, Var2>> input{num_points};
std::iota(get<Var1>(input).begin(), get<Var1>(input).end(), 1.0);
std::iota(get<Var2>(input).begin(), get<Var2>(input).end(), 1.0);
CAPTURE(input);
const auto result = first_index_to_different_frame(input, inv_jacobian);
const auto& var1 =
get<Tags::TransformedFirstIndex<Var1, Frame::ElementLogical>>(result);
const auto& var2 =
get<Tags::TransformedFirstIndex<Var2, Frame::ElementLogical>>(result);
CHECK(get<0>(var1) == 3.);
CHECK(get<1>(var1) == 7.5);
CHECK(get<0, 0>(var2) == 3.);
CHECK(get<1, 0>(var2) == 7.5);
CHECK(get<0, 1>(var2) == 8.);
CHECK(get<1, 1>(var2) == 16.5);
}
}

SPECTRE_TEST_CASE("Unit.PointwiseFunctions.GeneralRelativity.Transform",
Expand Down

0 comments on commit f5a932b

Please sign in to comment.