Skip to content

Commit

Permalink
Merge pull request #324 from firedrakeproject/FML_upgrades
Browse files Browse the repository at this point in the history
Neaten up some aspects of FML
  • Loading branch information
jshipton authored Jan 16, 2023
2 parents 83277bb + 843874d commit 0f3fdff
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
17 changes: 15 additions & 2 deletions gusto/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
TrialFunction, FacetNormal, jump, avg, dS_v,
DirichletBC, conditional, SpatialCoordinate,
split, Constant, action)
from gusto.fml.form_manipulation_labelling import Term, all_terms, identity, drop
from gusto.fml.form_manipulation_labelling import Term, all_terms, keep, drop, Label
from gusto.labels import (subject, time_derivative, transport, prognostic,
transporting_velocity, replace_subject, linearisation,
name, pressure_gradient, coriolis,
Expand Down Expand Up @@ -55,6 +55,19 @@ def __init__(self, state, function_space, field_name):

self.bcs[field_name] = []

def label_terms(self, term_filter, label):
"""
Labels terms in the equation, subject to the term filter.
Args:
term_filter (func): a function, taking terms as an argument, that
is used to filter terms.
label (:class:`Label`): the label to be applied to the terms.
"""
assert type(label, Label)
self.residual = self.residual.label_map(term_filter, map_if_true=label)


class AdvectionEquation(PrognosticEquation):
u"""Discretises the advection equation, ∂q/∂t + (u.∇)q = 0"""
Expand Down Expand Up @@ -342,7 +355,7 @@ def linearise(term, X, X_ref, du):
residual = residual.label_map(
should_linearise,
map_if_true=partial(linearise, X=self.X, X_ref=self.X_ref, du=self.trials),
map_if_false=identity, # TODO: should "keep" be an alias for identity?
map_if_false=keep,
)

return residual
Expand Down
7 changes: 4 additions & 3 deletions gusto/fml/form_manipulation_labelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
identity = lambda t: t
drop = lambda t: None
all_terms = lambda t: True
keep = identity


class Term(object):
Expand All @@ -29,7 +30,7 @@ def __init__(self, form, label_dict=None):
self.form = form
self.labels = label_dict or {}

def get(self, label, default=None):
def get(self, label):
"""
Returns the value of a label.
Expand Down Expand Up @@ -356,8 +357,8 @@ def __call__(self, target, value=None):
# if value is provided, check that we have a validator function
# and validate the value, otherwise use default value
if value is not None:
assert self.validator
assert self.validator(value)
assert self.validator, f'Label {self.label} requires a validator'
assert self.validator(value), f'Value {value} for label {self.label} does not satisfy validator'
self.value = value
else:
self.value = self.default_value
Expand Down

0 comments on commit 0f3fdff

Please sign in to comment.