Skip to content

Commit

Permalink
Merge pull request #183 from awslabs/sjg/openmp-improvements
Browse files Browse the repository at this point in the history
Add OpenMP parallelism where possible
  • Loading branch information
sebastiangrimberg authored Feb 7, 2024
2 parents 7d26775 + f1bfece commit ceecf5c
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 69 deletions.
44 changes: 16 additions & 28 deletions palace/linalg/ams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "fem/fespace.hpp"
#include "fem/integrator.hpp"
#include "linalg/rap.hpp"
#include "utils/omp.hpp"

namespace palace
{
Expand Down Expand Up @@ -50,9 +51,11 @@ void HypreAmsSolver::ConstructAuxiliaryMatrices(FiniteElementSpace &nd_fespace,
AuxiliaryFiniteElementSpace &h1_fespace)
{
// Set up the auxiliary space objects for the preconditioner. Mostly the same as MFEM's
// HypreAMS:Init. Start with the discrete gradient matrix.
// HypreAMS:Init. Start with the discrete gradient matrix. We don't skip zeros for the
// full assembly to accelerate things on GPU and since they shouldn't affect the sparsity
// pattern of the parallel G^T A G matrix (computed by Hypre).
const bool skip_zeros_interp = !mfem::Device::Allows(mfem::Backend::DEVICE_MASK);
{
constexpr bool skip_zeros_interp = true;
const auto *PtGP =
dynamic_cast<const ParOperator *>(&h1_fespace.GetDiscreteInterpolator());
MFEM_VERIFY(
Expand All @@ -62,41 +65,26 @@ void HypreAmsSolver::ConstructAuxiliaryMatrices(FiniteElementSpace &nd_fespace,
}

// Vertex coordinates for the lowest order case, or Nedelec interpolation matrix or
// matrices for order > 1.
// matrices for order > 1. Expects that Mesh::SetVerticesFromNodes has been called at some
// point to avoid calling GridFunction::GetNodalValues here.
mfem::ParMesh &mesh = h1_fespace.GetParMesh();
if (h1_fespace.GetMaxElementOrder() == 1)
{
mfem::ParGridFunction x_coord(&h1_fespace.Get()), y_coord(&h1_fespace.Get()),
z_coord(&h1_fespace.Get());
if (mesh.GetNodes())
MFEM_VERIFY(x_coord.Size() == mesh.GetNV(),
"Unexpected size for vertex coordinates in AMS setup!");
PalacePragmaOmp(parallel for schedule(static))
for (int i = 0; i < mesh.GetNV(); i++)
{
mesh.GetNodes()->GetNodalValues(x_coord, 1);
MFEM_VERIFY(x_coord.Size() == h1_fespace.GetVSize(),
"Unexpected size for vertex coordinates in AMS setup!");
x_coord(i) = mesh.GetVertex(i)[0];
if (space_dim > 1)
{
mesh.GetNodes()->GetNodalValues(y_coord, 2);
y_coord(i) = mesh.GetVertex(i)[1];
}
if (space_dim > 2)
{
mesh.GetNodes()->GetNodalValues(z_coord, 3);
}
}
else
{
MFEM_VERIFY(x_coord.Size() == mesh.GetNV(),
"Unexpected size for vertex coordinates in AMS setup!");
for (int i = 0; i < mesh.GetNV(); i++)
{
x_coord(i) = mesh.GetVertex(i)[0];
if (space_dim > 1)
{
y_coord(i) = mesh.GetVertex(i)[1];
}
if (space_dim > 2)
{
z_coord(i) = mesh.GetVertex(i)[2];
}
z_coord(i) = mesh.GetVertex(i)[2];
}
}
x.reset(x_coord.ParallelProject());
Expand All @@ -120,8 +108,8 @@ void HypreAmsSolver::ConstructAuxiliaryMatrices(FiniteElementSpace &nd_fespace,
mfem::DiscreteLinearOperator pi(&h1d_fespace.Get(), &nd_fespace.Get());
pi.AddDomainInterpolator(new mfem::IdentityInterpolator);
pi.SetAssemblyLevel(mfem::AssemblyLevel::LEGACY);
pi.Assemble();
pi.Finalize();
pi.Assemble(skip_zeros_interp);
pi.Finalize(skip_zeros_interp);
ParOperator RAP_Pi(std::unique_ptr<mfem::SparseMatrix>(pi.LoseMat()), h1d_fespace,
nd_fespace, true);
Pi = RAP_Pi.StealParallelAssemble();
Expand Down
132 changes: 92 additions & 40 deletions palace/utils/geodata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "utils/filesystem.hpp"
#include "utils/iodata.hpp"
#include "utils/meshio.hpp"
#include "utils/omp.hpp"
#include "utils/timer.hpp"

namespace palace
Expand Down Expand Up @@ -359,7 +360,7 @@ void RefineMesh(const IoData &iodata, std::vector<std::unique_ptr<mfem::ParMesh>

// Print some mesh information.
mfem::Vector bbmin, bbmax;
mesh[0]->GetBoundingBox(bbmin, bbmax);
GetAxisAlignedBoundingBox(*mesh[0], bbmin, bbmax);
const double Lc = iodata.DimensionalizeValue(IoData::ValueType::LENGTH, 1.0);
Mpi::Print(mesh[0]->GetComm(), "\nMesh curvature order: {}\nMesh bounding box:\n",
mesh[0]->GetNodes()
Expand Down Expand Up @@ -394,13 +395,15 @@ namespace

void ScaleMesh(mfem::Mesh &mesh, double L)
{
PalacePragmaOmp(parallel for schedule(static))
for (int i = 0; i < mesh.GetNV(); i++)
{
double *v = mesh.GetVertex(i);
std::transform(v, v + mesh.SpaceDimension(), v, [L](double val) { return val * L; });
}
if (auto *pmesh = dynamic_cast<mfem::ParMesh *>(&mesh))
{
PalacePragmaOmp(parallel for schedule(static))
for (int i = 0; i < pmesh->face_nbr_vertices.Size(); i++)
{
double *v = pmesh->face_nbr_vertices[i]();
Expand Down Expand Up @@ -519,7 +522,8 @@ void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, const mfem::Array<int>
}
if (!mesh.GetNodes())
{
auto BBUpdate = [&mesh, &dim, &min, &max](const mfem::Array<int> &verts) -> void
auto BBUpdate =
[&mesh, &dim](const mfem::Array<int> &verts, mfem::Vector &min, mfem::Vector &max)
{
for (int j = 0; j < verts.Size(); j++)
{
Expand All @@ -537,40 +541,60 @@ void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, const mfem::Array<int>
}
}
};
if (bdr)
PalacePragmaOmp(parallel)
{
for (int i = 0; i < mesh.GetNBE(); i++)
mfem::Vector loc_min(dim), loc_max(dim);
for (int d = 0; d < dim; d++)
{
if (!marker[mesh.GetBdrAttribute(i) - 1])
loc_min(d) = mfem::infinity();
loc_max(d) = -mfem::infinity();
}
mfem::Array<int> verts;
if (bdr)
{
PalacePragmaOmp(for schedule(static))
for (int i = 0; i < mesh.GetNBE(); i++)
{
continue;
if (!marker[mesh.GetBdrAttribute(i) - 1])
{
continue;
}
mesh.GetBdrElementVertices(i, verts);
BBUpdate(verts, loc_min, loc_max);
}
mfem::Array<int> verts;
mesh.GetBdrElementVertices(i, verts);
BBUpdate(verts);
}
}
else
{
for (int i = 0; i < mesh.GetNE(); i++)
else
{
if (!marker[mesh.GetAttribute(i) - 1])
PalacePragmaOmp(for schedule(static))
for (int i = 0; i < mesh.GetNE(); i++)
{
continue;
if (!marker[mesh.GetAttribute(i) - 1])
{
continue;
}
mesh.GetElementVertices(i, verts);
BBUpdate(verts, loc_min, loc_max);
}
}
PalacePragmaOmp(critical(BBUpdate))
{
for (int d = 0; d < dim; d++)
{
min(d) = std::min(min(d), loc_min(d));
max(d) = std::max(max(d), loc_max(d));
}
mfem::Array<int> verts;
mesh.GetElementVertices(i, verts);
BBUpdate(verts);
}
}
}
else
{
auto BBUpdate = [&min, &max](mfem::ElementTransformation &T, mfem::Geometry::Type &geom,
int ref) -> void
mesh.GetNodes()->HostRead();
const int ref = mesh.GetNodes()->FESpace()->GetMaxElementOrder();
auto BBUpdate = [&ref](mfem::GeometryRefiner refiner, mfem::Geometry::Type &geom,
mfem::ElementTransformation &T, mfem::DenseMatrix &pointmat,
mfem::Vector &min, mfem::Vector &max)
{
mfem::DenseMatrix pointmat;
mfem::RefinedGeometry *RefG = mfem::GlobGeometryRefiner.Refine(geom, ref);
mfem::RefinedGeometry *RefG = refiner.Refine(geom, ref);
T.Transform(RefG->RefPts, pointmat);
for (int j = 0; j < pointmat.Width(); j++)
{
Expand All @@ -587,32 +611,52 @@ void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, const mfem::Array<int>
}
}
};
const int ref = mesh.GetNodes()->FESpace()->GetMaxElementOrder();
mfem::IsoparametricTransformation T;
if (bdr)
PalacePragmaOmp(parallel)
{
for (int i = 0; i < mesh.GetNBE(); i++)
mfem::Vector loc_min(dim), loc_max(dim);
for (int d = 0; d < dim; d++)
{
if (!marker[mesh.GetBdrAttribute(i) - 1])
loc_min(d) = mfem::infinity();
loc_max(d) = -mfem::infinity();
}
mfem::GeometryRefiner refiner;
mfem::IsoparametricTransformation T;
mfem::DenseMatrix pointmat;
if (bdr)
{
PalacePragmaOmp(for schedule(static))
for (int i = 0; i < mesh.GetNBE(); i++)
{
continue;
if (!marker[mesh.GetBdrAttribute(i) - 1])
{
continue;
}
mesh.GetBdrElementTransformation(i, &T);
mfem::Geometry::Type geom = mesh.GetBdrElementGeometry(i);
BBUpdate(refiner, geom, T, pointmat, loc_min, loc_max);
}
mesh.GetBdrElementTransformation(i, &T);
mfem::Geometry::Type geom = mesh.GetBdrElementGeometry(i);
BBUpdate(T, geom, ref);
}
}
else
{
for (int i = 0; i < mesh.GetNE(); i++)
else
{
if (!marker[mesh.GetAttribute(i) - 1])
PalacePragmaOmp(for schedule(static))
for (int i = 0; i < mesh.GetNE(); i++)
{
continue;
if (!marker[mesh.GetAttribute(i) - 1])
{
continue;
}
mesh.GetElementTransformation(i, &T);
mfem::Geometry::Type geom = mesh.GetElementGeometry(i);
BBUpdate(refiner, geom, T, pointmat, loc_min, loc_max);
}
}
PalacePragmaOmp(critical(BBUpdate))
{
for (int d = 0; d < dim; d++)
{
min(d) = std::min(min(d), loc_min(d));
max(d) = std::max(max(d), loc_max(d));
}
mesh.GetElementTransformation(i, &T);
mfem::Geometry::Type geom = mesh.GetElementGeometry(i);
BBUpdate(T, geom, ref);
}
}
}
Expand All @@ -629,6 +673,14 @@ void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, int attr, bool bdr,
GetAxisAlignedBoundingBox(mesh, marker, bdr, min, max);
}

void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, mfem::Vector &min,
mfem::Vector &max)
{
mfem::Array<int> marker(mesh.attributes.Max());
marker = 1;
GetAxisAlignedBoundingBox(mesh, marker, false, min, max);
}

double BoundingBox::Area() const
{
return 4.0 *
Expand Down
2 changes: 2 additions & 0 deletions palace/utils/geodata.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, const mfem::Array<int>
bool bdr, mfem::Vector &min, mfem::Vector &max);
void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, int attr, bool bdr,
mfem::Vector &min, mfem::Vector &max);
void GetAxisAlignedBoundingBox(const mfem::ParMesh &mesh, mfem::Vector &min,
mfem::Vector &max);

// Struct describing a bounding box in terms of the center and face normals. The normals
// specify the direction from the center of the box.
Expand Down
2 changes: 1 addition & 1 deletion palace/utils/iodata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void IoData::NondimensionalizeInputs(mfem::ParMesh &mesh)
else
{
mfem::Vector bbmin, bbmax;
mesh.GetBoundingBox(bbmin, bbmax);
mesh::GetAxisAlignedBoundingBox(mesh, bbmin, bbmax);
bbmax -= bbmin;
bbmax *= model.L0; // [m]
Lc = *std::max_element(bbmax.begin(), bbmax.end());
Expand Down

0 comments on commit ceecf5c

Please sign in to comment.