From ea84dcf19d4dee9ff3cb526e9952a750d55ecf12 Mon Sep 17 00:00:00 2001 From: Sebastian Blauth Date: Wed, 1 Feb 2023 08:41:39 +0100 Subject: [PATCH] Correctly use MPI communicators --- cashocs/_constraints/solvers.py | 3 +- cashocs/_database/geometry_database.py | 1 + cashocs/_forms/shape_form_handler.py | 1 + cashocs/_forms/shape_regularization.py | 1 + cashocs/_pde_problems/adjoint_problem.py | 10 +++++-- .../_pde_problems/control_gradient_problem.py | 1 + cashocs/_pde_problems/hessian_problems.py | 3 ++ .../_pde_problems/shape_gradient_problem.py | 1 + cashocs/_pde_problems/state_problem.py | 1 + cashocs/_utils/linalg.py | 28 ++++++++++++++++++- cashocs/geometry/boundary_distance.py | 8 ++++-- cashocs/geometry/mesh_handler.py | 1 + cashocs/geometry/mesh_testing.py | 5 ++-- cashocs/geometry/quality.py | 16 ++++++----- cashocs/io/mesh.py | 2 +- cashocs/nonlinear_solvers/newton_solver.py | 10 +++++-- cashocs/nonlinear_solvers/picard_solver.py | 4 ++- 17 files changed, 74 insertions(+), 22 deletions(-) diff --git a/cashocs/_constraints/solvers.py b/cashocs/_constraints/solvers.py index 4619e9d1..0a6fcf8f 100644 --- a/cashocs/_constraints/solvers.py +++ b/cashocs/_constraints/solvers.py @@ -207,8 +207,9 @@ def _project_pointwise_multiplier( lhs = trial * test * measure rhs = project_term * test * measure + comm = self.cg_function_space.mesh().mpi_comm() _utils.assemble_and_solve_linear( - lhs, rhs, A=A_tensor, b=b_tensor, fun=multiplier + lhs, rhs, A=A_tensor, b=b_tensor, fun=multiplier, comm=comm ) def _update_cost_functional(self) -> None: diff --git a/cashocs/_database/geometry_database.py b/cashocs/_database/geometry_database.py index 38a299af..11bd0f28 100644 --- a/cashocs/_database/geometry_database.py +++ b/cashocs/_database/geometry_database.py @@ -39,3 +39,4 @@ def __init__(self, function_db: function_database.FunctionDatabase) -> None: """ self.mesh: fenics.Mesh = function_db.state_spaces[0].mesh() self.dx: fenics.Measure = fenics.Measure("dx", self.mesh) + self.mpi_comm = self.mesh.mpi_comm() diff --git a/cashocs/_forms/shape_form_handler.py b/cashocs/_forms/shape_form_handler.py index 9baf1ed9..3c3a3da4 100644 --- a/cashocs/_forms/shape_form_handler.py +++ b/cashocs/_forms/shape_form_handler.py @@ -189,6 +189,7 @@ def compute(self) -> None: A=self.A_mu_matrix, b=self.b_mu, ksp_options=self.options_mu, + comm=self.mesh.mpi_comm(), ) if self.config.getboolean("ShapeGradient", "use_sqrt_mu"): diff --git a/cashocs/_forms/shape_regularization.py b/cashocs/_forms/shape_regularization.py index 1b6501c0..3d7d98f5 100644 --- a/cashocs/_forms/shape_regularization.py +++ b/cashocs/_forms/shape_regularization.py @@ -609,6 +609,7 @@ def _compute_curvature(self) -> None: A=self.a_curvature_matrix.mat(), b=self.b_curvature.vec(), fun=self.kappa_curvature, + comm=self.db.geometry_db.mpi_comm, ) def scale(self) -> None: diff --git a/cashocs/_pde_problems/adjoint_problem.py b/cashocs/_pde_problems/adjoint_problem.py index f071ac49..0519926e 100755 --- a/cashocs/_pde_problems/adjoint_problem.py +++ b/cashocs/_pde_problems/adjoint_problem.py @@ -68,14 +68,17 @@ def __init__( # pylint: disable=invalid-name self.A_tensors = [ - fenics.PETScMatrix() for _ in range(self.db.parameter_db.state_dim) + fenics.PETScMatrix(db.geometry_db.mpi_comm) + for _ in range(self.db.parameter_db.state_dim) ] self.b_tensors = [ - fenics.PETScVector() for _ in range(self.db.parameter_db.state_dim) + fenics.PETScVector(db.geometry_db.mpi_comm) + for _ in range(self.db.parameter_db.state_dim) ] self.res_j_tensors = [ - fenics.PETScVector() for _ in range(self.db.parameter_db.state_dim) + fenics.PETScVector(db.geometry_db.mpi_comm) + for _ in range(self.db.parameter_db.state_dim) ] self._number_of_solves = 0 @@ -119,6 +122,7 @@ def solve(self) -> List[fenics.Function]: b=self.b_tensors[-1 - i], fun=self.adjoints[-1 - i], ksp_options=self.db.parameter_db.adjoint_ksp_options[-1 - i], + comm=self.db.geometry_db.mpi_comm, ) else: diff --git a/cashocs/_pde_problems/control_gradient_problem.py b/cashocs/_pde_problems/control_gradient_problem.py index 947411f4..09aa1b88 100755 --- a/cashocs/_pde_problems/control_gradient_problem.py +++ b/cashocs/_pde_problems/control_gradient_problem.py @@ -111,6 +111,7 @@ def solve(self) -> List[fenics.Function]: b=self.b_tensors[i].vec(), fun=self.db.function_db.gradient[i], ksp_options=self.riesz_ksp_options[i], + comm=self.db.geometry_db.mpi_comm, ) self.has_solution = True diff --git a/cashocs/_pde_problems/hessian_problems.py b/cashocs/_pde_problems/hessian_problems.py index 8c017778..b3804b62 100755 --- a/cashocs/_pde_problems/hessian_problems.py +++ b/cashocs/_pde_problems/hessian_problems.py @@ -174,6 +174,7 @@ def hessian_application( self.bcs_list_ad[i], fun=self.db.function_db.states_prime[i], ksp_options=self.db.parameter_db.state_ksp_options[i], + comm=self.db.geometry_db.mpi_comm, ) for i in range(self.state_dim): @@ -185,6 +186,7 @@ def hessian_application( self.bcs_list_ad[-1 - i], fun=self.db.function_db.adjoints_prime[-1 - i], ksp_options=self.db.parameter_db.adjoint_ksp_options[-1 - i], + comm=self.db.geometry_db.mpi_comm, ) else: @@ -234,6 +236,7 @@ def hessian_application( b=b, fun=out[i], ksp_options=self.riesz_ksp_options[i], + comm=self.db.geometry_db.mpi_comm, ) self.no_sensitivity_solves += 2 diff --git a/cashocs/_pde_problems/shape_gradient_problem.py b/cashocs/_pde_problems/shape_gradient_problem.py index db0dbd47..5ce5ec24 100755 --- a/cashocs/_pde_problems/shape_gradient_problem.py +++ b/cashocs/_pde_problems/shape_gradient_problem.py @@ -154,6 +154,7 @@ def solve(self) -> List[fenics.Function]: b=self.form_handler.fe_shape_derivative_vector.vec(), fun=self.db.function_db.gradient[0], ksp_options=self.ksp_options, + comm=self.db.geometry_db.mpi_comm, ) self.has_solution = True diff --git a/cashocs/_pde_problems/state_problem.py b/cashocs/_pde_problems/state_problem.py index 1aa5cc1c..b01c2e73 100755 --- a/cashocs/_pde_problems/state_problem.py +++ b/cashocs/_pde_problems/state_problem.py @@ -128,6 +128,7 @@ def solve(self) -> List[fenics.Function]: b=self.b_tensors[i], fun=self.states[i], ksp_options=self.db.parameter_db.state_ksp_options[i], + comm=self.db.geometry_db.mpi_comm, ) else: diff --git a/cashocs/_utils/linalg.py b/cashocs/_utils/linalg.py index bb01ed2d..fbc836ff 100644 --- a/cashocs/_utils/linalg.py +++ b/cashocs/_utils/linalg.py @@ -31,6 +31,8 @@ from cashocs._utils import forms as forms_module if TYPE_CHECKING: + from mpi4py import MPI + from cashocs import _typing iterative_ksp_options: _typing.KspOption = { @@ -218,6 +220,24 @@ def setup_fieldsplit_preconditioner( pc.setFieldSplitIS(*idx_tuples) +def _initialize_comm(comm: Optional[MPI.Comm] = None) -> MPI.Comm: + """Initializes the MPI communicator. + + If the supplied communicator is `None`, return MPI.comm_world. + + Args: + comm: The supplied communicator or `None` + + Returns: + The resulting communicator. + + """ + if comm is None: + comm = fenics.MPI.comm_world + + return comm + + def solve_linear_problem( A: Optional[PETSc.Mat] = None, # pylint: disable=invalid-name b: Optional[PETSc.Vec] = None, @@ -225,6 +245,7 @@ def solve_linear_problem( ksp_options: Optional[_typing.KspOption] = None, rtol: Optional[float] = None, atol: Optional[float] = None, + comm: Optional[MPI.Comm] = None, ) -> PETSc.Vec: """Solves a finite dimensional linear problem. @@ -245,12 +266,14 @@ def solve_linear_problem( atol: The absolute tolerance used in case an iterative solver is used for solving the linear problem. Overrides the specification in the ksp object and ksp_options. + comm: The MPI communicator for the problem. Returns: The solution vector. """ - ksp = PETSc.KSP().create() + comm = _initialize_comm(comm) + ksp = PETSc.KSP().create(comm=comm) if A is not None: ksp.setOperators(A) @@ -306,6 +329,7 @@ def assemble_and_solve_linear( ksp_options: Optional[_typing.KspOption] = None, rtol: Optional[float] = None, atol: Optional[float] = None, + comm: Optional[MPI.Comm] = None, ) -> PETSc.Vec: """Assembles and solves a linear system. @@ -325,6 +349,7 @@ def assemble_and_solve_linear( atol: The absolute tolerance used in case an iterative solver is used for solving the linear problem. Overrides the specification in the ksp object and ksp_options. + comm: The MPI communicator for solving the problem. Returns: A PETSc vector containing the solution x. @@ -341,6 +366,7 @@ def assemble_and_solve_linear( ksp_options=ksp_options, rtol=rtol, atol=atol, + comm=comm, ) return solution diff --git a/cashocs/geometry/boundary_distance.py b/cashocs/geometry/boundary_distance.py index 71ca2aab..e5e0db40 100644 --- a/cashocs/geometry/boundary_distance.py +++ b/cashocs/geometry/boundary_distance.py @@ -65,6 +65,8 @@ def compute_boundary_distance( function_space = fenics.FunctionSpace(mesh, "CG", 1) dx = measure.NamedMeasure("dx", mesh) + comm = mesh.mpi_comm() + ksp_options = copy.deepcopy(_utils.linalg.iterative_ksp_options) u = fenics.TrialFunction(function_space) @@ -95,7 +97,9 @@ def compute_boundary_distance( lhs = fenics.dot(fenics.grad(u), fenics.grad(v)) * dx rhs = fenics.Constant(1.0) * v * dx - _utils.assemble_and_solve_linear(lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options) + _utils.assemble_and_solve_linear( + lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options, comm=comm + ) rhs = fenics.dot(fenics.grad(u_prev) / norm_u_prev, fenics.grad(v)) * dx @@ -114,7 +118,7 @@ def compute_boundary_distance( u_prev.vector().vec().aypx(0.0, u_curr.vector().vec()) u_prev.vector().apply("") _utils.assemble_and_solve_linear( - lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options + lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options, comm=comm ) res = np.sqrt(fenics.assemble(residual_form)) diff --git a/cashocs/geometry/mesh_handler.py b/cashocs/geometry/mesh_handler.py index 5f7061b1..650c4554 100644 --- a/cashocs/geometry/mesh_handler.py +++ b/cashocs/geometry/mesh_handler.py @@ -363,6 +363,7 @@ def compute_decreases( self.a_frobenius, self.l_frobenius, ksp_options=self.options_frobenius, + comm=self.db.geometry_db.mpi_comm, ) frobenius_norm = x.max()[1] diff --git a/cashocs/geometry/mesh_testing.py b/cashocs/geometry/mesh_testing.py index 8e3ca961..789a2960 100644 --- a/cashocs/geometry/mesh_testing.py +++ b/cashocs/geometry/mesh_testing.py @@ -92,14 +92,13 @@ def test(self, transformation: fenics.Function, volume_change: float) -> bool: A boolean that indicates whether the desired transformation is feasible. """ + comm = self.transformation_container.function_space().mesh().mpi_comm() self.transformation_container.vector().vec().aypx( 0.0, transformation.vector().vec() ) self.transformation_container.vector().apply("") x = _utils.assemble_and_solve_linear( - self.A_prior, - self.l_prior, - ksp_options=self.options_prior, + self.A_prior, self.l_prior, ksp_options=self.options_prior, comm=comm ) min_det = float(x.min()[1]) diff --git a/cashocs/geometry/quality.py b/cashocs/geometry/quality.py index e4a99e27..96625eac 100644 --- a/cashocs/geometry/quality.py +++ b/cashocs/geometry/quality.py @@ -294,7 +294,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray: The element wise skewness of the mesh on process 0. """ - comm = fenics.MPI.comm_world + comm = mesh.mpi_comm() skewness_array = self._quality_object.skewness(mesh).array() skewness_list: np.ndarray = comm.gather(skewness_array, root=0) if comm.rank == 0: @@ -329,7 +329,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray: The maximum angle quality measure for each element on process 0. """ - comm = fenics.MPI.comm_world + comm = mesh.mpi_comm() maximum_angle_array = self._quality_object.maximum_angle(mesh).array() maximum_angle_list: np.ndarray = comm.gather(maximum_angle_array, root=0) if comm.rank == 0: @@ -361,7 +361,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray: The radius ratios of the mesh elements on process 0. """ - comm = fenics.MPI.comm_world + comm = mesh.mpi_comm() radius_ratios_array = fenics.MeshQuality.radius_ratios(mesh).array() radius_ratios_list: np.ndarray = comm.gather(radius_ratios_array, root=0) if comm.rank == 0: @@ -388,7 +388,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray: The condition numbers of the elements on process 0. """ - comm = fenics.MPI.comm_world + comm = mesh.mpi_comm() function_space_dg0 = fenics.FunctionSpace(mesh, "DG", 0) jac = ufl.Jacobian(mesh) inv = ufl.JacobianInverse(mesh) @@ -417,7 +417,9 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray: cond = fenics.Function(function_space_dg0) - _utils.assemble_and_solve_linear(lhs, rhs, fun=cond, ksp_options=options) + _utils.assemble_and_solve_linear( + lhs, rhs, fun=cond, ksp_options=options, comm=comm + ) cond.vector().vec().reciprocal() cond.vector().apply("") cond.vector().vec().scale(np.sqrt(mesh.geometric_dimension())) @@ -454,7 +456,7 @@ def min(cls, calculator: MeshQualityCalculator, mesh: fenics.Mesh) -> float: """ quality_list = calculator.compute(mesh) - comm = fenics.MPI.comm_world + comm = mesh.mpi_comm() if comm.rank == 0: qual = float(np.min(quality_list)) else: @@ -477,7 +479,7 @@ def avg(cls, calculator: MeshQualityCalculator, mesh: fenics.Mesh) -> float: """ quality_list = calculator.compute(mesh) - comm = fenics.MPI.comm_world + comm = mesh.mpi_comm() if comm.rank == 0: qual = float(np.average(quality_list)) diff --git a/cashocs/io/mesh.py b/cashocs/io/mesh.py index f7d28ae1..52f07eb1 100644 --- a/cashocs/io/mesh.py +++ b/cashocs/io/mesh.py @@ -207,7 +207,7 @@ def gather_coordinates(mesh: fenics.Mesh) -> np.ndarray: A numpy array which contains the vertex coordinates of the mesh """ - comm = fenics.MPI.comm_world + comm = mesh.mpi_comm() rank = comm.Get_rank() top = mesh.topology() global_vertex_indices = top.global_indices(0) diff --git a/cashocs/nonlinear_solvers/newton_solver.py b/cashocs/nonlinear_solvers/newton_solver.py index 6bbbaab4..1b522293 100644 --- a/cashocs/nonlinear_solvers/newton_solver.py +++ b/cashocs/nonlinear_solvers/newton_solver.py @@ -148,9 +148,11 @@ def __init__( self.derivative, -self.nonlinear_form, self.bcs_hom ) self.assembler.keep_diagonal = True + + self.comm = self.u.function_space().mesh().mpi_comm() # pylint: disable=invalid-name - self.A_fenics = self.A_tensor or fenics.PETScMatrix() - self.residual = self.b_tensor or fenics.PETScVector() + self.A_fenics = self.A_tensor or fenics.PETScMatrix(self.comm) + self.residual = self.b_tensor or fenics.PETScVector(self.comm) self.b = fenics.as_backend_type(self.residual).vec() self.A_matrix = fenics.as_backend_type(self.A_fenics).mat() @@ -160,7 +162,7 @@ def __init__( self.assembler_shift = fenics.SystemAssembler( self.derivative, self.shift, self.bcs_hom ) - self.residual_shift = fenics.PETScVector() + self.residual_shift = fenics.PETScVector(self.comm) self.breakdown = False self.res = 1.0 @@ -266,6 +268,7 @@ def solve(self) -> fenics.Function: ksp_options=self.ksp_options, rtol=self.eta, atol=self.atol / 10.0, + comm=self.comm, ) if self.is_linear: @@ -370,6 +373,7 @@ def _backtracking_line_search(self) -> None: ksp_options=self.ksp_options, rtol=self.eta, atol=self.atol / 10.0, + comm=self.comm, ) if ( diff --git a/cashocs/nonlinear_solvers/picard_solver.py b/cashocs/nonlinear_solvers/picard_solver.py index 43678df9..26fd1181 100644 --- a/cashocs/nonlinear_solvers/picard_solver.py +++ b/cashocs/nonlinear_solvers/picard_solver.py @@ -128,9 +128,11 @@ def picard_iteration( bcs_list = _utils.check_and_enlist_bcs(bcs_list) bcs_list_hom = _create_homogenized_bcs(bcs_list) + comm = u_list[0].function_space().mesh().mpi_comm() + prefix = "Picard iteration: " - res_tensor = [fenics.PETScVector() for _ in range(len(u_list))] + res_tensor = [fenics.PETScVector(comm) for _ in u_list] eta_max = 0.9 gamma = 0.9 res_0 = 1.0