Skip to content

Commit

Permalink
Merge pull request #548 from firedrakeproject/TBendall/BackwardEulerW…
Browse files Browse the repository at this point in the history
…rapper

Fixes to some recovery related issues
  • Loading branch information
jshipton authored Sep 11, 2024
2 parents 9fa9cd2 + cec1ee6 commit 6eebf46
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 27 deletions.
31 changes: 16 additions & 15 deletions gusto/recovery/recovery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def __init__(self, domain, boundary_method=None, use_vector_spaces=False):
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
boundary_method (:variable:'dict', optional): A dictionary containing the space
the boundary method is to be applied to along with specified method. Acceptable keys are "DG",
"HDiv" and "theta". acceptable values are (BoundaryMethod.taylor/hcurl/extruded),
passed as ('space', 'boundary method'). Defaults to None
use_vector_spaces (bool, optional):. Determines if we need to use DG / CG
space for the embedded and recovery space for the HDiv field instead of the usual
HDiv, HCurl spaces. Defaults to False
boundary_method (:variable:'dict', optional): A dictionary
containing the space the boundary method is to be applied to
along with specified method. Acceptable keys are "DG", "HDiv"
and "theta". Acceptable values are
(BoundaryMethod.taylor/hcurl/extruded). Defaults to None.
use_vector_spaces (bool, optional):. Determines if we need to use
the vector DG1 / CG1 space for the embedded and recovery space
for the HDiv field instead of the usual HDiv, HCurl spaces.
Defaults to False.
"""
family = domain.family
mesh = domain.mesh
Expand All @@ -36,7 +36,7 @@ def __init__(self, domain, boundary_method=None, use_vector_spaces=False):

valid_keys = ['DG', 'HDiv', 'theta']
if boundary_method is not None:
for key in boundary_method:
for key in boundary_method.keys():
if key not in valid_keys:
raise KeyError(f'Recovery spaces: boundary method key {key} not valid. Valid keys are DG, HDiv, theta')

Expand All @@ -47,7 +47,7 @@ def __init__(self, domain, boundary_method=None, use_vector_spaces=False):
# Check if extruded and if so builds theta spaces
if hasattr(mesh, "_base_mesh"):
# check if boundary method is present
if hasattr(boundary_method, 'theta'):
if boundary_method is not None and 'theta' in boundary_method.keys():
theta_boundary_method = boundary_method['theta']
else:
theta_boundary_method = None
Expand All @@ -71,7 +71,7 @@ def __init__(self, domain, boundary_method=None, use_vector_spaces=False):
# ----------------------------------------------------------------------
# Building the DG options
# ----------------------------------------------------------------------
if hasattr(boundary_method, 'DG'):
if boundary_method is not None and 'DG' in boundary_method.keys():
DG_boundary_method = boundary_method['DG']
else:
DG_boundary_method = None
Expand All @@ -91,7 +91,7 @@ def __init__(self, domain, boundary_method=None, use_vector_spaces=False):
# Building HDiv options
# ----------------------------------------------------------------------

if hasattr(boundary_method, 'HDiv'):
if boundary_method is not None and 'HDiv' in boundary_method.keys():
HDiv_boundary_method = boundary_method['HDiv']
else:
HDiv_boundary_method = None
Expand All @@ -102,16 +102,17 @@ def __init__(self, domain, boundary_method=None, use_vector_spaces=False):

HDiv_embedding_Space = Vu_DG1
HDiv_recovered_Space = Vu_CG1
project_high_method = 'interpolate'

else:

HDiv_embedding_Space = self.de_Rham.HDiv
HDiv_recovered_Space = self.de_Rham.HCurl
project_high_method = 'project'

self.HDiv_options = RecoveryOptions(embedding_space=HDiv_embedding_Space,
recovered_space=HDiv_recovered_Space,
injection_method='recover',
project_high_method='project',
project_high_method=project_high_method,
project_low_method='project',
broken_method='project',
boundary_method=HDiv_boundary_method)
5 changes: 4 additions & 1 deletion gusto/time_discretisation/imex_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from firedrake.fml import replace_subject, all_terms, drop
from firedrake.utils import cached_property
from gusto.core.labels import time_derivative, implicit, explicit
from gusto.time_discretisation.time_discretisation import TimeDiscretisation
from gusto.time_discretisation.time_discretisation import (
TimeDiscretisation, wrapper_apply
)
import numpy as np


Expand Down Expand Up @@ -209,6 +211,7 @@ def final_solver(self):
solver_name = self.field_name+self.__class__.__name__
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters, options_prefix=solver_name)

@wrapper_apply
def apply(self, x_out, x_in):
self.x1.assign(x_in)
solver_list = self.solvers
Expand Down
5 changes: 4 additions & 1 deletion gusto/time_discretisation/implicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from firedrake.utils import cached_property

from gusto.core.labels import time_derivative
from gusto.time_discretisation.time_discretisation import TimeDiscretisation
from gusto.time_discretisation.time_discretisation import (
TimeDiscretisation, wrapper_apply
)


__all__ = ["ImplicitRungeKutta", "ImplicitMidpoint", "QinZhang"]
Expand Down Expand Up @@ -142,6 +144,7 @@ def solve_stage(self, x0, stage):

self.k[stage].assign(self.x_out)

@wrapper_apply
def apply(self, x_out, x_in):

for i in range(self.nStages):
Expand Down
10 changes: 8 additions & 2 deletions gusto/time_discretisation/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def setup(self, equation, apply_bcs=True, *active_labels):
# -------------------------------------------------------------------- #

if self.wrapper is not None:

wrapper_bcs = bcs if apply_bcs else None

if self.wrapper_name == "mixed_options":

self.wrapper.wrapper_spaces = equation.spaces
Expand All @@ -199,7 +202,7 @@ def setup(self, equation, apply_bcs=True, *active_labels):
raise ValueError(f"The option defined for {field} is for a field that does not exist in the equation set")

field_idx = equation.field_names.index(field)
subwrapper.setup(equation.spaces[field_idx])
subwrapper.setup(equation.spaces[field_idx], wrapper_bcs)

# Update the function space to that needed by the wrapper
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space
Expand All @@ -218,7 +221,7 @@ def setup(self, equation, apply_bcs=True, *active_labels):
if self.wrapper_name == "supg":
self.wrapper.setup()
else:
self.wrapper.setup(self.fs)
self.wrapper.setup(self.fs, wrapper_bcs)
self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
Expand Down Expand Up @@ -480,6 +483,7 @@ def rhs(self):

return r.form

@wrapper_apply
def apply(self, x_out, x_in):
"""
Apply the time discretisation to advance one whole time step.
Expand Down Expand Up @@ -570,6 +574,7 @@ def rhs(self):

return r.form

@wrapper_apply
def apply(self, x_out, x_in):
"""
Apply the time discretisation to advance one whole time step.
Expand Down Expand Up @@ -731,6 +736,7 @@ def solver_bdf2(self):
return NonlinearVariationalSolver(problem, solver_parameters=self.solver_parameters,
options_prefix=solver_name)

@wrapper_apply
def apply(self, x_out, x_in):
"""
Apply the time discretisation to advance one whole time step.
Expand Down
36 changes: 28 additions & 8 deletions gusto/time_discretisation/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def setup(self, original_space):
Args:
original_space (:class:`FunctionSpace`): the space that the
prognostic variable is defined on. This is a subset space of
a mixed function space when using a MixedFSWrapper.
prognostic variable is defined on. This is a subset space of
a mixed function space when using a MixedFSWrapper.
"""
self.original_space = original_space

Expand Down Expand Up @@ -85,8 +85,17 @@ class EmbeddedDGWrapper(Wrapper):
the original space.
"""

def setup(self, original_space):
"""Sets up function spaces and fields needed for this wrapper."""
def setup(self, original_space, post_apply_bcs):
"""
Sets up function spaces and fields needed for this wrapper.
Args:
original_space (:class:`FunctionSpace`): the space that the
prognostic variable is defined on.
post_apply_bcs (list of :class:`DirichletBC`): list of Dirichlet
boundary condition objects to be passed to the projector used
in the post-apply step.
"""

assert isinstance(self.options, EmbeddedDGOptions), \
'Embedded DG wrapper can only be used with Embedded DG Options'
Expand Down Expand Up @@ -121,7 +130,8 @@ def setup(self, original_space):
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])

if self.options.project_back_method == 'project':
self.x_out_projector = Projector(self.x_out, self.x_projected)
self.x_out_projector = Projector(self.x_out, self.x_projected,
bcs=post_apply_bcs)
elif self.options.project_back_method == 'recover':
self.x_out_projector = Recoverer(self.x_out, self.x_projected)
else:
Expand Down Expand Up @@ -169,8 +179,17 @@ class RecoveryWrapper(Wrapper):
field is then returned to the original space.
"""

def setup(self, original_space):
"""Sets up function spaces and fields needed for this wrapper."""
def setup(self, original_space, post_apply_bcs):
"""
Sets up function spaces and fields needed for this wrapper.
Args:
original_space (:class:`FunctionSpace`): the space that the
prognostic variable is defined on.
post_apply_bcs (list of :class:`DirichletBC`): list of Dirichlet
boundary condition objects to be passed to the projector used
in the post-apply step.
"""

assert isinstance(self.options, RecoveryOptions), \
'Recovery wrapper can only be used with Recovery Options'
Expand Down Expand Up @@ -213,7 +232,8 @@ def setup(self, original_space):
if self.options.project_low_method == 'interpolate':
self.x_out_projector = Interpolator(self.x_out, self.x_projected)
elif self.options.project_low_method == 'project':
self.x_out_projector = Projector(self.x_out, self.x_projected)
self.x_out_projector = Projector(self.x_out, self.x_projected,
bcs=post_apply_bcs)
elif self.options.project_low_method == 'recover':
self.x_out_projector = Recoverer(self.x_out, self.x_projected,
method=self.options.broken_method)
Expand Down

0 comments on commit 6eebf46

Please sign in to comment.