diff --git a/palace/linalg/solver.cpp b/palace/linalg/solver.cpp index 7ea33e07b..3eb59caf0 100644 --- a/palace/linalg/solver.cpp +++ b/palace/linalg/solver.cpp @@ -52,7 +52,26 @@ void MfemWrapperSolver::SetOperator(const ComplexOperator &op) } if (hAr && hAi) { - A.reset(mfem::Add(1.0, *hAr, 1.0, *hAi)); + // A = [Ar, -Ai] + // [Ai, Ar] + mfem::Array2D blocks(2, 2); + mfem::Array2D block_coeffs(2, 2); + blocks(0, 0) = hAr; + blocks(0, 1) = hAi; + blocks(1, 0) = hAi; + blocks(1, 1) = hAr; + block_coeffs(0, 0) = 1.0; + block_coeffs(0, 1) = -1.0; + block_coeffs(1, 0) = 1.0; + block_coeffs(1, 1) = 1.0; + A.reset(mfem::HypreParMatrixFromBlocks(blocks, &block_coeffs)); + idx1.SetSize(op.Width()); + idx2.SetSize(op.Width()); + for (int i = 0; i < op.Width(); i++) + { + idx1[i] = i; + idx2[i] = i + op.Width(); + } if (PtAPr) { PtAPr->StealParallelAssemble(); @@ -101,13 +120,32 @@ template <> void MfemWrapperSolver::Mult(const ComplexVector &x, ComplexVector &y) const { - mfem::Array X(2); - mfem::Array Y(2); - X[0] = &x.Real(); - X[1] = &x.Imag(); - Y[0] = &y.Real(); - Y[1] = &y.Imag(); - pc->ArrayMult(X, Y); + if (pc->Height() == x.Size()) + { + mfem::Array X(2); + mfem::Array Y(2); + X[0] = &x.Real(); + X[1] = &x.Imag(); + Y[0] = &y.Real(); + Y[1] = &y.Imag(); + pc->ArrayMult(X, Y); + } + else + { + Vector X(2 * x.Size()), Y(2 * y.Size()), yr, yi; + X.UseDevice(true); + Y.UseDevice(true); + yr.UseDevice(true); + yi.UseDevice(true); + X.SetSubVector(idx1, x.Real()); + X.SetSubVector(idx2, x.Imag()); + pc->Mult(X, Y); + Y.ReadWrite(); + yr.MakeRef(Y, 0, y.Size()); + yi.MakeRef(Y, y.Size(), y.Size()); + y.Real() = yr; + y.Imag() = yi; + } } } // namespace palace diff --git a/palace/linalg/solver.hpp b/palace/linalg/solver.hpp index 3dcc1096a..130f0bd1c 100644 --- a/palace/linalg/solver.hpp +++ b/palace/linalg/solver.hpp @@ -83,6 +83,9 @@ class MfemWrapperSolver : public Solver // mfem::Solver::SetOperator (some solvers copy their input). bool save_assembled; + // Indices of real and imaginary parts of the complex system RHS/solution. + mfem::Array idx1, idx2; + public: MfemWrapperSolver(std::unique_ptr &&pc, bool save_assembled = true) : Solver(pc->iterative_mode), pc(std::move(pc)),