Skip to content

Commit

Permalink
Merge pull request #362 from firedrakeproject/another_perp_fix
Browse files Browse the repository at this point in the history
PR #362: rewrite replace_subject/trial/test, fix how this works with the perp operator and add tests
  • Loading branch information
tommbendall authored Apr 24, 2023
2 parents b826358 + 1054c43 commit b1c4fd5
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 104 deletions.
12 changes: 7 additions & 5 deletions gusto/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,18 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None,
if fexpr is not None:
V = FunctionSpace(domain.mesh, "CG", 1)
f = self.prescribed_fields("coriolis", V).interpolate(fexpr)
coriolis_form = perp(
coriolis(
subject(prognostic(f*inner(u, w)*dx, "u"), self.X)
), domain.perp)
coriolis_form = coriolis(subject(
prognostic(f*inner(domain.perp(u), w)*dx, "u"), self.X))
if not domain.on_sphere:
coriolis_form = perp(coriolis_form, domain.perp)
# Add linearisation
if self.linearisation_map(coriolis_form.terms[0]):
linear_coriolis = perp(
coriolis(
subject(prognostic(f*inner(u_trial, w)*dx, "u"), self.X)
subject(prognostic(f*inner(domain.perp(u_trial), w)*dx, "u"), self.X)
), domain.perp)
if not domain.on_sphere:
linear_coriolis = perp(linear_coriolis, domain.perp)
coriolis_form = linearisation(coriolis_form, linear_coriolis)
residual += coriolis_form

Expand Down
195 changes: 118 additions & 77 deletions gusto/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,64 @@
from types import MethodType, LambdaType


def replace_test_function(new_test):
def _replace_dict(old, new, idx, replace_type):
"""
Build a dictionary to pass to the ufl.replace routine
The dictionary matches variables in the old term with those in the new
Does not check types unless indexing is required (leave type-checking to ufl.replace)
"""

replace_dict = {}

if type(old.ufl_element()) is MixedElement:

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new

if indexable_new:
split_new = new if type(new) is tuple else split(new)

if len(split_new) != len(old.function_space()):
raise ValueError(f"new {replace_type} of type {new} must be same length"
+ f"as replaced mixed {replace_type} of type {old}")

if idx is None:
for k, v in zip(split(old), split_new):
replace_dict[k] = v
else:
replace_dict[split(old)[idx]] = split_new[idx]

else: # new is not indexable
if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is mixed and"
+ f" new {replace_type} of type {new} is a single component")

replace_dict[split(old)[idx]] = new

else: # old is not mixed

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new

if indexable_new:
split_new = new if type(new) is tuple else split(new)

if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is not mixed"
+ f" and new {replace_type} of type {new} is indexable")

replace_dict[old] = split_new[idx]

else:
replace_dict[old] = new

return replace_dict


def replace_test_function(new_test, idx=None):
"""
A routine to replace the test function in a term with a new test function.
Expand All @@ -30,14 +87,22 @@ def repl(t):
Returns:
:class:`Term`: the new term.
"""
test = t.form.arguments()[0]
new_form = ufl.replace(t.form, {test: new_test})
old_test = t.form.arguments()[0]
replace_dict = _replace_dict(old_test, new_test, idx, 'test')

try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_test_function with {new_test}"
raise type(err)(error_message) from err

return Term(new_form, t.labels)

return repl


def replace_trial_function(new):
def replace_trial_function(new_trial, idx=None):
"""
A routine to replace the trial function in a term with a new expression.
Expand Down Expand Up @@ -65,14 +130,38 @@ def repl(t):
"""
if len(t.form.arguments()) != 2:
raise TypeError('Trying to replace trial function of a form that is not linear')
trial = t.form.arguments()[1]
new_form = ufl.replace(t.form, {trial: new})
old_trial = t.form.arguments()[1]
replace_dict = _replace_dict(old_trial, new_trial, idx, 'trial')

try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_trial_function with {new_trial}"
raise type(err)(error_message) from err

# When a term has the perp label, this indicates that replace
# cannot see that the perped object should also be
# replaced. In this case we also pass the perped object to
# replace.
if t.has_label(perp):
perp_op = t.get(perp)
perp_old = perp_op(old_trial)
perp_new = perp_op(new_trial)
try:
new_form = ufl.replace(t.form, {perp_old: perp_new})

except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_subject with {new_trial}"
raise type(err)(error_message) from err

return Term(new_form, t.labels)

return repl


def replace_subject(new, idx=None):
def replace_subject(new_subj, idx=None):
"""
A routine to replace the subject in a term with a new variable.
Expand All @@ -97,79 +186,31 @@ def repl(t):
:class:`Term`: the new term.
"""

subj = t.get(subject)

# Build a dictionary to pass to the ufl.replace routine
# The dictionary matches variables in the old term with those in the new
replace_dict = {}

# Consider cases that subj is normal Function or MixedFunction
# vs cases of new being Function vs MixedFunction vs tuple
# Ideally catch all cases or fail gracefully
if type(subj.ufl_element()) is MixedElement:
if type(new) == tuple:
assert len(new) == len(subj.function_space())
for k, v in zip(split(subj), new):
replace_dict[k] = v

elif type(new) == ufl.algebra.Sum:
replace_dict[subj] = new

elif isinstance(new, ufl.indexed.Indexed):
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when subject is Mixed and new is a single component')
replace_dict[split(subj)[idx]] = new

# Otherwise fail if new is not a function
elif not isinstance(new, Function):
raise ValueError(f'new must be a tuple or Function, not type {type(new)}')

# Now handle MixedElements separately as these need indexing
elif type(new.ufl_element()) is MixedElement:
assert len(new.function_space()) == len(subj.function_space())
# If idx specified, replace only that component
if idx is not None:
replace_dict[split(subj)[idx]] = split(new)[idx]
# Otherwise replace all components
else:
for k, v in zip(split(subj), split(new)):
replace_dict[k] = v

# Otherwise 'new' is a normal Function
else:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when subject is Mixed and new is a single component')
replace_dict[split(subj)[idx]] = new

# subj is a normal Function
else:
if type(new) is tuple:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when new is a tuple')
replace_dict[subj] = new[idx]
elif isinstance(new, ufl.indexed.Indexed):
replace_dict[subj] = new
elif not isinstance(new, Function):
raise ValueError(f'new must be a Function, not type {type(new)}')
elif type(new.ufl_element()) == MixedElement:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when new is a tuple')
replace_dict[subj] = split(new)[idx]
else:
replace_dict[subj] = new
old_subj = t.get(subject)
replace_dict = _replace_dict(old_subj, new_subj, idx, 'subject')

new_form = ufl.replace(t.form, replace_dict)
try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_subject with {new_subj}"
raise type(err)(error_message) from err

# this is necessary to defer applying the perp until after the
# subject is replaced because otherwise replace cannot find
# the subject
# When a term has the perp label, this indicates that replace
# cannot see that the perped object should also be
# replaced. In this case we also pass the perped object to
# replace.
if t.has_label(perp):
perp_function = t.get(perp)
new_form = ufl.replace(new_form, {split(new)[0]: perp_function(split(new)[0])})
perp_op = t.get(perp)
perp_old = perp_op(t.get(subject))
perp_new = perp_op(new_subj)
try:
new_form = ufl.replace(t.form, {perp_old: perp_new})

except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_subject with {new_subj}"
raise type(err)(error_message) from err

return Term(new_form, t.labels)

Expand Down
Binary file added integration-tests/data/sw_fplane_chkpt.h5
Binary file not shown.
126 changes: 126 additions & 0 deletions integration-tests/equations/test_sw_fplane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
This runs a shallow water simulation on the fplane with 3 waves
that interact and checks the results agains a known checkpointed answer.
"""

from os.path import join, abspath, dirname
from gusto import *
from firedrake import (PeriodicSquareMesh, SpatialCoordinate, Function,
norm, cos, pi)


def run_sw_fplane(tmpdir):
# Domain
Nx = 32
Ny = Nx
Lx = 10
mesh = PeriodicSquareMesh(Nx, Ny, Lx, quadrilateral=True)
dt = 0.01
domain = Domain(mesh, dt, 'RTCF', 1)

# Equation
H = 2
g = 50
parameters = ShallowWaterParameters(H=H, g=g)
f0 = 10
fexpr = Constant(f0)
eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr)

# I/O
output = OutputParameters(dirname=str(tmpdir)+"/sw_fplane",
dumpfreq=1,
log_level='INFO')

io = IO(domain, output, diagnostic_fields=[CourantNumber()])

# Transport schemes
transported_fields = []
transported_fields.append((ImplicitMidpoint(domain, "u")))
transported_fields.append((SSPRK3(domain, "D")))

# Time stepper
stepper = SemiImplicitQuasiNewton(eqns, io, transported_fields)

# ------------------------------------------------------------------------ #
# Initial conditions
# ------------------------------------------------------------------------ #

u0 = stepper.fields("u")
D0 = stepper.fields("D")
x, y = SpatialCoordinate(mesh)
N0 = 0.1
gamma = sqrt(g*H)
###############################
# Fast wave:
k1 = 5*(2*pi/Lx)

K1sq = k1**2
psi1 = sqrt(f0**2 + g*H*K1sq)
xi1 = sqrt(2*K1sq)*psi1

c1 = cos(k1*x)
s1 = sin(k1*x)
################################
# Slow wave:
k2 = -k1
l2 = k1

K2sq = k2**2 + l2**2
psi2 = sqrt(f0**2 + g*H*K2sq)

c2 = cos(k2*x + l2*y)
s2 = sin(k2*x + l2*y)
################################
# Construct the initial condition:
A1 = N0/xi1
u1 = A1*(k1*psi1*c1)
v1 = A1*(f0*k1*s1)
phi1 = A1*(K1sq*gamma*c1)

A2 = N0/psi2
u2 = A2*(l2*gamma*s2)
v2 = A2*(-k2*gamma*s2)
phi2 = A2*(f0*c2)

u_expr = as_vector([u1+u2, v1+v2])
D_expr = H + sqrt(H/g)*(phi1+phi2)

u0.project(u_expr)
D0.interpolate(D_expr)

Dbar = Function(D0.function_space()).assign(H)
stepper.set_reference_profiles([('D', Dbar)])

# ------------------------------------------------------------------------ #
# Run
# ------------------------------------------------------------------------ #

stepper.run(t=0, tmax=10*dt)

# State for checking checkpoints
checkpoint_name = 'sw_fplane_chkpt'
new_path = join(abspath(dirname(__file__)), '..', f'data/{checkpoint_name}')
check_eqn = ShallowWaterEquations(domain, parameters, fexpr=fexpr)
check_output = OutputParameters(dirname=tmpdir+"/sw_fplane",
checkpoint_pickup_filename=new_path)
check_io = IO(domain, output=check_output)
check_stepper = SemiImplicitQuasiNewton(check_eqn, check_io, [])
check_stepper.set_reference_profiles([])
check_stepper.run(t=0, tmax=0, pick_up=True)

return stepper, check_stepper


def test_sw_fplane(tmpdir):

dirname = str(tmpdir)
stepper, check_stepper = run_sw_fplane(dirname)

for variable in ['u', 'D']:
new_variable = stepper.fields(variable)
check_variable = check_stepper.fields(variable)
error = norm(new_variable - check_variable) / norm(check_variable)

# Slack values chosen to be robust to different platforms
assert error < 1e-10, f'Values for {variable} in ' + \
'shallow water fplane test do not match KGO values'
Loading

0 comments on commit b1c4fd5

Please sign in to comment.