Skip to content

Commit

Permalink
MLMG interface (AMReX-Codes#2858)
Browse files Browse the repository at this point in the history
These changes are made to support a generic type (i.e., amrex::Any) in MLMG.
This is still work in progress.  But it should not break any existing codes.
  • Loading branch information
WeiqunZhang authored Aug 1, 2022
1 parent 5a3b303 commit 9469329
Show file tree
Hide file tree
Showing 21 changed files with 1,801 additions and 892 deletions.
5 changes: 4 additions & 1 deletion Src/Base/AMReX_Any.H
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,18 @@ public:
private:
struct innards_base {
virtual const std::type_info& Type () const = 0;
virtual ~innards_base () = default;
};

template <typename MF>
struct innards : innards_base
{
innards(MF && mf)
innards (MF && mf)
: m_mf(std::forward<MF>(mf))
{}

virtual ~innards () = default;

virtual const std::type_info& Type () const override {
return typeid(MF);
}
Expand Down
2 changes: 1 addition & 1 deletion Src/Base/AMReX_BaseFab.H
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ public:
*/
void clear () noexcept;

// Release ownership of memory
//! Release ownership of memory
std::unique_ptr<T,DataDeleter> release () noexcept;

//! Returns how many bytes used
Expand Down
2 changes: 1 addition & 1 deletion Src/Base/AMReX_FabArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -2848,7 +2848,7 @@ FabArray<FAB>::SumBoundary_nowait (int scomp, int ncomp, IntVect const& src_ngho

FabArray<FAB>* tmp = new FabArray<FAB>( boxArray(), DistributionMap(), ncomp, src_nghost, MFInfo(), Factory() );
amrex::Copy(*tmp, *this, scomp, 0, ncomp, src_nghost);
this->setVal(0.0, scomp, ncomp, dst_nghost);
this->setVal(typename FAB::value_type(0), scomp, ncomp, dst_nghost);
this->ParallelCopy_nowait(*tmp,0,scomp,ncomp,src_nghost,dst_nghost,period,FabArrayBase::ADD);

// All local. Operation complete.
Expand Down
4 changes: 2 additions & 2 deletions Src/LinearSolvers/MLMG/AMReX_MLABecLaplacian.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ MLABecLaplacian::applyMetricTermsCoeffs ()
for (int alev = 0; alev < m_num_amr_levels; ++alev)
{
const int mglev = 0;
applyMetricTerm(alev, mglev, m_a_coeffs[alev][mglev]);
applyMetricTermToMF(alev, mglev, m_a_coeffs[alev][mglev]);
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim)
{
applyMetricTerm(alev, mglev, m_b_coeffs[alev][mglev][idim]);
applyMetricTermToMF(alev, mglev, m_b_coeffs[alev][mglev][idim]);
}
}
#endif
Expand Down
5 changes: 5 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public:
Real eps_rel,
Real eps_abs);

int solve (Any& solnL,
const Any& rhsL,
Real eps_rel,
Real eps_abs);

void setVerbose (int _verbose) { verbose = _verbose; }
int getVerbose () const { return verbose; }

Expand Down
7 changes: 7 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ MLCGSolver::solve (MultiFab& sol,
}
}

int
MLCGSolver::solve (Any& sol, const Any& rhs, Real eps_rel, Real eps_abs)
{
AMREX_ASSERT(sol.is<MultiFab>()); // xxxxx TODO: MLCGSolver Any
return solve(sol.get<MultiFab>(), rhs.get<MultiFab>(), eps_rel, eps_abs);
}

int
MLCGSolver::solve_bicgstab (MultiFab& sol,
const MultiFab& rhs,
Expand Down
4 changes: 2 additions & 2 deletions Src/LinearSolvers/MLMG/AMReX_MLCellABecLap.H
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ public:
virtual MultiFab const* getACoeffs (int amrlev, int mglev) const = 0;
virtual Array<MultiFab const*,AMREX_SPACEDIM> getBCoeffs (int amrlev, int mglev) const = 0;

virtual void applyInhomogNeumannTerm (int amrlev, MultiFab& rhs) const final override;
virtual void applyInhomogNeumannTerm (int amrlev, Any& rhs) const final override;

virtual void applyOverset (int amlev, MultiFab& rhs) const override;
virtual void applyOverset (int amlev, Any& rhs) const override;

#if defined(AMREX_USE_HYPRE) && (AMREX_SPACEDIM > 1)
virtual std::unique_ptr<Hypre> makeHypre (Hypre::Interface hypre_interface) const override;
Expand Down
11 changes: 8 additions & 3 deletions Src/LinearSolvers/MLMG/AMReX_MLCellABecLap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ MLCellABecLap::define (const Vector<Geometry>& a_geom,
amrlev = 0;
for (int mglev = 1; mglev < m_num_mg_levels[amrlev]; ++mglev) {
MultiFab foo(m_grids[amrlev][mglev], m_dmap[amrlev][mglev], 1, 0, MFInfo().SetAlloc(false));
if (! isMFIterSafe(*m_overset_mask[amrlev][mglev], foo)) {
if (! amrex::isMFIterSafe(*m_overset_mask[amrlev][mglev], foo)) {
auto osm = std::make_unique<iMultiFab>(m_grids[amrlev][mglev],
m_dmap[amrlev][mglev], 1, 1);
osm->ParallelCopy(*m_overset_mask[amrlev][mglev]);
Expand Down Expand Up @@ -193,13 +193,16 @@ MLCellABecLap::getFluxes (const Vector<Array<MultiFab*,AMREX_SPACEDIM> >& a_flux
}

void
MLCellABecLap::applyInhomogNeumannTerm (int amrlev, MultiFab& rhs) const
MLCellABecLap::applyInhomogNeumannTerm (int amrlev, Any& a_rhs) const
{
bool has_inhomog_neumann = hasInhomogNeumannBC();
bool has_robin = hasRobinBC();

if (!has_inhomog_neumann && !has_robin) return;

AMREX_ASSERT(a_rhs.is<MultiFab>());
MultiFab& rhs = a_rhs.get<MultiFab>();

int ncomp = getNComp();
const int mglev = 0;

Expand Down Expand Up @@ -414,9 +417,11 @@ MLCellABecLap::applyInhomogNeumannTerm (int amrlev, MultiFab& rhs) const
}

void
MLCellABecLap::applyOverset (int amrlev, MultiFab& rhs) const
MLCellABecLap::applyOverset (int amrlev, Any& a_rhs) const
{
if (m_overset_mask[amrlev][0]) {
AMREX_ASSERT(a_rhs.is<MultiFab>());
auto& rhs = a_rhs.get<MultiFab>();
const int ncomp = getNComp();
#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
Expand Down
27 changes: 25 additions & 2 deletions Src/LinearSolvers/MLMG/AMReX_MLCellLinOp.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <AMReX_Config.H>

#include <AMReX_MLLinOp.H>
#include <AMReX_iMultiFab.H>

namespace amrex {

Expand Down Expand Up @@ -109,6 +110,8 @@ public:

virtual void interpolation (int amrlev, int fmglev, MultiFab& fine, const MultiFab& crse) const override;

virtual void interpAssign (int amrlev, int fmglev, MultiFab& fine, MultiFab& crse) const override;

virtual void averageDownSolutionRHS (int camrlev, MultiFab& crse_sol, MultiFab& crse_rhs,
const MultiFab& fine_sol, const MultiFab& fine_rhs) override;

Expand All @@ -132,9 +135,12 @@ public:
virtual void compGrad (int amrlev, const Array<MultiFab*,AMREX_SPACEDIM>& grad,
MultiFab& sol, Location loc) const override;

virtual void applyMetricTerm (int amrlev, int mglev, MultiFab& rhs) const final override;
virtual void applyMetricTerm (int amrlev, int mglev, Any& rhs) const final override;
virtual void unapplyMetricTerm (int amrlev, int mglev, MultiFab& rhs) const final override;
virtual void fillSolutionBC (int amrlev, MultiFab& sol, const MultiFab* crse_bcdata=nullptr) final override;
virtual Vector<Real> getSolvabilityOffset (int amrlev, int mglev,
Any const& rhs) const override;
virtual void fixSolvabilityByOffset (int amrlev, int mglev, Any& rhs,
Vector<Real> const& offset) const override;

virtual void prepareForSolve () override;

Expand All @@ -146,6 +152,18 @@ public:
const Array<FArrayBox*,AMREX_SPACEDIM>& flux,
const FArrayBox& sol, Location loc, const int face_only=0) const = 0;

// This could be turned into template if needed.
void applyMetricTermToMF (int amrlev, int mglev, MultiFab& rhs) const;

virtual Real AnyNormInfMask (int amrlev, Any const& a, bool local) const override;

virtual void AnyAvgDownResAmr (int clev, Any& cres, Any const& fres) const override;

virtual void AnyInterpolationAmr (int famrlev, Any& fine, const Any& crse,
IntVect const& /*nghost*/) const override;

virtual void AnyAverageDownAndSync (Vector<Any>& sol) const override;

struct BCTL {
BoundCond type;
Real location;
Expand Down Expand Up @@ -210,12 +228,17 @@ protected:
// boundary cell flags for covered, not_covered, outside_domain
Vector<Vector<Array<MultiMask,2*AMREX_SPACEDIM> > > m_maskvals;

Vector<std::unique_ptr<iMultiFab> > m_norm_fine_mask;

mutable Vector<YAFluxRegister> m_fluxreg;

private:

void defineAuxData ();
void defineBC ();

void computeVolInv () const;
mutable Vector<Vector<Real> > m_volinv; // used by solvability fix
};

}
Expand Down
Loading

0 comments on commit 9469329

Please sign in to comment.