Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Allow indexing in replace_test_function and replace_trial_function maps. #325

Merged
merged 13 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 89 additions & 73 deletions gusto/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,64 @@
from types import MethodType, LambdaType


def replace_test_function(new_test):
def _replace_dict(old, new, idx, replace_type):
"""
Build a dictionary to pass to the ufl.replace routine
The dictionary matches variables in the old term with those in the new

Does not check types unless indexing is required (leave type-checking to ufl.replace)
"""

replace_dict = {}

if type(old.ufl_element()) is MixedElement:

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new

if indexable_new:
split_new = new if type(new) is tuple else split(new)

if len(split_new) != len(old.function_space()):
raise ValueError(f"new {replace_type} of type {new} must be same length"
+ f"as replaced mixed {replace_type} of type {old}")

if idx is None:
for k, v in zip(split(old), split_new):
replace_dict[k] = v
else:
replace_dict[split(old)[idx]] = split_new[idx]

else: # new is not indexable
if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is mixed and"
+ f" new {replace_type} of type {new} is a single component")

replace_dict[split(old)[idx]] = new

else: # old is not mixed

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new

if indexable_new:
split_new = new if type(new) is tuple else split(new)

if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is not mixed"
+ f" and new {replace_type} of type {new} is indexable")

replace_dict[old] = split_new[idx]

else:
replace_dict[old] = new

return replace_dict


def replace_test_function(new_test, idx=None):
"""
A routine to replace the test function in a term with a new test function.

Expand All @@ -30,14 +87,22 @@ def repl(t):
Returns:
:class:`Term`: the new term.
"""
test = t.form.arguments()[0]
new_form = ufl.replace(t.form, {test: new_test})
old_test = t.form.arguments()[0]
replace_dict = _replace_dict(old_test, new_test, idx, 'test')

try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_test_function with {new_test}"
raise type(err)(error_message) from err

return Term(new_form, t.labels)

return repl


def replace_trial_function(new):
def replace_trial_function(new_trial, idx=None):
"""
A routine to replace the trial function in a term with a new expression.

Expand Down Expand Up @@ -65,14 +130,22 @@ def repl(t):
"""
if len(t.form.arguments()) != 2:
raise TypeError('Trying to replace trial function of a form that is not linear')
trial = t.form.arguments()[1]
new_form = ufl.replace(t.form, {trial: new})
old_trial = t.form.arguments()[1]
replace_dict = _replace_dict(old_trial, new_trial, idx, 'trial')

try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_trial_function with {new_trial}"
raise type(err)(error_message) from err

return Term(new_form, t.labels)

return repl


def replace_subject(new, idx=None):
def replace_subject(new_subj, idx=None):
"""
A routine to replace the subject in a term with a new variable.

Expand All @@ -97,79 +170,22 @@ def repl(t):
:class:`Term`: the new term.
"""

subj = t.get(subject)

# Build a dictionary to pass to the ufl.replace routine
# The dictionary matches variables in the old term with those in the new
replace_dict = {}

# Consider cases that subj is normal Function or MixedFunction
# vs cases of new being Function vs MixedFunction vs tuple
# Ideally catch all cases or fail gracefully
if type(subj.ufl_element()) is MixedElement:
if type(new) == tuple:
assert len(new) == len(subj.function_space())
for k, v in zip(split(subj), new):
replace_dict[k] = v

elif type(new) == ufl.algebra.Sum:
replace_dict[subj] = new

elif isinstance(new, ufl.indexed.Indexed):
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when subject is Mixed and new is a single component')
replace_dict[split(subj)[idx]] = new

# Otherwise fail if new is not a function
elif not isinstance(new, Function):
raise ValueError(f'new must be a tuple or Function, not type {type(new)}')

# Now handle MixedElements separately as these need indexing
elif type(new.ufl_element()) is MixedElement:
assert len(new.function_space()) == len(subj.function_space())
# If idx specified, replace only that component
if idx is not None:
replace_dict[split(subj)[idx]] = split(new)[idx]
# Otherwise replace all components
else:
for k, v in zip(split(subj), split(new)):
replace_dict[k] = v

# Otherwise 'new' is a normal Function
else:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when subject is Mixed and new is a single component')
replace_dict[split(subj)[idx]] = new

# subj is a normal Function
else:
if type(new) is tuple:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when new is a tuple')
replace_dict[subj] = new[idx]
elif isinstance(new, ufl.indexed.Indexed):
replace_dict[subj] = new
elif not isinstance(new, Function):
raise ValueError(f'new must be a Function, not type {type(new)}')
elif type(new.ufl_element()) == MixedElement:
if idx is None:
raise ValueError('idx must be specified to replace_subject'
+ ' when new is a tuple')
replace_dict[subj] = split(new)[idx]
else:
replace_dict[subj] = new
old_subj = t.get(subject)
replace_dict = _replace_dict(old_subj, new_subj, idx, 'subject')

new_form = ufl.replace(t.form, replace_dict)
try:
new_form = ufl.replace(t.form, replace_dict)
except Exception as err:
error_message = f"{type(err)} raised by ufl.replace when trying to" \
+ f" replace_subject with {new_subj}"
raise type(err)(error_message) from err

# this is necessary to defer applying the perp until after the
# subject is replaced because otherwise replace cannot find
# the subject
if t.has_label(perp):
perp_function = t.get(perp)
new_form = ufl.replace(new_form, {split(new)[0]: perp_function(split(new)[0])})
new_form = ufl.replace(new_form, {split(new_subj)[0]: perp_function(split(new_subj)[0])})

return Term(new_form, t.labels)

Expand Down
60 changes: 38 additions & 22 deletions unit-tests/fml_tests/test_replace_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,31 @@

from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction,
VectorFunctionSpace, MixedFunctionSpace, dx, inner,
TrialFunctions, split)
TrialFunctions, TrialFunction, split)
from gusto.fml import Label
from gusto import subject, replace_subject
from gusto import subject, replace_subject, replace_test_function, replace_trial_function
import pytest

replace_funcs = [
pytest.param((Function, replace_subject), id="replace_subj"),
pytest.param((TestFunction, replace_test_function), id="replace_test"),
pytest.param((TrialFunction, replace_trial_function), id="replace_trial")
]


@pytest.mark.parametrize('subject_type', ['normal', 'mixed', 'vector'])
@pytest.mark.parametrize('replacement_type', ['normal', 'mixed', 'mixed-component', 'vector', 'tuple'])
@pytest.mark.parametrize('replacement_type', ['normal', 'mixed', 'vector', 'tuple'])
@pytest.mark.parametrize('function_or_indexed', ['function', 'indexed'])
def test_replace_subject(subject_type, replacement_type, function_or_indexed):
@pytest.mark.parametrize('replace_func', replace_funcs)
def test_replace_subject(subject_type, replacement_type, function_or_indexed, replace_func):

# ------------------------------------------------------------------------ #
# Only certain combinations of options are valid
# ------------------------------------------------------------------------ #

if subject_type == 'vector' and replacement_type != 'vector':
return
elif replacement_type == 'vector' and subject_type != 'vector':
return

if replacement_type == 'mixed-component':
if subject_type != 'mixed':
return
elif function_or_indexed != 'indexed':
return
# only makes sense to replace a vector with a vector
if (subject_type == 'vector') ^ (replacement_type == 'vector'):
pytest.skip("invalid option combination")

# ------------------------------------------------------------------------ #
# Set up
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed):
raise ValueError

the_subject = Function(V)
not_subject = Function(V)
not_subject = TrialFunction(V)
test = TestFunction(V)

form_1 = inner(the_subject, test)*dx
Expand All @@ -84,22 +84,21 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed):
V = Vmixed
if subject_type != 'mixed':
idx = 0
elif replacement_type == 'mixed-component':
V = Vmixed
idx = 0
elif replacement_type == 'vector':
V = V2
elif replacement_type == 'tuple':
V = Vmixed
else:
raise ValueError

the_replacement = Function(V)
FunctionType = replace_func[0]

the_replacement = FunctionType(V)

if function_or_indexed == 'indexed' and replacement_type != 'vector':
the_replacement = split(the_replacement)

if len(the_replacement) == 1 or replacement_type == 'mixed-component':
if len(the_replacement) == 1:
the_replacement = the_replacement[0]

if replacement_type == 'tuple':
Expand All @@ -111,7 +110,24 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed):
# Test replace_subject
# ------------------------------------------------------------------------ #

replace_map = replace_func[1]

if replace_map is replace_trial_function:
match_label = bar_label
else:
match_label = subject

labelled_form = labelled_form.label_map(
lambda t: t.has_label(subject),
map_if_true=replace_subject(the_replacement, idx=idx)
lambda t: t.has_label(match_label),
map_if_true=replace_map(the_replacement, idx=idx)
)

# also test indexed
if subject_type == 'mixed' and function_or_indexed == 'indexed':
idx = 0
the_replacement = split(FunctionType(Vmixed))[idx]

labelled_form = labelled_form.label_map(
lambda t: t.has_label(match_label),
map_if_true=replace_map(the_replacement, idx=idx)
)