Skip to content

Commit

Permalink
thresholds from python
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasaunai committed Aug 9, 2024
1 parent 1eff256 commit 1059f24
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 16 deletions.
8 changes: 8 additions & 0 deletions pyphare/pyphare/pharein/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ def as_paths(rb):
as_paths(refinement_boxes)
elif simulation.refinement == "tagging":
add_string("simulation/AMR/refinement/tagging/method", "auto")
# the two following params are hard-coded for now
# they will become configurable when we have multi-models or several methods
# per model
add_string("simulation/AMR/refinement/tagging/model", "HybridModel")
add_string("simulation/AMR/refinement/tagging/method", "default")
add_double(
"simulation/AMR/refinement/tagging/threshold", simulation.tagging_threshold
)
else:
add_string(
"simulation/AMR/refinement/tagging/method", "none"
Expand Down
7 changes: 7 additions & 0 deletions pyphare/pyphare/pharein/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ def wrapper(simulation_object, **kwargs):
"diag_export_format",
"refinement_boxes",
"refinement",
"tagging_threshold",
"clustering",
"smallest_patch_size",
"largest_patch_size",
Expand Down Expand Up @@ -677,6 +678,7 @@ def wrapper(simulation_object, **kwargs):
kwargs["max_nbr_levels"] = kwargs.get("max_nbr_levels", None)
assert kwargs["max_nbr_levels"] is not None # this needs setting otherwise
kwargs["refinement_boxes"] = None
kwargs["tagging_threshold"] = kwargs.get("tagging_threshold", 0.1)

kwargs["resistivity"] = check_resistivity(**kwargs)

Expand Down Expand Up @@ -804,6 +806,11 @@ class Simulation(object):
:Keyword Arguments:
* *nesting_buffer* (``ìnt``)--
[default=0] minimum gap in coarse cells from the border of a level and any refined patch border
kwargs["
* *refinement* (``str``)--
"boxes" (default), "tagging" type of refinement to use.
"tagging": use tagging_threshold to tag cells for refinement
"boxes": use refinement_boxes to define refinement levels
* *refinement_boxes* --
[default=None] {"L0":{"B0":[(lox,loy,loz),(upx,upy,upz)],...,"Bi":[(),()]},..."Li":{B0:[(),()]}}
* *smallest_patch_size* (``int`` or ``tuple``)--
Expand Down
20 changes: 13 additions & 7 deletions src/amr/tagging/default_hybrid_tagger_strategy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
#include "core/data/vecfield/vecfield_component.hpp"
#include "core/data/ndarray/ndarray_vector.hpp"
#include <cstddef>
#include <iostream>

#include "initializer/data_provider.hpp"

namespace PHARE::amr
{
Expand All @@ -17,8 +16,17 @@ class DefaultHybridTaggerStrategy : public HybridTaggerStrategy<HybridModel>
using gridlayout_type = typename HybridModel::gridlayout_type;
static auto constexpr dimension = HybridModel::dimension;


public:
DefaultHybridTaggerStrategy(initializer::PHAREDict const& dict)
:threshold_{[&](){return (dict.contains("threshold")) ? dict["threshold"].template to<double>() : 0.1;}()}
{

}
void tag(HybridModel& model, gridlayout_type const& layout, int* tags) const override;

private:
double threshold_ = 0.1;
};

template<typename HybridModel>
Expand All @@ -31,8 +39,6 @@ void DefaultHybridTaggerStrategy<HybridModel>::tag(HybridModel& model,

auto& N = model.state.ions.density();

double threshold = 0.1;

// we loop on cell indexes for all qties regardless of their centering
auto const& [start_x, _]
= layout.physicalStartToEnd(PHARE::core::QtyCentering::dual, PHARE::core::Direction::X);
Expand All @@ -56,7 +62,7 @@ void DefaultHybridTaggerStrategy<HybridModel>::tag(HybridModel& model,
auto crit_bz_x = (Bz(ix + 2) - Bz(ix)) / (1 + Bz(ix + 1) - Bz(ix));
auto criter = std::max(crit_by_x, crit_bz_x);

if (criter > threshold)
if (criter > threshold_)
{
tagsv(iCell) = 1;
}
Expand All @@ -82,7 +88,7 @@ void DefaultHybridTaggerStrategy<HybridModel>::tag(HybridModel& model,
auto criter_b = std::sqrt(criter_by * criter_by + criter_bz * criter_bz);
auto criter = criter_b;

if (criter > threshold)
if (criter > threshold_)
{
tagsv(iCell) = 1;
}
Expand Down Expand Up @@ -113,7 +119,7 @@ void DefaultHybridTaggerStrategy<HybridModel>::tag(HybridModel& model,
auto const& [Bz_x, Bz_y] = field_diff(Bz);
auto crit = std::max({Bx_x, Bx_y, By_x, By_y, Bz_x, Bz_y});

if (crit > threshold)
if (crit > threshold_)
{
tagsv(iTag_x, iTag_y) = 1;
}
Expand Down
10 changes: 7 additions & 3 deletions src/amr/tagging/tagger_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "hybrid_tagger_strategy.hpp"
#include "default_hybrid_tagger_strategy.hpp"
#include "core/def.hpp"
#include "initializer/data_provider.hpp"

namespace PHARE::amr
{
Expand All @@ -17,12 +18,15 @@ class TaggerFactory
{
public:
TaggerFactory() = delete;
NO_DISCARD static std::unique_ptr<Tagger> make(std::string modelName, std::string methodName);
NO_DISCARD static std::unique_ptr<Tagger> make(PHARE::initializer::PHAREDict const& dict);
};

template<typename PHARE_T>
std::unique_ptr<Tagger> TaggerFactory<PHARE_T>::make(std::string modelName, std::string methodName)
std::unique_ptr<Tagger> TaggerFactory<PHARE_T>::make(PHARE::initializer::PHAREDict const& dict)
{
auto modelName = dict["model"].template to<std::string>();
auto methodName = dict["method"].template to<std::string>();

if (modelName == "HybridModel")
{
using HybridModel = typename PHARE_T::HybridModel_t;
Expand All @@ -31,7 +35,7 @@ std::unique_ptr<Tagger> TaggerFactory<PHARE_T>::make(std::string modelName, std:
if (methodName == "default")
{
using HTS = DefaultHybridTaggerStrategy<HybridModel>;
return std::make_unique<HT>(std::make_unique<HTS>());
return std::make_unique<HT>(std::make_unique<HTS>(dict));
}
}
return nullptr;
Expand Down
12 changes: 10 additions & 2 deletions src/simulator/simulator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,16 @@ void Simulator<dim, _interp, nbRefinedPart>::hybrid_init(initializer::PHAREDict
multiphysInteg_->registerAndSetupMessengers(messengerFactory_);

// hard coded for now, should get some params later from the dict
auto hybridTagger_ = amr::TaggerFactory<PHARETypes>::make("HybridModel", "default");
multiphysInteg_->registerTagger(0, maxLevelNumber_ - 1, std::move(hybridTagger_));
if (dict["simulation"]["AMR"]["refinement"].contains("tagging"))
{
if (dict["simulation"]["AMR"]["refinement"]["tagging"]["method"].template to<std::string>()
!= "none")
{
auto hybridTagger_ = amr::TaggerFactory<PHARETypes>::make(
dict["simulation"]["AMR"]["refinement"]["tagging"]);
multiphysInteg_->registerTagger(0, maxLevelNumber_ - 1, std::move(hybridTagger_));
}
}

amr::LoadBalancerDetails lb_info
= amr::LoadBalancerDetails::FROM(dict["simulation"]["AMR"]["loadbalancing"]);
Expand Down
19 changes: 15 additions & 4 deletions tests/amr/tagging/test_tagging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,29 @@ using namespace PHARE::amr;



TEST(test_tagger, fromFactory)
TEST(test_tagger, fromFactoryValid)
{
using phare_types = PHARE::PHARE_Types<1, 1, 2>;
auto hybridTagger = TaggerFactory<phare_types>::make("HybridModel", "default");
PHARE::initializer::PHAREDict dict;
dict["model"] = std::string{"HybridModel"};
dict["method"] = std::string{"default"};
dict["threshold"] = 0.2;
auto hybridTagger = TaggerFactory<phare_types>::make(dict);
EXPECT_TRUE(hybridTagger != nullptr);
}

auto badTagger = TaggerFactory<phare_types>::make("invalidModel", "invalidStrat");
TEST(test_tagger, fromFactoryInvalid)
{
using phare_types = PHARE::PHARE_Types<1, 1, 2>;
PHARE::initializer::PHAREDict dict;
dict["model"] = std::string{"invalidModel"};
dict["method"] = std::string{"invalidStrat"};
auto hybridTagger = TaggerFactory<phare_types>::make(dict);
auto badTagger = TaggerFactory<phare_types>::make(dict);
EXPECT_TRUE(badTagger == nullptr);
}



using Param = std::vector<double>;
using RetType = std::shared_ptr<PHARE::core::Span<double>>;

Expand Down

0 comments on commit 1059f24

Please sign in to comment.