Skip to content

Commit

Permalink
Make the global deformation work with new PETSc
Browse files Browse the repository at this point in the history
Replace calls to `getVec...` with `createVec` for better petsc4py / PETSc match
Refactor the way the transfer matrix is used for the global deformation (createVec does not have to be called here)
  • Loading branch information
sblauth committed Jul 11, 2024
1 parent f414851 commit 94eaf8f
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cashocs/_forms/control_form_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def scalar_product(
x = fenics.as_backend_type(a[i].vector()).vec()
y = fenics.as_backend_type(b[i].vector()).vec()

temp, _ = self.riesz_projection_matrices[i].getVecs()
temp = self.riesz_projection_matrices[i].createVecRight()
self.riesz_projection_matrices[i].mult(x, temp)
result += temp.dot(y)

Expand Down
2 changes: 1 addition & 1 deletion cashocs/_forms/shape_form_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def scalar_product(
x = fenics.as_backend_type(a[0].vector()).vec()
y = fenics.as_backend_type(b[0].vector()).vec()

temp, _ = self.scalar_product_matrix.getVecs()
temp = self.scalar_product_matrix.createVecRight()
self.scalar_product_matrix.mult(x, temp)
result = temp.dot(y)

Expand Down
6 changes: 3 additions & 3 deletions cashocs/_optimization/line_search/line_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def perform(

transfer_matrix = cast(PETSc.Mat, transfer_matrix)

_, temp = transfer_matrix.getVecs()
transfer_matrix.mult(x, temp)
self.global_deformation_vector.axpy(1.0, temp)
transfer_matrix.multAdd(
x, self.global_deformation_vector, self.global_deformation_vector
)
self.deformation_function.vector().apply("")

self.post_line_search()
Expand Down
4 changes: 2 additions & 2 deletions cashocs/_utils/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,10 @@ def solve(
A = setup_matrix_and_preconditioner(ksp, A, P)

if b is None:
return A.getVecs()[0]
return A.createVecRight()

if fun is None:
x, _ = A.getVecs()
x = A.createVecRight()
else:
x = fun.vector().vec()

Expand Down

0 comments on commit 94eaf8f

Please sign in to comment.