Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Another perp fix #362

Merged
merged 26 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5c7a5c9
first attempt at allowing indexing in replace_{test,trial}_function
JHopeCollins Jan 13, 2023
4613dc0
replace_*: allow ufl.algebra.Sum to be replacement for non-mixed func…
JHopeCollins Jan 16, 2023
c32a774
replace_* label maps: check new function is of acceptable type earlier
JHopeCollins Jan 16, 2023
e6f7ab0
replace_* label maps: better error messages
JHopeCollins Jan 16, 2023
17b1ec9
replace_* label maps: leave type checking to ufl.replace when buildin…
JHopeCollins Jan 24, 2023
ca1ed0c
replace_* label maps: tidy up error messages and test skipping
JHopeCollins Jan 24, 2023
b8a33be
replace_* label maps: add more information to ufl.replace exception m…
JHopeCollins Jan 24, 2023
788dbb7
replace_* label maps: remove old _replace_dict impl
JHopeCollins Jan 24, 2023
259f2d1
replace_* label maps: parametrize tests for replace_{subject,trial,test}
JHopeCollins Jan 24, 2023
1d7852e
replace_* label maps: remove mixed-component test parameter for in-te…
JHopeCollins Jan 24, 2023
d55b42c
Merge branch 'main' into JHopeCollins/replace_indexed
JHopeCollins Jan 24, 2023
648effa
replace_* label maps: remove old dictionary builder
JHopeCollins Jan 25, 2023
d4a0bc8
Merge branch 'main' into JHopeCollins/replace_indexed
JHopeCollins Jan 25, 2023
c6a8822
hacky fix
jshipton Jan 26, 2023
829488c
hack for replacing subject with trial functions
jshipton Jan 27, 2023
f5f653b
work towards checking for the perped subject in a different way
jshipton Feb 2, 2023
a5a6301
only label coriolis term with perp if we are not on sphere
jshipton Feb 6, 2023
cc1e7cb
adding simple test for replacing a perped subject on the plane
jshipton Feb 24, 2023
90e8136
start of SW fplane test - needs checkpointed data for KGO
jshipton Feb 25, 2023
1d93391
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Apr 11, 2023
41d3ed8
Merge branch 'main' of https://github.com/firedrakeproject/gusto into…
jshipton Apr 12, 2023
219e31d
add test for shallow water fplane checking against known good solutio…
jshipton Apr 12, 2023
599800a
fix lint
jshipton Apr 12, 2023
bcea13d
reverting change that is not part of this PR
jshipton Apr 12, 2023
e0e3096
also label linearisation of coriolis term with perp when not on sphere
jshipton Apr 12, 2023
1054c43
updated comments relating to the perp label
jshipton Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions gusto/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,18 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None,
if fexpr is not None:
V = FunctionSpace(domain.mesh, "CG", 1)
f = self.prescribed_fields("coriolis", V).interpolate(fexpr)
coriolis_form = perp(
coriolis(
subject(prognostic(f*inner(u, w)*dx, "u"), self.X)
), domain.perp)
coriolis_form = coriolis(subject(
prognostic(f*inner(domain.perp(u), w)*dx, "u"), self.X))
if not domain.on_sphere:
coriolis_form = perp(coriolis_form, domain.perp)
# Add linearisation
if self.linearisation_map(coriolis_form.terms[0]):
linear_coriolis = perp(
coriolis(
subject(prognostic(f*inner(u_trial, w)*dx, "u"), self.X)
subject(prognostic(f*inner(domain.perp(u_trial), w)*dx, "u"), self.X)
), domain.perp)
if not domain.on_sphere:
linear_coriolis = perp(linear_coriolis, domain.perp)
coriolis_form = linearisation(coriolis_form, linear_coriolis)
residual += coriolis_form

Expand Down
184 changes: 110 additions & 74 deletions gusto/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,64 @@
from types import MethodType, LambdaType


def replace_test_function(new_test):
def _replace_dict(old, new, idx, replace_type):
"""
Build a dictionary to pass to the ufl.replace routine
The dictionary matches variables in the old term with those in the new

Does not check types unless indexing is required (leave type-checking to ufl.replace)
"""

replace_dict = {}

if type(old.ufl_element()) is MixedElement:

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new

if indexable_new:
split_new = new if type(new) is tuple else split(new)

if len(split_new) != len(old.function_space()):
raise ValueError(f"new {replace_type} of type {new} must be same length"
+ f"as replaced mixed {replace_type} of type {old}")

if idx is None:
for k, v in zip(split(old), split_new):
replace_dict[k] = v
else:
replace_dict[split(old)[idx]] = split_new[idx]

else: # new is not indexable
if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is mixed and"
+ f" new {replace_type} of type {new} is a single component")

replace_dict[split(old)[idx]] = new

else: # old is not mixed

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new

if indexable_new:
split_new = new if type(new) is tuple else split(new)

if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is not mixed"
+ f" and new {replace_type} of type {new} is indexable")

replace_dict[old] = split_new[idx]

else:
replace_dict[old] = new

return replace_dict


def replace_test_function(new_test, idx=None):
"""
A routine to replace the test function in a term with a new test function.

Expand All @@ -30,14 +87,22 @@ def repl(t):
Returns:
:class:`Term`: the new term.
"""
test = t.form.arguments()[0]
new_form = ufl.replace(t.form, {test: new_test})
old_test = t.form.arguments()[0]
replace_dict = _replace_dict(old_test, new_test, idx, 'test')

try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_test_function with {new_test}"
raise type(err)(error_message) from err

jshipton marked this conversation as resolved.
Show resolved Hide resolved
return Term(new_form, t.labels)

return repl


def replace_trial_function(new):
def replace_trial_function(new_trial, idx=None):
"""
A routine to replace the trial function in a term with a new expression.

Expand Down Expand Up @@ -65,14 +130,34 @@ def repl(t):
"""
if len(t.form.arguments()) != 2:
raise TypeError('Trying to replace trial function of a form that is not linear')
trial = t.form.arguments()[1]
new_form = ufl.replace(t.form, {trial: new})
old_trial = t.form.arguments()[1]
replace_dict = _replace_dict(old_trial, new_trial, idx, 'trial')

try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_trial_function with {new_trial}"
raise type(err)(error_message) from err

jshipton marked this conversation as resolved.
Show resolved Hide resolved
if t.has_label(perp):
perp_op = t.get(perp)
perp_old = perp_op(old_trial)
perp_new = perp_op(new_trial)
try:
new_form = ufl.replace(t.form, {perp_old: perp_new})

except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_subject with {new_trial}"
raise type(err)(error_message) from err

return Term(new_form, t.labels)

return repl


def replace_subject(new, idx=None):
def replace_subject(new_subj, idx=None):
"""
A routine to replace the subject in a term with a new variable.

Expand All @@ -97,79 +182,30 @@ def repl(t):
:class:`Term`: the new term.
"""

subj = t.get(subject)

# Build a dictionary to pass to the ufl.replace routine
# The dictionary matches variables in the old term with those in the new
replace_dict = {}

# Consider cases that subj is normal Function or MixedFunction
# vs cases of new being Function vs MixedFunction vs tuple
# Ideally catch all cases or fail gracefully
if type(subj.ufl_element()) is MixedElement:
if type(new) == tuple:
assert len(new) == len(subj.function_space())
for k, v in zip(split(subj), new):
replace_dict[k] = v

elif type(new) == ufl.algebra.Sum:
replace_dict[subj] = new

elif isinstance(new, ufl.indexed.Indexed):
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when subject is Mixed and new is a single component')
replace_dict[split(subj)[idx]] = new

# Otherwise fail if new is not a function
elif not isinstance(new, Function):
raise ValueError(f'new must be a tuple or Function, not type {type(new)}')

# Now handle MixedElements separately as these need indexing
elif type(new.ufl_element()) is MixedElement:
assert len(new.function_space()) == len(subj.function_space())
# If idx specified, replace only that component
if idx is not None:
replace_dict[split(subj)[idx]] = split(new)[idx]
# Otherwise replace all components
else:
for k, v in zip(split(subj), split(new)):
replace_dict[k] = v

# Otherwise 'new' is a normal Function
else:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when subject is Mixed and new is a single component')
replace_dict[split(subj)[idx]] = new

# subj is a normal Function
else:
if type(new) is tuple:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when new is a tuple')
replace_dict[subj] = new[idx]
elif isinstance(new, ufl.indexed.Indexed):
replace_dict[subj] = new
elif not isinstance(new, Function):
raise ValueError(f'new must be a Function, not type {type(new)}')
elif type(new.ufl_element()) == MixedElement:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when new is a tuple')
replace_dict[subj] = split(new)[idx]
else:
replace_dict[subj] = new
old_subj = t.get(subject)
replace_dict = _replace_dict(old_subj, new_subj, idx, 'subject')

new_form = ufl.replace(t.form, replace_dict)
try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_subject with {new_subj}"
raise type(err)(error_message) from err

# this is necessary to defer applying the perp until after the
# subject is replaced because otherwise replace cannot find
# the subject
if t.has_label(perp):
perp_function = t.get(perp)
new_form = ufl.replace(new_form, {split(new)[0]: perp_function(split(new)[0])})
perp_op = t.get(perp)
perp_old = perp_op(t.get(subject))
perp_new = perp_op(new_subj)
try:
new_form = ufl.replace(t.form, {perp_old: perp_new})

except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_subject with {new_subj}"
raise type(err)(error_message) from err

return Term(new_form, t.labels)

Expand Down
Binary file added integration-tests/data/sw_fplane_chkpt.h5
Binary file not shown.
126 changes: 126 additions & 0 deletions integration-tests/equations/test_sw_fplane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
This runs a shallow water simulation on the fplane with 3 waves
that interact and checks the results agains a known checkpointed answer.
"""

from os.path import join, abspath, dirname
from gusto import *
from firedrake import (PeriodicSquareMesh, SpatialCoordinate, Function,
norm, cos, pi)


def run_sw_fplane(tmpdir):
# Domain
Nx = 32
Ny = Nx
Lx = 10
mesh = PeriodicSquareMesh(Nx, Ny, Lx, quadrilateral=True)
dt = 0.01
domain = Domain(mesh, dt, 'RTCF', 1)

# Equation
H = 2
g = 50
parameters = ShallowWaterParameters(H=H, g=g)
f0 = 10
fexpr = Constant(f0)
eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr)

# I/O
output = OutputParameters(dirname=str(tmpdir)+"/sw_fplane",
dumpfreq=1,
log_level='INFO')

io = IO(domain, output, diagnostic_fields=[CourantNumber()])

# Transport schemes
transported_fields = []
transported_fields.append((ImplicitMidpoint(domain, "u")))
transported_fields.append((SSPRK3(domain, "D")))

# Time stepper
stepper = SemiImplicitQuasiNewton(eqns, io, transported_fields)

# ------------------------------------------------------------------------ #
# Initial conditions
# ------------------------------------------------------------------------ #

u0 = stepper.fields("u")
D0 = stepper.fields("D")
x, y = SpatialCoordinate(mesh)
N0 = 0.1
gamma = sqrt(g*H)
###############################
# Fast wave:
k1 = 5*(2*pi/Lx)

K1sq = k1**2
psi1 = sqrt(f0**2 + g*H*K1sq)
xi1 = sqrt(2*K1sq)*psi1

c1 = cos(k1*x)
s1 = sin(k1*x)
################################
# Slow wave:
k2 = -k1
l2 = k1

K2sq = k2**2 + l2**2
psi2 = sqrt(f0**2 + g*H*K2sq)

c2 = cos(k2*x + l2*y)
s2 = sin(k2*x + l2*y)
################################
# Construct the initial condition:
A1 = N0/xi1
u1 = A1*(k1*psi1*c1)
v1 = A1*(f0*k1*s1)
phi1 = A1*(K1sq*gamma*c1)

A2 = N0/psi2
u2 = A2*(l2*gamma*s2)
v2 = A2*(-k2*gamma*s2)
phi2 = A2*(f0*c2)

u_expr = as_vector([u1+u2, v1+v2])
D_expr = H + sqrt(H/g)*(phi1+phi2)

u0.project(u_expr)
D0.interpolate(D_expr)

Dbar = Function(D0.function_space()).assign(H)
stepper.set_reference_profiles([('D', Dbar)])

# ------------------------------------------------------------------------ #
# Run
# ------------------------------------------------------------------------ #

stepper.run(t=0, tmax=10*dt)

# State for checking checkpoints
checkpoint_name = 'sw_fplane_chkpt'
new_path = join(abspath(dirname(__file__)), '..', f'data/{checkpoint_name}')
check_eqn = ShallowWaterEquations(domain, parameters, fexpr=fexpr)
check_output = OutputParameters(dirname=tmpdir+"/sw_fplane",
checkpoint_pickup_filename=new_path)
check_io = IO(domain, output=check_output)
check_stepper = SemiImplicitQuasiNewton(check_eqn, check_io, [])
check_stepper.set_reference_profiles([])
check_stepper.run(t=0, tmax=0, pick_up=True)

return stepper, check_stepper


def test_sw_fplane(tmpdir):

dirname = str(tmpdir)
stepper, check_stepper = run_sw_fplane(dirname)

for variable in ['u', 'D']:
new_variable = stepper.fields(variable)
check_variable = check_stepper.fields(variable)
error = norm(new_variable - check_variable) / norm(check_variable)

# Slack values chosen to be robust to different platforms
assert error < 1e-10, f'Values for {variable} in ' + \
'shallow water fplane test do not match KGO values'
Loading