Skip to content

Commit

Permalink
Merge pull request #118 from awslabs/sjg/discrete-interpolator-refactor
Browse files Browse the repository at this point in the history
Refactor construction of discrete interpolator matrices for reuse
  • Loading branch information
sebastiangrimberg authored Oct 31, 2023
2 parents 7cc51e8 + 8c05eaf commit e5118ac
Show file tree
Hide file tree
Showing 46 changed files with 672 additions and 529 deletions.
5 changes: 0 additions & 5 deletions docs/src/config/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ Thus, this object is only relevant for [`config["Problem"]["Type"]: "Magnetostat
`"Linear"` : Top-level object for configuring the linear solver employed by all simulation
types.

### Advanced solver options

- `"PartialAssemblyInterpolators" [true]`

## `solver["Eigenmode"]`

```json
Expand Down Expand Up @@ -456,7 +452,6 @@ vectors in Krylov subspace methods or other parts of the code.
### Advanced linear solver options

- `"InitialGuess" [true]`
- `"MGLegacyTransfer" [false]`
- `"MGAuxiliarySmoother" [true]`
- `"MGSmoothEigScaleMax" [1.0]`
- `"MGSmoothEigScaleMin" [0.0]`
Expand Down
7 changes: 4 additions & 3 deletions palace/drivers/basesolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <mfem.hpp>
#include <nlohmann/json.hpp>
#include "fem/errorindicator.hpp"
#include "fem/fespace.hpp"
#include "linalg/ksp.hpp"
#include "models/domainpostoperator.hpp"
#include "models/postoperator.hpp"
Expand Down Expand Up @@ -84,17 +85,17 @@ BaseSolver::BaseSolver(const IoData &iodata, bool root, int size, int num_thread
}
}

void BaseSolver::SaveMetadata(const mfem::ParFiniteElementSpaceHierarchy &fespaces) const
void BaseSolver::SaveMetadata(const FiniteElementSpaceHierarchy &fespaces) const
{
if (post_dir.length() == 0)
{
return;
}
const mfem::ParFiniteElementSpace &fespace = fespaces.GetFinestFESpace();
const auto &fespace = fespaces.GetFinestFESpace();
HYPRE_BigInt ne = fespace.GetParMesh()->GetNE();
Mpi::GlobalSum(1, &ne, fespace.GetComm());
std::vector<HYPRE_BigInt> ndofs(fespaces.GetNumLevels());
for (int l = 0; l < fespaces.GetNumLevels(); l++)
for (std::size_t l = 0; l < fespaces.GetNumLevels(); l++)
{
ndofs[l] = fespaces.GetFESpaceAtLevel(l).GlobalTrueVSize();
}
Expand Down
4 changes: 2 additions & 2 deletions palace/drivers/basesolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
namespace mfem
{

class ParFiniteElementSpaceHierarchy;
class ParMesh;

} // namespace mfem
Expand All @@ -21,6 +20,7 @@ namespace palace
{

class ErrorIndicator;
class FiniteElementSpaceHierarchy;
class IoData;
class PostOperator;
class Timer;
Expand Down Expand Up @@ -90,7 +90,7 @@ class BaseSolver
Solve(const std::vector<std::unique_ptr<mfem::ParMesh>> &mesh) const = 0;

// These methods write different simulation metadata to a JSON file in post_dir.
void SaveMetadata(const mfem::ParFiniteElementSpaceHierarchy &fespaces) const;
void SaveMetadata(const FiniteElementSpaceHierarchy &fespaces) const;
template <typename SolverType>
void SaveMetadata(const SolverType &ksp) const;
void SaveMetadata(const Timer &timer) const;
Expand Down
20 changes: 10 additions & 10 deletions palace/drivers/drivensolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ ErrorIndicator DrivenSolver::SweepUniform(SpaceOperator &spaceop, PostOperator &
auto C = spaceop.GetDampingMatrix<ComplexOperator>(Operator::DIAG_ZERO);
auto M = spaceop.GetMassMatrix<ComplexOperator>(Operator::DIAG_ZERO);
auto A2 = spaceop.GetExtraSystemMatrix<ComplexOperator>(omega0, Operator::DIAG_ZERO);
auto Curl = spaceop.GetCurlMatrix<ComplexOperator>();
const auto &Curl = spaceop.GetCurlMatrix();

// Set up the linear solver and set operators for the first frequency step. The
// preconditioner for the complex linear system is constructed from a real approximation
Expand All @@ -132,15 +132,14 @@ ErrorIndicator DrivenSolver::SweepUniform(SpaceOperator &spaceop, PostOperator &

// Set up RHS vector for the incident field at port boundaries, and the vector for the
// first frequency step.
ComplexVector RHS(Curl->Width()), E(Curl->Width()), B(Curl->Height());
ComplexVector RHS(Curl.Width()), E(Curl.Width()), B(Curl.Height());
E = 0.0;
B = 0.0;

// Initialize structures for storing and reducing the results of error estimation.
CurlFluxErrorEstimator<ComplexVector> estimator(
spaceop.GetMaterialOp(), spaceop.GetNDSpaces(), iodata.solver.linear.estimator_tol,
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold,
iodata.solver.pa_discrete_interp);
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold);
ErrorIndicator indicator;

// Main frequency sweep loop.
Expand Down Expand Up @@ -176,7 +175,8 @@ ErrorIndicator DrivenSolver::SweepUniform(SpaceOperator &spaceop, PostOperator &
// PostOperator for all postprocessing operations.
BlockTimer bt2(Timer::POSTPRO);
double E_elec = 0.0, E_mag = 0.0;
Curl->Mult(E, B);
Curl.Mult(E.Real(), B.Real());
Curl.Mult(E.Imag(), B.Imag());
B *= -1.0 / (1i * omega);
postop.SetEGridFunction(E);
postop.SetBGridFunction(B);
Expand Down Expand Up @@ -244,16 +244,15 @@ ErrorIndicator DrivenSolver::SweepAdaptive(SpaceOperator &spaceop, PostOperator

// Allocate negative curl matrix for postprocessing the B-field and vectors for the
// high-dimensional field solution.
auto Curl = spaceop.GetCurlMatrix<ComplexOperator>();
ComplexVector E(Curl->Width()), B(Curl->Height());
const auto &Curl = spaceop.GetCurlMatrix();
ComplexVector E(Curl.Width()), B(Curl.Height());
E = 0.0;
B = 0.0;

// Initialize structures for storing and reducing the results of error estimation.
CurlFluxErrorEstimator<ComplexVector> estimator(
spaceop.GetMaterialOp(), spaceop.GetNDSpaces(), iodata.solver.linear.estimator_tol,
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold,
iodata.solver.pa_discrete_interp);
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold);
ErrorIndicator indicator;

// Configure the PROM operator which performs the parameter space sampling and basis
Expand Down Expand Up @@ -337,7 +336,8 @@ ErrorIndicator DrivenSolver::SweepAdaptive(SpaceOperator &spaceop, PostOperator
// PostOperator for all postprocessing operations.
BlockTimer bt4(Timer::POSTPRO);
double E_elec = 0.0, E_mag = 0.0;
Curl->Mult(E, B);
Curl.Mult(E.Real(), B.Real());
Curl.Mult(E.Imag(), B.Imag());
B *= -1.0 / (1i * omega);
postop.SetEGridFunction(E);
postop.SetBGridFunction(B);
Expand Down
17 changes: 9 additions & 8 deletions palace/drivers/eigensolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ EigenSolver::Solve(const std::vector<std::unique_ptr<mfem::ParMesh>> &mesh) cons
auto K = spaceop.GetStiffnessMatrix<ComplexOperator>(Operator::DIAG_ONE);
auto C = spaceop.GetDampingMatrix<ComplexOperator>(Operator::DIAG_ZERO);
auto M = spaceop.GetMassMatrix<ComplexOperator>(Operator::DIAG_ZERO);
auto Curl = spaceop.GetCurlMatrix<ComplexOperator>();
const auto &Curl = spaceop.GetCurlMatrix();
SaveMetadata(spaceop.GetNDSpaces());

// Configure objects for postprocessing.
PostOperator postop(iodata, spaceop, "eigenmode");
ComplexVector E(Curl->Width()), B(Curl->Height());
ComplexVector E(Curl.Width()), B(Curl.Height());

// Define and configure the eigensolver to solve the eigenvalue problem:
// (K + λ C + λ² M) u = 0 or K u = -λ² M u
Expand Down Expand Up @@ -166,7 +166,7 @@ EigenSolver::Solve(const std::vector<std::unique_ptr<mfem::ParMesh>> &mesh) cons
spaceop.GetMaterialOp(), spaceop.GetNDSpace(), spaceop.GetH1Spaces(),
spaceop.GetAuxBdrTDofLists(), iodata.solver.linear.divfree_tol,
iodata.solver.linear.divfree_max_it, divfree_verbose,
iodata.solver.pa_order_threshold, iodata.solver.pa_discrete_interp);
iodata.solver.pa_order_threshold);
eigen->SetDivFreeProjector(*divfree);
}

Expand All @@ -192,9 +192,10 @@ EigenSolver::Solve(const std::vector<std::unique_ptr<mfem::ParMesh>> &mesh) cons
eigen->SetInitialSpace(v0); // Copies the vector

// Debug
// auto Grad = spaceop.GetGradMatrix<ComplexOperator>();
// const auto &Grad = spaceop.GetGradMatrix();
// ComplexVector r0(Grad->Width());
// Grad->MultTranspose(v0, r0);
// Grad.MultTranspose(v0.Real(), r0.Real());
// Grad.MultTranspose(v0.Imag(), r0.Imag());
// r0.Print();
}

Expand Down Expand Up @@ -260,8 +261,7 @@ EigenSolver::Solve(const std::vector<std::unique_ptr<mfem::ParMesh>> &mesh) cons
// Calculate and record the error indicators.
CurlFluxErrorEstimator<ComplexVector> estimator(
spaceop.GetMaterialOp(), spaceop.GetNDSpaces(), iodata.solver.linear.estimator_tol,
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold,
iodata.solver.pa_discrete_interp);
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold);
ErrorIndicator indicator;
for (int i = 0; i < iodata.solver.eigenmode.n; i++)
{
Expand Down Expand Up @@ -296,7 +296,8 @@ EigenSolver::Solve(const std::vector<std::unique_ptr<mfem::ParMesh>> &mesh) cons
// Compute B = -1/(iω) ∇ x E on the true dofs, and set the internal GridFunctions in
// PostOperator for all postprocessing operations.
eigen->GetEigenvector(i, E);
Curl->Mult(E, B);
Curl.Mult(E.Real(), B.Real());
Curl.Mult(E.Imag(), B.Imag());
B *= -1.0 / (1i * omega);
postop.SetEGridFunction(E);
postop.SetBGridFunction(B);
Expand Down
16 changes: 8 additions & 8 deletions palace/drivers/electrostaticsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,21 @@ ErrorIndicator ElectrostaticSolver::Postprocess(LaplaceOperator &laplaceop,
// charges from the prescribed voltage to get C directly as:
// Q_i = ∫ ρ dV = ∫ ∇ ⋅ (ε E) dV = ∫ (ε E) ⋅ n dS
// and C_ij = Q_i/V_j. The energy formulation avoids having to locally integrate E = -∇V.
auto Grad = laplaceop.GetGradMatrix();
const auto &Grad = laplaceop.GetGradMatrix();
const std::map<int, mfem::Array<int>> &terminal_sources = laplaceop.GetSources();
int nstep = static_cast<int>(terminal_sources.size());
mfem::DenseMatrix C(nstep), Cm(nstep);
Vector E(Grad->Height()), Vij(Grad->Width());
Vector E(Grad.Height()), Vij(Grad.Width());
if (iodata.solver.electrostatic.n_post > 0)
{
Mpi::Print("\n");
}

// Calculate and record the error indicators.
GradFluxErrorEstimator estimator(
laplaceop.GetMaterialOp(), laplaceop.GetH1Spaces(),
iodata.solver.linear.estimator_tol, iodata.solver.linear.estimator_max_it, 0,
iodata.solver.pa_order_threshold, iodata.solver.pa_discrete_interp);
GradFluxErrorEstimator estimator(laplaceop.GetMaterialOp(), laplaceop.GetH1Spaces(),
iodata.solver.linear.estimator_tol,
iodata.solver.linear.estimator_max_it, 0,
iodata.solver.pa_order_threshold);
ErrorIndicator indicator;
for (int i = 0; i < nstep; i++)
{
Expand All @@ -113,7 +113,7 @@ ErrorIndicator ElectrostaticSolver::Postprocess(LaplaceOperator &laplaceop,
// Compute E = -∇V on the true dofs, and set the internal GridFunctions in PostOperator
// for all postprocessing operations.
E = 0.0;
Grad->AddMult(V[i], E, -1.0);
Grad.AddMult(V[i], E, -1.0);
postop.SetEGridFunction(E);
postop.SetVGridFunction(V[i]);
double Ue = postop.GetEFieldEnergy();
Expand Down Expand Up @@ -151,7 +151,7 @@ ErrorIndicator ElectrostaticSolver::Postprocess(LaplaceOperator &laplaceop,
{
linalg::AXPBYPCZ(1.0, V[i], 1.0, V[j], 0.0, Vij);
E = 0.0;
Grad->AddMult(Vij, E, -1.0);
Grad.AddMult(Vij, E, -1.0);
postop.SetEGridFunction(E);
double Ue = postop.GetEFieldEnergy();
C(i, j) = Ue - 0.5 * (C(i, i) + C(j, j));
Expand Down
10 changes: 5 additions & 5 deletions palace/drivers/magnetostaticsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ ErrorIndicator MagnetostaticSolver::Postprocess(CurlCurlOperator &curlcurlop,
// Φ_i = ∫ B ⋅ n_j dS
// and M_ij = Φ_i/I_j. The energy formulation avoids having to locally integrate B =
// ∇ x A.
auto Curl = curlcurlop.GetCurlMatrix();
const auto &Curl = curlcurlop.GetCurlMatrix();
const SurfaceCurrentOperator &surf_j_op = curlcurlop.GetSurfaceCurrentOp();
int nstep = static_cast<int>(surf_j_op.Size());
mfem::DenseMatrix M(nstep), Mm(nstep);
Vector B(Curl->Height()), Aij(Curl->Width());
Vector B(Curl.Height()), Aij(Curl.Width());
Vector Iinc(nstep);
if (iodata.solver.magnetostatic.n_post > 0)
{
Expand All @@ -103,7 +103,7 @@ ErrorIndicator MagnetostaticSolver::Postprocess(CurlCurlOperator &curlcurlop,
CurlFluxErrorEstimator<Vector> estimator(
curlcurlop.GetMaterialOp(), curlcurlop.GetNDSpaces(),
iodata.solver.linear.estimator_tol, iodata.solver.linear.estimator_max_it, 0,
iodata.solver.pa_order_threshold, iodata.solver.pa_discrete_interp);
iodata.solver.pa_order_threshold);
ErrorIndicator indicator;
for (int i = 0; i < nstep; i++)
{
Expand All @@ -120,7 +120,7 @@ ErrorIndicator MagnetostaticSolver::Postprocess(CurlCurlOperator &curlcurlop,

// Compute B = ∇ x A on the true dofs, and set the internal GridFunctions in
// PostOperator for all postprocessing operations.
Curl->Mult(A[i], B);
Curl.Mult(A[i], B);
postop.SetBGridFunction(B);
postop.SetAGridFunction(A[i]);
double Um = postop.GetHFieldEnergy();
Expand Down Expand Up @@ -157,7 +157,7 @@ ErrorIndicator MagnetostaticSolver::Postprocess(CurlCurlOperator &curlcurlop,
else if (j > i)
{
linalg::AXPBYPCZ(1.0, A[i], 1.0, A[j], 0.0, Aij);
Curl->Mult(Aij, B);
Curl.Mult(Aij, B);
postop.SetBGridFunction(B);
double Um = postop.GetHFieldEnergy();
M(i, j) = Um / (Iinc(i) * Iinc(j)) -
Expand Down
3 changes: 1 addition & 2 deletions palace/drivers/transientsolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ TransientSolver::Solve(const std::vector<std::unique_ptr<mfem::ParMesh>> &mesh)
// Initialize structures for storing and reducing the results of error estimation.
CurlFluxErrorEstimator<Vector> estimator(
spaceop.GetMaterialOp(), spaceop.GetNDSpaces(), iodata.solver.linear.estimator_tol,
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold,
iodata.solver.pa_discrete_interp);
iodata.solver.linear.estimator_max_it, 0, iodata.solver.pa_order_threshold);
ErrorIndicator indicator;

// Main time integration loop.
Expand Down
94 changes: 94 additions & 0 deletions palace/fem/fespace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#include "fespace.hpp"

#include "fem/bilinearform.hpp"
#include "fem/integrator.hpp"
#include "linalg/rap.hpp"
#include "utils/omp.hpp"

namespace palace
Expand All @@ -24,4 +27,95 @@ std::size_t FiniteElementSpace::GetId() const
return id;
}

const Operator &AuxiliaryFiniteElementSpace::BuildDiscreteInterpolator() const
{
// G is always partially assembled.
const int dim = GetParMesh()->Dimension();
const auto aux_map_type = FEColl()->GetMapType(dim);
const auto primal_map_type = primal_fespace.FEColl()->GetMapType(dim);
if (aux_map_type == mfem::FiniteElement::VALUE &&
primal_map_type == mfem::FiniteElement::H_CURL)
{
// Discrete gradient interpolator
DiscreteLinearOperator interp(*this, primal_fespace);
interp.AddDomainInterpolator<GradientInterpolator>();
G = std::make_unique<ParOperator>(interp.Assemble(), *this, primal_fespace, true);
}
else if (primal_map_type == mfem::FiniteElement::VALUE &&
aux_map_type == mfem::FiniteElement::H_CURL)
{
// Discrete gradient interpolator (spaces reversed)
DiscreteLinearOperator interp(primal_fespace, *this);
interp.AddDomainInterpolator<GradientInterpolator>();
G = std::make_unique<ParOperator>(interp.Assemble(), primal_fespace, *this, true);
}
else if (aux_map_type == mfem::FiniteElement::H_CURL &&
primal_map_type == mfem::FiniteElement::H_DIV)
{
// Discrete curl interpolator
DiscreteLinearOperator interp(*this, primal_fespace);
interp.AddDomainInterpolator<CurlInterpolator>();
G = std::make_unique<ParOperator>(interp.Assemble(), *this, primal_fespace, true);
}
else if (primal_map_type == mfem::FiniteElement::H_CURL &&
aux_map_type == mfem::FiniteElement::H_DIV)
{
// Discrete curl interpolator (spaces reversed)
DiscreteLinearOperator interp(primal_fespace, *this);
interp.AddDomainInterpolator<CurlInterpolator>();
G = std::make_unique<ParOperator>(interp.Assemble(), primal_fespace, *this, true);
}
else if (aux_map_type == mfem::FiniteElement::H_DIV &&
primal_map_type == mfem::FiniteElement::INTEGRAL)
{
// Discrete divergence interpolator
DiscreteLinearOperator interp(*this, primal_fespace);
interp.AddDomainInterpolator<DivergenceInterpolator>();
G = std::make_unique<ParOperator>(interp.Assemble(), *this, primal_fespace, true);
}
else if (primal_map_type == mfem::FiniteElement::H_DIV &&
aux_map_type == mfem::FiniteElement::INTEGRAL)
{
// Discrete divergence interpolator (spaces reversed)
DiscreteLinearOperator interp(primal_fespace, *this);
interp.AddDomainInterpolator<DivergenceInterpolator>();
G = std::make_unique<ParOperator>(interp.Assemble(), primal_fespace, *this, true);
}
else
{
MFEM_ABORT("Unsupported trial/test FE spaces for AuxiliaryFiniteElementSpace discrete "
"interpolator!");
}

return *G;
}

template <typename FESpace>
const Operator &
BaseFiniteElementSpaceHierarchy<FESpace>::BuildProlongationAtLevel(std::size_t l) const
{
// P is always partially assembled.
MFEM_VERIFY(l >= 0 && l < GetNumLevels() - 1,
"Can only construct a finite element space prolongation with more than one "
"space in the hierarchy!");
if (fespaces[l]->GetParMesh() != fespaces[l + 1]->GetParMesh())
{
P[l] = std::make_unique<ParOperator>(
std::make_unique<mfem::TransferOperator>(*fespaces[l], *fespaces[l + 1]),
*fespaces[l], *fespaces[l + 1], true);
}
else
{
DiscreteLinearOperator p(*fespaces[l], *fespaces[l + 1]);
p.AddDomainInterpolator<IdentityInterpolator>();
P[l] =
std::make_unique<ParOperator>(p.Assemble(), *fespaces[l], *fespaces[l + 1], true);
}

return *P[l];
}

template class BaseFiniteElementSpaceHierarchy<FiniteElementSpace>;
template class BaseFiniteElementSpaceHierarchy<AuxiliaryFiniteElementSpace>;

} // namespace palace
Loading

0 comments on commit e5118ac

Please sign in to comment.