Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly use MPI communicators #181

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cashocs/_constraints/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def _project_pointwise_multiplier(
lhs = trial * test * measure
rhs = project_term * test * measure

comm = self.cg_function_space.mesh().mpi_comm()
_utils.assemble_and_solve_linear(
lhs, rhs, A=A_tensor, b=b_tensor, fun=multiplier
lhs, rhs, A=A_tensor, b=b_tensor, fun=multiplier, comm=comm
)

def _update_cost_functional(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions cashocs/_database/geometry_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ def __init__(self, function_db: function_database.FunctionDatabase) -> None:
"""
self.mesh: fenics.Mesh = function_db.state_spaces[0].mesh()
self.dx: fenics.Measure = fenics.Measure("dx", self.mesh)
self.mpi_comm = self.mesh.mpi_comm()
1 change: 1 addition & 0 deletions cashocs/_forms/shape_form_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def compute(self) -> None:
A=self.A_mu_matrix,
b=self.b_mu,
ksp_options=self.options_mu,
comm=self.mesh.mpi_comm(),
)

if self.config.getboolean("ShapeGradient", "use_sqrt_mu"):
Expand Down
1 change: 1 addition & 0 deletions cashocs/_forms/shape_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def _compute_curvature(self) -> None:
A=self.a_curvature_matrix.mat(),
b=self.b_curvature.vec(),
fun=self.kappa_curvature,
comm=self.db.geometry_db.mpi_comm,
)

def scale(self) -> None:
Expand Down
10 changes: 7 additions & 3 deletions cashocs/_pde_problems/adjoint_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def __init__(

# pylint: disable=invalid-name
self.A_tensors = [
fenics.PETScMatrix() for _ in range(self.db.parameter_db.state_dim)
fenics.PETScMatrix(db.geometry_db.mpi_comm)
for _ in range(self.db.parameter_db.state_dim)
]
self.b_tensors = [
fenics.PETScVector() for _ in range(self.db.parameter_db.state_dim)
fenics.PETScVector(db.geometry_db.mpi_comm)
for _ in range(self.db.parameter_db.state_dim)
]

self.res_j_tensors = [
fenics.PETScVector() for _ in range(self.db.parameter_db.state_dim)
fenics.PETScVector(db.geometry_db.mpi_comm)
for _ in range(self.db.parameter_db.state_dim)
]

self._number_of_solves = 0
Expand Down Expand Up @@ -119,6 +122,7 @@ def solve(self) -> List[fenics.Function]:
b=self.b_tensors[-1 - i],
fun=self.adjoints[-1 - i],
ksp_options=self.db.parameter_db.adjoint_ksp_options[-1 - i],
comm=self.db.geometry_db.mpi_comm,
)

else:
Expand Down
1 change: 1 addition & 0 deletions cashocs/_pde_problems/control_gradient_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def solve(self) -> List[fenics.Function]:
b=self.b_tensors[i].vec(),
fun=self.db.function_db.gradient[i],
ksp_options=self.riesz_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

self.has_solution = True
Expand Down
3 changes: 3 additions & 0 deletions cashocs/_pde_problems/hessian_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def hessian_application(
self.bcs_list_ad[i],
fun=self.db.function_db.states_prime[i],
ksp_options=self.db.parameter_db.state_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

for i in range(self.state_dim):
Expand All @@ -185,6 +186,7 @@ def hessian_application(
self.bcs_list_ad[-1 - i],
fun=self.db.function_db.adjoints_prime[-1 - i],
ksp_options=self.db.parameter_db.adjoint_ksp_options[-1 - i],
comm=self.db.geometry_db.mpi_comm,
)

else:
Expand Down Expand Up @@ -234,6 +236,7 @@ def hessian_application(
b=b,
fun=out[i],
ksp_options=self.riesz_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

self.no_sensitivity_solves += 2
Expand Down
1 change: 1 addition & 0 deletions cashocs/_pde_problems/shape_gradient_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def solve(self) -> List[fenics.Function]:
b=self.form_handler.fe_shape_derivative_vector.vec(),
fun=self.db.function_db.gradient[0],
ksp_options=self.ksp_options,
comm=self.db.geometry_db.mpi_comm,
)

self.has_solution = True
Expand Down
1 change: 1 addition & 0 deletions cashocs/_pde_problems/state_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def solve(self) -> List[fenics.Function]:
b=self.b_tensors[i],
fun=self.states[i],
ksp_options=self.db.parameter_db.state_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

else:
Expand Down
28 changes: 27 additions & 1 deletion cashocs/_utils/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from cashocs._utils import forms as forms_module

if TYPE_CHECKING:
from mpi4py import MPI

from cashocs import _typing

iterative_ksp_options: _typing.KspOption = {
Expand Down Expand Up @@ -218,13 +220,32 @@ def setup_fieldsplit_preconditioner(
pc.setFieldSplitIS(*idx_tuples)


def _initialize_comm(comm: Optional[MPI.Comm] = None) -> MPI.Comm:
"""Initializes the MPI communicator.

If the supplied communicator is `None`, return MPI.comm_world.

Args:
comm: The supplied communicator or `None`

Returns:
The resulting communicator.

"""
if comm is None:
comm = fenics.MPI.comm_world

return comm


def solve_linear_problem(
A: Optional[PETSc.Mat] = None, # pylint: disable=invalid-name
b: Optional[PETSc.Vec] = None,
fun: Optional[fenics.Function] = None,
ksp_options: Optional[_typing.KspOption] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
comm: Optional[MPI.Comm] = None,
) -> PETSc.Vec:
"""Solves a finite dimensional linear problem.

Expand All @@ -245,12 +266,14 @@ def solve_linear_problem(
atol: The absolute tolerance used in case an iterative solver is used for
solving the linear problem. Overrides the specification in the ksp object
and ksp_options.
comm: The MPI communicator for the problem.

Returns:
The solution vector.

"""
ksp = PETSc.KSP().create()
comm = _initialize_comm(comm)
ksp = PETSc.KSP().create(comm=comm)

if A is not None:
ksp.setOperators(A)
Expand Down Expand Up @@ -306,6 +329,7 @@ def assemble_and_solve_linear(
ksp_options: Optional[_typing.KspOption] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
comm: Optional[MPI.Comm] = None,
) -> PETSc.Vec:
"""Assembles and solves a linear system.

Expand All @@ -325,6 +349,7 @@ def assemble_and_solve_linear(
atol: The absolute tolerance used in case an iterative solver is used for
solving the linear problem. Overrides the specification in the ksp object
and ksp_options.
comm: The MPI communicator for solving the problem.

Returns:
A PETSc vector containing the solution x.
Expand All @@ -341,6 +366,7 @@ def assemble_and_solve_linear(
ksp_options=ksp_options,
rtol=rtol,
atol=atol,
comm=comm,
)

return solution
Expand Down
8 changes: 6 additions & 2 deletions cashocs/geometry/boundary_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def compute_boundary_distance(
function_space = fenics.FunctionSpace(mesh, "CG", 1)
dx = measure.NamedMeasure("dx", mesh)

comm = mesh.mpi_comm()

ksp_options = copy.deepcopy(_utils.linalg.iterative_ksp_options)

u = fenics.TrialFunction(function_space)
Expand Down Expand Up @@ -95,7 +97,9 @@ def compute_boundary_distance(
lhs = fenics.dot(fenics.grad(u), fenics.grad(v)) * dx
rhs = fenics.Constant(1.0) * v * dx

_utils.assemble_and_solve_linear(lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options)
_utils.assemble_and_solve_linear(
lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options, comm=comm
)

rhs = fenics.dot(fenics.grad(u_prev) / norm_u_prev, fenics.grad(v)) * dx

Expand All @@ -114,7 +118,7 @@ def compute_boundary_distance(
u_prev.vector().vec().aypx(0.0, u_curr.vector().vec())
u_prev.vector().apply("")
_utils.assemble_and_solve_linear(
lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options
lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options, comm=comm
)
res = np.sqrt(fenics.assemble(residual_form))

Expand Down
1 change: 1 addition & 0 deletions cashocs/geometry/mesh_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def compute_decreases(
self.a_frobenius,
self.l_frobenius,
ksp_options=self.options_frobenius,
comm=self.db.geometry_db.mpi_comm,
)

frobenius_norm = x.max()[1]
Expand Down
5 changes: 2 additions & 3 deletions cashocs/geometry/mesh_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ def test(self, transformation: fenics.Function, volume_change: float) -> bool:
A boolean that indicates whether the desired transformation is feasible.

"""
comm = self.transformation_container.function_space().mesh().mpi_comm()
self.transformation_container.vector().vec().aypx(
0.0, transformation.vector().vec()
)
self.transformation_container.vector().apply("")
x = _utils.assemble_and_solve_linear(
self.A_prior,
self.l_prior,
ksp_options=self.options_prior,
self.A_prior, self.l_prior, ksp_options=self.options_prior, comm=comm
)

min_det = float(x.min()[1])
Expand Down
16 changes: 9 additions & 7 deletions cashocs/geometry/quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The element wise skewness of the mesh on process 0.

"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
skewness_array = self._quality_object.skewness(mesh).array()
skewness_list: np.ndarray = comm.gather(skewness_array, root=0)
if comm.rank == 0:
Expand Down Expand Up @@ -329,7 +329,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The maximum angle quality measure for each element on process 0.

"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
maximum_angle_array = self._quality_object.maximum_angle(mesh).array()
maximum_angle_list: np.ndarray = comm.gather(maximum_angle_array, root=0)
if comm.rank == 0:
Expand Down Expand Up @@ -361,7 +361,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The radius ratios of the mesh elements on process 0.

"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
radius_ratios_array = fenics.MeshQuality.radius_ratios(mesh).array()
radius_ratios_list: np.ndarray = comm.gather(radius_ratios_array, root=0)
if comm.rank == 0:
Expand All @@ -388,7 +388,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The condition numbers of the elements on process 0.

"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
function_space_dg0 = fenics.FunctionSpace(mesh, "DG", 0)
jac = ufl.Jacobian(mesh)
inv = ufl.JacobianInverse(mesh)
Expand Down Expand Up @@ -417,7 +417,9 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:

cond = fenics.Function(function_space_dg0)

_utils.assemble_and_solve_linear(lhs, rhs, fun=cond, ksp_options=options)
_utils.assemble_and_solve_linear(
lhs, rhs, fun=cond, ksp_options=options, comm=comm
)
cond.vector().vec().reciprocal()
cond.vector().apply("")
cond.vector().vec().scale(np.sqrt(mesh.geometric_dimension()))
Expand Down Expand Up @@ -454,7 +456,7 @@ def min(cls, calculator: MeshQualityCalculator, mesh: fenics.Mesh) -> float:

"""
quality_list = calculator.compute(mesh)
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
if comm.rank == 0:
qual = float(np.min(quality_list))
else:
Expand All @@ -477,7 +479,7 @@ def avg(cls, calculator: MeshQualityCalculator, mesh: fenics.Mesh) -> float:

"""
quality_list = calculator.compute(mesh)
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()

if comm.rank == 0:
qual = float(np.average(quality_list))
Expand Down
2 changes: 1 addition & 1 deletion cashocs/io/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def gather_coordinates(mesh: fenics.Mesh) -> np.ndarray:
A numpy array which contains the vertex coordinates of the mesh

"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
rank = comm.Get_rank()
top = mesh.topology()
global_vertex_indices = top.global_indices(0)
Expand Down
10 changes: 7 additions & 3 deletions cashocs/nonlinear_solvers/newton_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def __init__(
self.derivative, -self.nonlinear_form, self.bcs_hom
)
self.assembler.keep_diagonal = True

self.comm = self.u.function_space().mesh().mpi_comm()
# pylint: disable=invalid-name
self.A_fenics = self.A_tensor or fenics.PETScMatrix()
self.residual = self.b_tensor or fenics.PETScVector()
self.A_fenics = self.A_tensor or fenics.PETScMatrix(self.comm)
self.residual = self.b_tensor or fenics.PETScVector(self.comm)
self.b = fenics.as_backend_type(self.residual).vec()
self.A_matrix = fenics.as_backend_type(self.A_fenics).mat()

Expand All @@ -160,7 +162,7 @@ def __init__(
self.assembler_shift = fenics.SystemAssembler(
self.derivative, self.shift, self.bcs_hom
)
self.residual_shift = fenics.PETScVector()
self.residual_shift = fenics.PETScVector(self.comm)

self.breakdown = False
self.res = 1.0
Expand Down Expand Up @@ -266,6 +268,7 @@ def solve(self) -> fenics.Function:
ksp_options=self.ksp_options,
rtol=self.eta,
atol=self.atol / 10.0,
comm=self.comm,
)

if self.is_linear:
Expand Down Expand Up @@ -370,6 +373,7 @@ def _backtracking_line_search(self) -> None:
ksp_options=self.ksp_options,
rtol=self.eta,
atol=self.atol / 10.0,
comm=self.comm,
)

if (
Expand Down
4 changes: 3 additions & 1 deletion cashocs/nonlinear_solvers/picard_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ def picard_iteration(
bcs_list = _utils.check_and_enlist_bcs(bcs_list)
bcs_list_hom = _create_homogenized_bcs(bcs_list)

comm = u_list[0].function_space().mesh().mpi_comm()

prefix = "Picard iteration: "

res_tensor = [fenics.PETScVector() for _ in range(len(u_list))]
res_tensor = [fenics.PETScVector(comm) for _ in u_list]
eta_max = 0.9
gamma = 0.9
res_0 = 1.0
Expand Down