From 3ff2e8a95f820ffd5026dfc07acd311ed8100430 Mon Sep 17 00:00:00 2001 From: Luke Roberts Date: Wed, 4 Sep 2024 14:01:06 -0600 Subject: [PATCH] Make flux metadata accessible at WithFlux variable creation point --- src/interface/metadata.cpp | 47 ++++++++++++++++++----- src/interface/metadata.hpp | 30 ++++++++++++--- src/interface/sparse_pool.cpp | 60 +++++++++++++++++++----------- src/interface/state_descriptor.cpp | 22 +---------- 4 files changed, 101 insertions(+), 58 deletions(-) diff --git a/src/interface/metadata.cpp b/src/interface/metadata.cpp index 244a1860363e..36f37adf8b1c 100644 --- a/src/interface/metadata.cpp +++ b/src/interface/metadata.cpp @@ -101,10 +101,13 @@ MetadataFlag Metadata::GetUserFlag(const std::string &flagname) { } namespace parthenon { -Metadata::Metadata(const std::vector &bits, const std::vector &shape, +Metadata::Metadata(const std::vector &bits, + const std::vector &flux_bits, + const std::vector &shape, const std::vector &component_labels, const std::string &associated, - const refinement::RefinementFunctions_t ref_funcs_) + const refinement::RefinementFunctions_t ref_funcs_, + const refinement::RefinementFunctions_t flux_ref_funcs_) : shape_(shape), component_labels_(component_labels), associated_(associated) { // set flags for (const auto f : bits) { @@ -164,6 +167,39 @@ Metadata::Metadata(const std::vector &bits, const std::vector deallocation_threshold_ = 0.0; default_value_ = 0.0; } + + // Now create the flux metadata if required + if (IsSet(WithFluxes)) { + std::set flux_flags; + for (const auto f : flux_bits) + flux_flags.insert(f); + + // Set some standard defaults for the flux metadata if no + // flags were provided + if (flux_flags.size() == 0) { + flux_flags.insert(OneCopy); + if (IsSet(Fine)) flux_flags.insert(Fine); + if (IsSet(Cell)) flux_flags.insert(CellMemAligned); + if (IsSet(Sparse)) flux_flags.insert(Sparse); + } + + // These flags are automatically propagated for fluxes + flux_flags.insert(Flux); + if (IsSet(Cell)) { + flux_flags.insert(Face); + } else if (IsSet(Face)) { + flux_flags.insert(Edge); + } else if (IsSet(Edge)) { + flux_flags.insert(Node); + } + + if (IsSet(Tensor)) flux_flags.insert(Tensor); + if (IsSet(Vector)) flux_flags.insert(Vector); + + std::vector flux_flags_vec(flux_flags.begin(), flux_flags.end()); + flux_metadata = std::make_shared(flux_flags_vec, shape, component_labels, + std::string(), flux_ref_funcs_); + } } std::ostream &operator<<(std::ostream &os, const parthenon::Metadata &m) { @@ -271,13 +307,6 @@ bool Metadata::IsValid(bool throw_on_fail) const { } } - // Associated fluxes - if (IsSet(FluxNotOneCopy)) { - PARTHENON_REQUIRE( - IsSet(WithFluxes), - "Asking for non-OneCopy associated fluxes without asking for associated fluxes."); - } - return valid; } diff --git a/src/interface/metadata.hpp b/src/interface/metadata.hpp index ca8afeddbe54..66d8e6496dc4 100644 --- a/src/interface/metadata.hpp +++ b/src/interface/metadata.hpp @@ -120,8 +120,6 @@ PARTHENON_INTERNAL_FOR_FLAG(Fine) \ /** this variable is the flux for another variable **/ \ PARTHENON_INTERNAL_FOR_FLAG(Flux) \ - /** allocate a separate flux array for each stage if WithFluxes is specified**/ \ - PARTHENON_INTERNAL_FOR_FLAG(FluxNotOneCopy) \ /** Align memory of fields to cell centered memory \ (Field will be missing one layer of ghosts if it is not cell centered) **/ \ PARTHENON_INTERNAL_FOR_FLAG(CellMemAligned) \ @@ -330,28 +328,48 @@ class Metadata { // 4 constructors, this is the general constructor called by all other constructors, so // we do some sanity checks here Metadata( - const std::vector &bits, const std::vector &shape = {}, + const std::vector &bits, const std::vector &flux_bits, + const std::vector &shape = {}, const std::vector &component_labels = {}, const std::string &associated = "", const refinement::RefinementFunctions_t ref_funcs_ = + refinement::RefinementFunctions_t::RegisterOps< + refinement_ops::ProlongateSharedMinMod, refinement_ops::RestrictAverage>(), + const refinement::RefinementFunctions_t flux_ref_funcs_ = refinement::RefinementFunctions_t::RegisterOps< refinement_ops::ProlongateSharedMinMod, refinement_ops::RestrictAverage>()); - // 1 constructor + Metadata( + const std::vector &bits, const std::vector &shape = {}, + const std::vector &component_labels = {}, + const std::string &associated = "", + const refinement::RefinementFunctions_t ref_funcs_ = + refinement::RefinementFunctions_t::RegisterOps< + refinement_ops::ProlongateSharedMinMod, refinement_ops::RestrictAverage>()) + : Metadata(bits, {}, shape, component_labels, associated, ref_funcs_, ref_funcs_) {} + Metadata(const std::vector &bits, const std::vector &shape, const std::string &associated) : Metadata(bits, shape, {}, associated) {} - // 2 constructors Metadata(const std::vector &bits, const std::vector component_labels, const std::string &associated = "") : Metadata(bits, {1}, component_labels, associated) {} - // 1 constructor Metadata(const std::vector &bits, const std::string &associated) : Metadata(bits, {1}, {}, associated) {} + std::shared_ptr GetSPtrFluxMetadata() { + PARTHENON_REQUIRE(IsSet(WithFluxes), + "Asking for flux metadata from metadata that doesn't have it."); + return flux_metadata; + } + + private: + std::shared_ptr flux_metadata; + + public: // Static routines static MetadataFlag AddUserFlag(const std::string &name); static bool FlagNameExists(const std::string &flagname); diff --git a/src/interface/sparse_pool.cpp b/src/interface/sparse_pool.cpp index 8a6e0ef213ca..7890cff964de 100644 --- a/src/interface/sparse_pool.cpp +++ b/src/interface/sparse_pool.cpp @@ -11,6 +11,8 @@ // the public, perform publicly and display publicly, and to permit others to do so. //======================================================================================== +#include + #include "interface/sparse_pool.hpp" #include "interface/metadata.hpp" @@ -48,34 +50,31 @@ SparsePool::SparsePool(const std::string &base_name, const Metadata &metadata, } } -const Metadata &SparsePool::AddImpl(int sparse_id, const std::vector &shape, - const MetadataFlag *vector_tensor, - const std::vector &component_labels) { - PARTHENON_REQUIRE_THROWS(sparse_id != InvalidSparseID, - "Tried to add InvalidSparseID to sparse pool " + base_name_); - +std::shared_ptr +MakeSparseVarMetadataImpl(Metadata *in, const std::vector &shape, + const MetadataFlag *vector_tensor, + const std::vector &component_labels) { // copy shared metadata - Metadata this_metadata( - shared_metadata_.Flags(), shape.size() > 0 ? shape : shared_metadata_.Shape(), - component_labels.size() > 0 ? component_labels - : shared_metadata_.getComponentLabels(), - shared_metadata_.getAssociated(), shared_metadata_.GetRefinementFunctions()); + auto this_metadata = std::make_shared( + in->Flags(), shape.size() > 0 ? shape : in->Shape(), + component_labels.size() > 0 ? component_labels : in->getComponentLabels(), + in->getAssociated(), in->GetRefinementFunctions()); - this_metadata.SetSparseThresholds(shared_metadata_.GetAllocationThreshold(), - shared_metadata_.GetDeallocationThreshold(), - shared_metadata_.GetDefaultValue()); + this_metadata->SetSparseThresholds(in->GetAllocationThreshold(), + in->GetDeallocationThreshold(), + in->GetDefaultValue()); // if vector_tensor is set, apply it if (vector_tensor != nullptr) { if (*vector_tensor == Metadata::Vector) { - this_metadata.Unset(Metadata::Tensor); - this_metadata.Set(Metadata::Vector); + this_metadata->Unset(Metadata::Tensor); + this_metadata->Set(Metadata::Vector); } else if (*vector_tensor == Metadata::Tensor) { - this_metadata.Unset(Metadata::Vector); - this_metadata.Set(Metadata::Tensor); + this_metadata->Unset(Metadata::Vector); + this_metadata->Set(Metadata::Tensor); } else if (*vector_tensor == Metadata::None) { - this_metadata.Unset(Metadata::Vector); - this_metadata.Unset(Metadata::Tensor); + this_metadata->Unset(Metadata::Vector); + this_metadata->Unset(Metadata::Tensor); } else { PARTHENON_THROW("Expected MetadataFlag Vector, Tensor, or None, but got " + vector_tensor->Name()); @@ -83,9 +82,26 @@ const Metadata &SparsePool::AddImpl(int sparse_id, const std::vector &shape } // just in case - this_metadata.IsValid(true); + this_metadata->IsValid(true); + + return this_metadata; +} + +const Metadata &SparsePool::AddImpl(int sparse_id, const std::vector &shape, + const MetadataFlag *vector_tensor, + const std::vector &component_labels) { + PARTHENON_REQUIRE_THROWS(sparse_id != InvalidSparseID, + "Tried to add InvalidSparseID to sparse pool " + base_name_); + + auto this_metadata = MakeSparseVarMetadataImpl(&shared_metadata_, shape, vector_tensor, + component_labels); + if (this_metadata->IsSet(Metadata::WithFluxes)) { + this_metadata->GetSPtrFluxMetadata() = + MakeSparseVarMetadataImpl(shared_metadata_.GetSPtrFluxMetadata().get(), shape, + vector_tensor, component_labels); + } - const auto ins = pool_.insert({sparse_id, this_metadata}); + const auto ins = pool_.insert({sparse_id, *this_metadata}); PARTHENON_REQUIRE_THROWS(ins.second, "Tried to add sparse ID " + std::to_string(sparse_id) + " to sparse pool '" + base_name_ + diff --git a/src/interface/state_descriptor.cpp b/src/interface/state_descriptor.cpp index 4ba65755fcb4..b468b0201091 100644 --- a/src/interface/state_descriptor.cpp +++ b/src/interface/state_descriptor.cpp @@ -274,29 +274,9 @@ bool StateDescriptor::AddFieldImpl(const VarID &vid, const Metadata &m_in, return false; // this field has already been added } else { if (m.IsSet(Metadata::WithFluxes) && m.GetFluxName() == "") { - std::vector mFlags = {Metadata::Flux}; - if (!m.IsSet(Metadata::FluxNotOneCopy)) mFlags.push_back(Metadata::OneCopy); - if (m.IsSet(Metadata::Sparse)) mFlags.push_back(Metadata::Sparse); - if (m.IsSet(Metadata::Fine)) mFlags.push_back(Metadata::Fine); - if (m.IsSet(Metadata::Cell)) { - mFlags.push_back(Metadata::Face); - mFlags.push_back(Metadata::CellMemAligned); - } else if (m.IsSet(Metadata::Face)) { - mFlags.push_back(Metadata::Edge); - } else if (m.IsSet(Metadata::Edge)) { - mFlags.push_back(Metadata::Node); - } - Metadata mf; - if (m.GetRefinementFunctions().label().size() > 0) { - // Propagate custom refinement ops to flux field - mf = Metadata(mFlags, m.Shape(), std::vector(), std::string(), - m.GetRefinementFunctions()); - } else { - mf = Metadata(mFlags, m.Shape()); - } auto fId = VarID{internal_fluxname + internal_varname_seperator + vid.base_name, vid.sparse_id}; - AddFieldImpl(fId, mf, control_vid); + AddFieldImpl(fId, *(m.GetSPtrFluxMetadata()), control_vid); m.SetFluxName(fId.label()); } metadataMap_.insert({vid, m});