Skip to content

Commit

Permalink
Make flux metadata accessible at WithFlux variable creation point
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Sep 4, 2024
1 parent 8e8f201 commit 3ff2e8a
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 58 deletions.
47 changes: 38 additions & 9 deletions src/interface/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,13 @@ MetadataFlag Metadata::GetUserFlag(const std::string &flagname) {
}

namespace parthenon {
Metadata::Metadata(const std::vector<MetadataFlag> &bits, const std::vector<int> &shape,
Metadata::Metadata(const std::vector<MetadataFlag> &bits,
const std::vector<MetadataFlag> &flux_bits,
const std::vector<int> &shape,
const std::vector<std::string> &component_labels,
const std::string &associated,
const refinement::RefinementFunctions_t ref_funcs_)
const refinement::RefinementFunctions_t ref_funcs_,
const refinement::RefinementFunctions_t flux_ref_funcs_)
: shape_(shape), component_labels_(component_labels), associated_(associated) {
// set flags
for (const auto f : bits) {
Expand Down Expand Up @@ -164,6 +167,39 @@ Metadata::Metadata(const std::vector<MetadataFlag> &bits, const std::vector<int>
deallocation_threshold_ = 0.0;
default_value_ = 0.0;
}

// Now create the flux metadata if required
if (IsSet(WithFluxes)) {
std::set<MetadataFlag> flux_flags;
for (const auto f : flux_bits)
flux_flags.insert(f);

// Set some standard defaults for the flux metadata if no
// flags were provided
if (flux_flags.size() == 0) {
flux_flags.insert(OneCopy);
if (IsSet(Fine)) flux_flags.insert(Fine);
if (IsSet(Cell)) flux_flags.insert(CellMemAligned);
if (IsSet(Sparse)) flux_flags.insert(Sparse);
}

// These flags are automatically propagated for fluxes
flux_flags.insert(Flux);
if (IsSet(Cell)) {
flux_flags.insert(Face);
} else if (IsSet(Face)) {
flux_flags.insert(Edge);
} else if (IsSet(Edge)) {
flux_flags.insert(Node);
}

if (IsSet(Tensor)) flux_flags.insert(Tensor);
if (IsSet(Vector)) flux_flags.insert(Vector);

std::vector<MetadataFlag> flux_flags_vec(flux_flags.begin(), flux_flags.end());
flux_metadata = std::make_shared<Metadata>(flux_flags_vec, shape, component_labels,
std::string(), flux_ref_funcs_);
}
}

std::ostream &operator<<(std::ostream &os, const parthenon::Metadata &m) {
Expand Down Expand Up @@ -271,13 +307,6 @@ bool Metadata::IsValid(bool throw_on_fail) const {
}
}

// Associated fluxes
if (IsSet(FluxNotOneCopy)) {
PARTHENON_REQUIRE(
IsSet(WithFluxes),
"Asking for non-OneCopy associated fluxes without asking for associated fluxes.");
}

return valid;
}

Expand Down
30 changes: 24 additions & 6 deletions src/interface/metadata.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@
PARTHENON_INTERNAL_FOR_FLAG(Fine) \
/** this variable is the flux for another variable **/ \
PARTHENON_INTERNAL_FOR_FLAG(Flux) \
/** allocate a separate flux array for each stage if WithFluxes is specified**/ \
PARTHENON_INTERNAL_FOR_FLAG(FluxNotOneCopy) \
/** Align memory of fields to cell centered memory \
(Field will be missing one layer of ghosts if it is not cell centered) **/ \
PARTHENON_INTERNAL_FOR_FLAG(CellMemAligned) \
Expand Down Expand Up @@ -330,28 +328,48 @@ class Metadata {
// 4 constructors, this is the general constructor called by all other constructors, so
// we do some sanity checks here
Metadata(
const std::vector<MetadataFlag> &bits, const std::vector<int> &shape = {},
const std::vector<MetadataFlag> &bits, const std::vector<MetadataFlag> &flux_bits,
const std::vector<int> &shape = {},
const std::vector<std::string> &component_labels = {},
const std::string &associated = "",
const refinement::RefinementFunctions_t ref_funcs_ =
refinement::RefinementFunctions_t::RegisterOps<
refinement_ops::ProlongateSharedMinMod, refinement_ops::RestrictAverage>(),
const refinement::RefinementFunctions_t flux_ref_funcs_ =
refinement::RefinementFunctions_t::RegisterOps<
refinement_ops::ProlongateSharedMinMod, refinement_ops::RestrictAverage>());

// 1 constructor
Metadata(
const std::vector<MetadataFlag> &bits, const std::vector<int> &shape = {},
const std::vector<std::string> &component_labels = {},
const std::string &associated = "",
const refinement::RefinementFunctions_t ref_funcs_ =
refinement::RefinementFunctions_t::RegisterOps<
refinement_ops::ProlongateSharedMinMod, refinement_ops::RestrictAverage>())
: Metadata(bits, {}, shape, component_labels, associated, ref_funcs_, ref_funcs_) {}

Metadata(const std::vector<MetadataFlag> &bits, const std::vector<int> &shape,
const std::string &associated)
: Metadata(bits, shape, {}, associated) {}

// 2 constructors
Metadata(const std::vector<MetadataFlag> &bits,
const std::vector<std::string> component_labels,
const std::string &associated = "")
: Metadata(bits, {1}, component_labels, associated) {}

// 1 constructor
Metadata(const std::vector<MetadataFlag> &bits, const std::string &associated)
: Metadata(bits, {1}, {}, associated) {}

std::shared_ptr<Metadata> GetSPtrFluxMetadata() {
PARTHENON_REQUIRE(IsSet(WithFluxes),
"Asking for flux metadata from metadata that doesn't have it.");
return flux_metadata;
}

private:
std::shared_ptr<Metadata> flux_metadata;

public:
// Static routines
static MetadataFlag AddUserFlag(const std::string &name);
static bool FlagNameExists(const std::string &flagname);
Expand Down
60 changes: 38 additions & 22 deletions src/interface/sparse_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
// the public, perform publicly and display publicly, and to permit others to do so.
//========================================================================================

#include <memory>

#include "interface/sparse_pool.hpp"

#include "interface/metadata.hpp"
Expand Down Expand Up @@ -48,44 +50,58 @@ SparsePool::SparsePool(const std::string &base_name, const Metadata &metadata,
}
}

const Metadata &SparsePool::AddImpl(int sparse_id, const std::vector<int> &shape,
const MetadataFlag *vector_tensor,
const std::vector<std::string> &component_labels) {
PARTHENON_REQUIRE_THROWS(sparse_id != InvalidSparseID,
"Tried to add InvalidSparseID to sparse pool " + base_name_);

std::shared_ptr<Metadata>
MakeSparseVarMetadataImpl(Metadata *in, const std::vector<int> &shape,
const MetadataFlag *vector_tensor,
const std::vector<std::string> &component_labels) {
// copy shared metadata
Metadata this_metadata(
shared_metadata_.Flags(), shape.size() > 0 ? shape : shared_metadata_.Shape(),
component_labels.size() > 0 ? component_labels
: shared_metadata_.getComponentLabels(),
shared_metadata_.getAssociated(), shared_metadata_.GetRefinementFunctions());
auto this_metadata = std::make_shared<Metadata>(
in->Flags(), shape.size() > 0 ? shape : in->Shape(),
component_labels.size() > 0 ? component_labels : in->getComponentLabels(),
in->getAssociated(), in->GetRefinementFunctions());

this_metadata.SetSparseThresholds(shared_metadata_.GetAllocationThreshold(),
shared_metadata_.GetDeallocationThreshold(),
shared_metadata_.GetDefaultValue());
this_metadata->SetSparseThresholds(in->GetAllocationThreshold(),
in->GetDeallocationThreshold(),
in->GetDefaultValue());

// if vector_tensor is set, apply it
if (vector_tensor != nullptr) {
if (*vector_tensor == Metadata::Vector) {
this_metadata.Unset(Metadata::Tensor);
this_metadata.Set(Metadata::Vector);
this_metadata->Unset(Metadata::Tensor);
this_metadata->Set(Metadata::Vector);
} else if (*vector_tensor == Metadata::Tensor) {
this_metadata.Unset(Metadata::Vector);
this_metadata.Set(Metadata::Tensor);
this_metadata->Unset(Metadata::Vector);
this_metadata->Set(Metadata::Tensor);
} else if (*vector_tensor == Metadata::None) {
this_metadata.Unset(Metadata::Vector);
this_metadata.Unset(Metadata::Tensor);
this_metadata->Unset(Metadata::Vector);
this_metadata->Unset(Metadata::Tensor);
} else {
PARTHENON_THROW("Expected MetadataFlag Vector, Tensor, or None, but got " +
vector_tensor->Name());
}
}

// just in case
this_metadata.IsValid(true);
this_metadata->IsValid(true);

return this_metadata;
}

const Metadata &SparsePool::AddImpl(int sparse_id, const std::vector<int> &shape,
const MetadataFlag *vector_tensor,
const std::vector<std::string> &component_labels) {
PARTHENON_REQUIRE_THROWS(sparse_id != InvalidSparseID,
"Tried to add InvalidSparseID to sparse pool " + base_name_);

auto this_metadata = MakeSparseVarMetadataImpl(&shared_metadata_, shape, vector_tensor,
component_labels);
if (this_metadata->IsSet(Metadata::WithFluxes)) {
this_metadata->GetSPtrFluxMetadata() =
MakeSparseVarMetadataImpl(shared_metadata_.GetSPtrFluxMetadata().get(), shape,
vector_tensor, component_labels);
}

const auto ins = pool_.insert({sparse_id, this_metadata});
const auto ins = pool_.insert({sparse_id, *this_metadata});
PARTHENON_REQUIRE_THROWS(ins.second, "Tried to add sparse ID " +
std::to_string(sparse_id) +
" to sparse pool '" + base_name_ +
Expand Down
22 changes: 1 addition & 21 deletions src/interface/state_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,29 +274,9 @@ bool StateDescriptor::AddFieldImpl(const VarID &vid, const Metadata &m_in,
return false; // this field has already been added
} else {
if (m.IsSet(Metadata::WithFluxes) && m.GetFluxName() == "") {
std::vector<MetadataFlag> mFlags = {Metadata::Flux};
if (!m.IsSet(Metadata::FluxNotOneCopy)) mFlags.push_back(Metadata::OneCopy);
if (m.IsSet(Metadata::Sparse)) mFlags.push_back(Metadata::Sparse);
if (m.IsSet(Metadata::Fine)) mFlags.push_back(Metadata::Fine);
if (m.IsSet(Metadata::Cell)) {
mFlags.push_back(Metadata::Face);
mFlags.push_back(Metadata::CellMemAligned);
} else if (m.IsSet(Metadata::Face)) {
mFlags.push_back(Metadata::Edge);
} else if (m.IsSet(Metadata::Edge)) {
mFlags.push_back(Metadata::Node);
}
Metadata mf;
if (m.GetRefinementFunctions().label().size() > 0) {
// Propagate custom refinement ops to flux field
mf = Metadata(mFlags, m.Shape(), std::vector<std::string>(), std::string(),
m.GetRefinementFunctions());
} else {
mf = Metadata(mFlags, m.Shape());
}
auto fId = VarID{internal_fluxname + internal_varname_seperator + vid.base_name,
vid.sparse_id};
AddFieldImpl(fId, mf, control_vid);
AddFieldImpl(fId, *(m.GetSPtrFluxMetadata()), control_vid);
m.SetFluxName(fId.label());
}
metadataMap_.insert({vid, m});
Expand Down

0 comments on commit 3ff2e8a

Please sign in to comment.