diff --git a/gusto/equations.py b/gusto/equations.py index b6001ad9b..e8d51a9aa 100644 --- a/gusto/equations.py +++ b/gusto/equations.py @@ -10,7 +10,7 @@ from gusto.fml.form_manipulation_labelling import Term, all_terms, keep, drop, Label from gusto.labels import (subject, time_derivative, transport, prognostic, transporting_velocity, replace_subject, linearisation, - name, pressure_gradient, coriolis, + name, pressure_gradient, coriolis, perp, replace_trial_function, hydrostatic) from gusto.thermodynamics import exner_pressure from gusto.transport_forms import (advection_form, continuity_form, @@ -663,12 +663,17 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None, if fexpr is not None: V = FunctionSpace(domain.mesh, "CG", 1) f = self.prescribed_fields("coriolis", V).interpolate(fexpr) - coriolis_form = coriolis( - subject(prognostic(f*inner(domain.perp(u), w)*dx, "u"), self.X)) + coriolis_form = perp( + coriolis( + subject(prognostic(f*inner(u, w)*dx, "u"), self.X) + ), domain.perp) # Add linearisation - linear_coriolis = coriolis( - subject(prognostic(f*inner(domain.perp(u_trial), w)*dx, "u"), self.X)) - coriolis_form = linearisation(coriolis_form, linear_coriolis) + if self.linearisation_map(coriolis_form.terms[0]): + linear_coriolis = perp( + coriolis( + subject(prognostic(f*inner(u_trial, w)*dx, "u"), self.X) + ), domain.perp) + coriolis_form = linearisation(coriolis_form, linear_coriolis) residual += coriolis_form if bexpr is not None: diff --git a/gusto/labels.py b/gusto/labels.py index 23fd2d1a7..2f17bc6a4 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -4,7 +4,7 @@ from firedrake import Function, split, MixedElement from gusto.configuration import IntegrateByParts, TransportEquationType from gusto.fml.form_manipulation_labelling import Term, Label, LabelledForm -from types import MethodType +from types import MethodType, LambdaType def replace_test_function(new_test): @@ -164,6 +164,13 @@ def repl(t): new_form = ufl.replace(t.form, replace_dict) + # this is necessary to defer applying the perp until after the + # subject is replaced because otherwise replace cannot find + # the subject + if t.has_label(perp): + perp_function = t.get(perp) + new_form = ufl.replace(new_form, {split(new)[0]: perp_function(split(new)[0])}) + return Term(new_form, t.labels) return repl @@ -186,3 +193,4 @@ def repl(t): name = Label("name", validator=lambda value: type(value) == str) ibp_label = Label("ibp", validator=lambda value: type(value) == IntegrateByParts) hydrostatic = Label("hydrostatic", validator=lambda value: type(value) in [LabelledForm, Term]) +perp = Label("perp", validator=lambda value: isinstance(value, LambdaType))