From 5c7a5c99c6b0c632a5523462829a6f7eba575fc4 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Fri, 13 Jan 2023 14:45:52 +0000 Subject: [PATCH 01/11] first attempt at allowing indexing in replace_{test,trial}_function --- gusto/labels.py | 155 +++++++++++++++++++++++++----------------------- 1 file changed, 82 insertions(+), 73 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index 23fd2d1a7..1466960e2 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -7,7 +7,78 @@ from types import MethodType -def replace_test_function(new_test): +def _replace_dict(old, new, idx, replace_type): + """ + Build a dictionary to pass to the ufl.replace routine + The dictionary matches variables in the old term with those in the new + + Consider cases that old is normal Function or MixedFunction + vs cases of new being Function vs MixedFunction vs tuple + Ideally catch all cases or fail gracefully + """ + + replace_dict = {} + + if type(old.ufl_element()) is MixedElement: + if type(new) == tuple: + assert len(new) == len(old.function_space()) + for k, v in zip(split(old), new): + replace_dict[k] = v + + elif type(new) == ufl.algebra.Sum: + replace_dict[old] = new + + elif isinstance(new, ufl.indexed.Indexed): + if idx is None: + raise ValueError('idx must be specified to replace_{replace_type}' + + ' when {replace_type} is Mixed and new is a single component') + replace_dict[split(old)[idx]] = new + + # Otherwise fail if new is not a function + elif not isinstance(new, type(old)): + raise ValueError(f'new must be a tuple or {type(old)}, not type {type(new)}') + + # Now handle MixedElements separately as these need indexing + elif type(new.ufl_element()) is MixedElement: + assert len(new.function_space()) == len(old.function_space()) + # If idx specified, replace only that component + if idx is not None: + replace_dict[split(old)[idx]] = split(new)[idx] + # Otherwise replace all components + else: + for k, v in zip(split(old), split(new)): + replace_dict[k] = v + + # Otherwise 'new' is a normal Function + else: + if idx is None: + raise ValueError('idx must be specified to replace_{replace_type}' + + ' when {replace_type} is Mixed and new is a single component') + replace_dict[split(old)[idx]] = new + + # old is a normal Function + else: + if type(new) is tuple: + if idx is None: + raise ValueError('idx must be specified to replace_{replace_type}' + + ' when new is a tuple') + replace_dict[old] = new[idx] + elif isinstance(new, ufl.indexed.Indexed): + replace_dict[old] = new + elif not isinstance(new, type(old)): + raise ValueError(f'new must be a {type(old)}, not type {type(new)}') + elif type(new.ufl_element()) == MixedElement: + if idx is None: + raise ValueError('idx must be specified to replace_{replace_type}' + + ' when new is a tuple') + replace_dict[old] = split(new)[idx] + else: + replace_dict[old] = new + + return replace_dict + + +def replace_test_function(new_test, idx=None): """ A routine to replace the test function in a term with a new test function. @@ -30,14 +101,15 @@ def repl(t): Returns: :class:`Term`: the new term. """ - test = t.form.arguments()[0] - new_form = ufl.replace(t.form, {test: new_test}) + old_test = t.form.arguments()[0] + replace_dict = _replace_dict(old_test, new_test, idx, 'test') + new_form = ufl.replace(t.form, replace_dict) return Term(new_form, t.labels) return repl -def replace_trial_function(new): +def replace_trial_function(new_trial, idx=None): """ A routine to replace the trial function in a term with a new expression. @@ -65,14 +137,15 @@ def repl(t): """ if len(t.form.arguments()) != 2: raise TypeError('Trying to replace trial function of a form that is not linear') - trial = t.form.arguments()[1] - new_form = ufl.replace(t.form, {trial: new}) + old_trial = t.form.arguments()[1] + replace_dict = _replace_dict(old_trial, new_trial, idx, 'trial') + new_form = ufl.replace(t.form, replace_dict) return Term(new_form, t.labels) return repl -def replace_subject(new, idx=None): +def replace_subject(new_subj, idx=None): """ A routine to replace the subject in a term with a new variable. @@ -97,73 +170,9 @@ def repl(t): :class:`Term`: the new term. """ - subj = t.get(subject) - - # Build a dictionary to pass to the ufl.replace routine - # The dictionary matches variables in the old term with those in the new - replace_dict = {} - - # Consider cases that subj is normal Function or MixedFunction - # vs cases of new being Function vs MixedFunction vs tuple - # Ideally catch all cases or fail gracefully - if type(subj.ufl_element()) is MixedElement: - if type(new) == tuple: - assert len(new) == len(subj.function_space()) - for k, v in zip(split(subj), new): - replace_dict[k] = v - - elif type(new) == ufl.algebra.Sum: - replace_dict[subj] = new - - elif isinstance(new, ufl.indexed.Indexed): - if idx is None: - raise ValueError('idx must be specified to replace_subject' - + ' when subject is Mixed and new is a single component') - replace_dict[split(subj)[idx]] = new - - # Otherwise fail if new is not a function - elif not isinstance(new, Function): - raise ValueError(f'new must be a tuple or Function, not type {type(new)}') - - # Now handle MixedElements separately as these need indexing - elif type(new.ufl_element()) is MixedElement: - assert len(new.function_space()) == len(subj.function_space()) - # If idx specified, replace only that component - if idx is not None: - replace_dict[split(subj)[idx]] = split(new)[idx] - # Otherwise replace all components - else: - for k, v in zip(split(subj), split(new)): - replace_dict[k] = v - - # Otherwise 'new' is a normal Function - else: - if idx is None: - raise ValueError('idx must be specified to replace_subject' - + ' when subject is Mixed and new is a single component') - replace_dict[split(subj)[idx]] = new - - # subj is a normal Function - else: - if type(new) is tuple: - if idx is None: - raise ValueError('idx must be specified to replace_subject' - + ' when new is a tuple') - replace_dict[subj] = new[idx] - elif isinstance(new, ufl.indexed.Indexed): - replace_dict[subj] = new - elif not isinstance(new, Function): - raise ValueError(f'new must be a Function, not type {type(new)}') - elif type(new.ufl_element()) == MixedElement: - if idx is None: - raise ValueError('idx must be specified to replace_subject' - + ' when new is a tuple') - replace_dict[subj] = split(new)[idx] - else: - replace_dict[subj] = new - + old_subj = t.get(subject) + replace_dict = _replace_dict(old_subj, new_subj, idx, 'subject') new_form = ufl.replace(t.form, replace_dict) - return Term(new_form, t.labels) return repl From 4613dc003c4482bba5dc884e259f7f521ffa4d50 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Mon, 16 Jan 2023 16:19:40 +0000 Subject: [PATCH 02/11] replace_*: allow ufl.algebra.Sum to be replacement for non-mixed functionspaces --- gusto/labels.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gusto/labels.py b/gusto/labels.py index 1466960e2..5a854d1cd 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -42,8 +42,10 @@ def _replace_dict(old, new, idx, replace_type): elif type(new.ufl_element()) is MixedElement: assert len(new.function_space()) == len(old.function_space()) # If idx specified, replace only that component + if idx is not None: replace_dict[split(old)[idx]] = split(new)[idx] + # Otherwise replace all components else: for k, v in zip(split(old), split(new)): @@ -63,15 +65,22 @@ def _replace_dict(old, new, idx, replace_type): raise ValueError('idx must be specified to replace_{replace_type}' + ' when new is a tuple') replace_dict[old] = new[idx] + + elif type(new) == ufl.algebra.Sum: + replace_dict[old] = new + elif isinstance(new, ufl.indexed.Indexed): replace_dict[old] = new + elif not isinstance(new, type(old)): raise ValueError(f'new must be a {type(old)}, not type {type(new)}') + elif type(new.ufl_element()) == MixedElement: if idx is None: raise ValueError('idx must be specified to replace_{replace_type}' + ' when new is a tuple') replace_dict[old] = split(new)[idx] + else: replace_dict[old] = new From c32a7740a28a609b546250e6e786c8cd247b4a00 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Mon, 16 Jan 2023 16:59:42 +0000 Subject: [PATCH 03/11] replace_* label maps: check new function is of acceptable type earlier --- gusto/labels.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index 5a854d1cd..0d934b31d 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -19,12 +19,20 @@ def _replace_dict(old, new, idx, replace_type): replace_dict = {} + acceptable_types = (type(old), ufl.algebra.Sum, ufl.indexed.Indexed) + if replace_type == 'trial': + acceptable_types = (*acceptable_types, Function) + if type(old.ufl_element()) is MixedElement: if type(new) == tuple: assert len(new) == len(old.function_space()) for k, v in zip(split(old), new): replace_dict[k] = v + # Otherwise fail if new is not a function + elif not isinstance(new, acceptable_types): + raise ValueError(f'new must be a tuple or {type(old)}, not type {type(new)}') + elif type(new) == ufl.algebra.Sum: replace_dict[old] = new @@ -34,10 +42,6 @@ def _replace_dict(old, new, idx, replace_type): + ' when {replace_type} is Mixed and new is a single component') replace_dict[split(old)[idx]] = new - # Otherwise fail if new is not a function - elif not isinstance(new, type(old)): - raise ValueError(f'new must be a tuple or {type(old)}, not type {type(new)}') - # Now handle MixedElements separately as these need indexing elif type(new.ufl_element()) is MixedElement: assert len(new.function_space()) == len(old.function_space()) @@ -66,15 +70,15 @@ def _replace_dict(old, new, idx, replace_type): + ' when new is a tuple') replace_dict[old] = new[idx] + elif not isinstance(new, acceptable_types): + raise ValueError(f'new must be a {type(old)}, not type {type(new)}') + elif type(new) == ufl.algebra.Sum: replace_dict[old] = new elif isinstance(new, ufl.indexed.Indexed): replace_dict[old] = new - elif not isinstance(new, type(old)): - raise ValueError(f'new must be a {type(old)}, not type {type(new)}') - elif type(new.ufl_element()) == MixedElement: if idx is None: raise ValueError('idx must be specified to replace_{replace_type}' From e6f7ab04b42cd8b9b28c8ff6bf0ab42332adf3a3 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Mon, 16 Jan 2023 17:23:25 +0000 Subject: [PATCH 04/11] replace_* label maps: better error messages --- gusto/labels.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index 0d934b31d..7a16571a3 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -23,6 +23,8 @@ def _replace_dict(old, new, idx, replace_type): if replace_type == 'trial': acceptable_types = (*acceptable_types, Function) + type_error_message = f'new must be a {tuple} or '+' or '.join((f"{t}" for t in acceptable_types))+f', not {type(new)}' + if type(old.ufl_element()) is MixedElement: if type(new) == tuple: assert len(new) == len(old.function_space()) @@ -31,7 +33,7 @@ def _replace_dict(old, new, idx, replace_type): # Otherwise fail if new is not a function elif not isinstance(new, acceptable_types): - raise ValueError(f'new must be a tuple or {type(old)}, not type {type(new)}') + raise TypeError(type_error_message) elif type(new) == ufl.algebra.Sum: replace_dict[old] = new @@ -67,11 +69,11 @@ def _replace_dict(old, new, idx, replace_type): if type(new) is tuple: if idx is None: raise ValueError('idx must be specified to replace_{replace_type}' - + ' when new is a tuple') + + ' when new is a tuple and {replace_type} is not Mixed') replace_dict[old] = new[idx] elif not isinstance(new, acceptable_types): - raise ValueError(f'new must be a {type(old)}, not type {type(new)}') + raise TypeError(type_error_message) elif type(new) == ufl.algebra.Sum: replace_dict[old] = new @@ -82,7 +84,7 @@ def _replace_dict(old, new, idx, replace_type): elif type(new.ufl_element()) == MixedElement: if idx is None: raise ValueError('idx must be specified to replace_{replace_type}' - + ' when new is a tuple') + + ' when new is a tuple and {replace_type} is not Mixed') replace_dict[old] = split(new)[idx] else: From 17b1ec9019453d0ec4a1fddda19015c102bb389c Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 24 Jan 2023 10:04:43 +0000 Subject: [PATCH 05/11] replace_* label maps: leave type checking to ufl.replace when building replace dictionary --- gusto/labels.py | 114 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 1 deletion(-) diff --git a/gusto/labels.py b/gusto/labels.py index 7a16571a3..2ce07ff76 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -7,7 +7,7 @@ from types import MethodType -def _replace_dict(old, new, idx, replace_type): +def _replace_dict_old(old, new, idx, replace_type): """ Build a dictionary to pass to the ufl.replace routine The dictionary matches variables in the old term with those in the new @@ -93,6 +93,118 @@ def _replace_dict(old, new, idx, replace_type): return replace_dict +def _replace_dict_dumb1(old, new, idx, replace_type): + """ + Build a dictionary to pass to the ufl.replace routine + The dictionary matches variables in the old term with those in the new + + Does not check types unless indexing is required (leave type-checking to ufl) + """ + + replace_dict = {} + + if type(old.ufl_element()) is MixedElement: + + if type(new) is tuple: + if len(new) != len(old.function_space()): + raise ValueError(f"tuple of new {replace_type} must be same length as replaced mixed {replace_type} of type {old}") + if idx is None: + for k, v in zip(split(old), new): + replace_dict[k] = v + else: + replace_dict[split(old)[idx]] = new[idx] + + elif type(new.ufl_element()) is MixedElement: + if len(new.function_space()) != len(old.function_space()): + raise ValueError(f"New mixed {replace_type} of type {new} must be same length as replaced mixed {replace_type} of type {old}") + if idx is None: + for k, v in zip(split(old), split(new)): + replace_dict[k] = v + else: + replace_dict[split(old)[idx]] = split(new)[idx] + + else: # new is not indexable + if idx is None: + raise ValueError(f"idx must be specified to replace_{replace_type}" + + f" when replaced {replace_type} of type {old} is mixed and new {replace_type} of type {new} is a single component") + replace_dict[split(old)[idx]] = new + + else: # old is not mixed + + if type(new) is tuple: + if idx is None: + raise ValueError(f"idx must be specified to replace_{replace_type}" + + f" when replaced {replace_type} of type {old} is not mixed and new {replace_type} is a tuple") + replace_dict[old] = new[idx] + + elif type(new.ufl_element()) is MixedElement: + if idx is None: + raise ValueError(f"idx must be specified to replace_{replace_type}" + + f" when replaced {replace_type} of type {old} is not mixed and new {replace_type} of type {new} is mixed") + replace_dict[old] = split(new)[idx] + + else: # new is not mixed + replace_dict[old] = new + + return replace_dict + + +def _replace_dict(old, new, idx, replace_type): + """ + Build a dictionary to pass to the ufl.replace routine + The dictionary matches variables in the old term with those in the new + + Does not check types unless indexing is required (leave type-checking to ufl) + """ + + replace_dict = {} + + if type(old.ufl_element()) is MixedElement: + + mixed_new = (hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement) + indexable_new = type(new) is tuple or mixed_new + + if indexable_new: + new_len = len(new) if type(new) is tuple else len(new.function_space()) + split_new = new if type(new) is tuple else split(new) + + if new_len != len(old.function_space()): + raise ValueError(f"new {replace_type} of type {new} must be same length as replaced mixed {replace_type} of type {old}") + + if idx is None: + for k, v in zip(split(old), split_new): + replace_dict[k] = v + else: + replace_dict[split(old)[idx]] = split_new[idx] + + else: # new is not indexable + if idx is None: + raise ValueError(f"idx must be specified to replace_{replace_type}" + + f" when replaced {replace_type} of type {old} is mixed and new {replace_type} of type {new} is a single component") + + replace_dict[split(old)[idx]] = new + + else: # old is not mixed + + # or will short-circuit + mixed_new = (hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement) + indexable_new = type(new) is tuple or mixed_new + + if indexable_new: + split_new = new if type(new) is tuple else split(new) + + if idx is None: + raise ValueError(f"idx must be specified to replace_{replace_type}" + + f" when replaced {replace_type} of type {old} is not mixed and new {replace_type} is of indexable type {new}") + + replace_dict[old] = split_new[idx] + + else: + replace_dict[old] = new + + return replace_dict + + def replace_test_function(new_test, idx=None): """ A routine to replace the test function in a term with a new test function. From ca1ed0c6f8a3aad488aecd16ef91e3153ad64f13 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 24 Jan 2023 11:02:40 +0000 Subject: [PATCH 06/11] replace_* label maps: tidy up error messages and test skipping --- gusto/labels.py | 18 ++++++++++-------- unit-tests/fml_tests/test_replace_subject.py | 8 ++++---- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index 2ce07ff76..47ec5bc46 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -161,7 +161,7 @@ def _replace_dict(old, new, idx, replace_type): if type(old.ufl_element()) is MixedElement: - mixed_new = (hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement) + mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement indexable_new = type(new) is tuple or mixed_new if indexable_new: @@ -169,7 +169,8 @@ def _replace_dict(old, new, idx, replace_type): split_new = new if type(new) is tuple else split(new) if new_len != len(old.function_space()): - raise ValueError(f"new {replace_type} of type {new} must be same length as replaced mixed {replace_type} of type {old}") + raise ValueError(f"new {replace_type} of type {new} must be same length" + + f"as replaced mixed {replace_type} of type {old}") if idx is None: for k, v in zip(split(old), split_new): @@ -179,23 +180,24 @@ def _replace_dict(old, new, idx, replace_type): else: # new is not indexable if idx is None: - raise ValueError(f"idx must be specified to replace_{replace_type}" - + f" when replaced {replace_type} of type {old} is mixed and new {replace_type} of type {new} is a single component") + raise ValueError(f"idx must be specified to replace_{replace_type} when" + + f" replaced {replace_type} of type {old} is mixed and" + + f" new {replace_type} of type {new} is a single component") replace_dict[split(old)[idx]] = new else: # old is not mixed - # or will short-circuit - mixed_new = (hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement) + mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement indexable_new = type(new) is tuple or mixed_new if indexable_new: split_new = new if type(new) is tuple else split(new) if idx is None: - raise ValueError(f"idx must be specified to replace_{replace_type}" - + f" when replaced {replace_type} of type {old} is not mixed and new {replace_type} is of indexable type {new}") + raise ValueError(f"idx must be specified to replace_{replace_type} when" + + f" replaced {replace_type} of type {old} is not mixed" + + f" and new {replace_type} of type {new} is indexable") replace_dict[old] = split_new[idx] diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index 1756c4860..8da0ca28e 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -20,15 +20,15 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed): # ------------------------------------------------------------------------ # if subject_type == 'vector' and replacement_type != 'vector': - return + pytest.skip("invalid option combination") elif replacement_type == 'vector' and subject_type != 'vector': - return + pytest.skip("invalid option combination") if replacement_type == 'mixed-component': if subject_type != 'mixed': - return + pytest.skip("invalid option combination") elif function_or_indexed != 'indexed': - return + pytest.skip("invalid option combination") # ------------------------------------------------------------------------ # # Set up From b8a33be39e57cf37df64b613179869b04ffdf869 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 24 Jan 2023 12:07:08 +0000 Subject: [PATCH 07/11] replace_* label maps: add more information to ufl.replace exception message --- gusto/labels.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index 47ec5bc46..4d5be32ab 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -154,7 +154,7 @@ def _replace_dict(old, new, idx, replace_type): Build a dictionary to pass to the ufl.replace routine The dictionary matches variables in the old term with those in the new - Does not check types unless indexing is required (leave type-checking to ufl) + Does not check types unless indexing is required (leave type-checking to ufl.replace) """ replace_dict = {} @@ -232,7 +232,14 @@ def repl(t): """ old_test = t.form.arguments()[0] replace_dict = _replace_dict(old_test, new_test, idx, 'test') - new_form = ufl.replace(t.form, replace_dict) + + try: + new_form = ufl.replace(t.form, replace_dict) + except Exception as err: + error_message = f"{err} raised by ufl.replace when trying to" \ + + f" replace_test_function with {new_test}" + raise type(err)(error_message) from err + return Term(new_form, t.labels) return repl @@ -268,7 +275,14 @@ def repl(t): raise TypeError('Trying to replace trial function of a form that is not linear') old_trial = t.form.arguments()[1] replace_dict = _replace_dict(old_trial, new_trial, idx, 'trial') - new_form = ufl.replace(t.form, replace_dict) + + try: + new_form = ufl.replace(t.form, replace_dict) + except Exception as err: + error_message = f"{err} raised by ufl.replace when trying to" \ + + f" replace_trial_function with {new_trial}" + raise type(err)(error_message) from err + return Term(new_form, t.labels) return repl @@ -301,7 +315,14 @@ def repl(t): old_subj = t.get(subject) replace_dict = _replace_dict(old_subj, new_subj, idx, 'subject') - new_form = ufl.replace(t.form, replace_dict) + + try: + new_form = ufl.replace(t.form, replace_dict) + except Exception as err: + error_message = f"{err} raised by ufl.replace when trying to" \ + + f" replace_subject with {new_subj}" + raise type(err)(error_message) from err + return Term(new_form, t.labels) return repl From 788dbb7f9fe18c67ea731c8514d35a79d040999c Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 24 Jan 2023 13:40:35 +0000 Subject: [PATCH 08/11] replace_* label maps: remove old _replace_dict impl --- gusto/labels.py | 62 +++---------------------------------------------- 1 file changed, 3 insertions(+), 59 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index 4d5be32ab..2cd77ddd4 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -93,62 +93,6 @@ def _replace_dict_old(old, new, idx, replace_type): return replace_dict -def _replace_dict_dumb1(old, new, idx, replace_type): - """ - Build a dictionary to pass to the ufl.replace routine - The dictionary matches variables in the old term with those in the new - - Does not check types unless indexing is required (leave type-checking to ufl) - """ - - replace_dict = {} - - if type(old.ufl_element()) is MixedElement: - - if type(new) is tuple: - if len(new) != len(old.function_space()): - raise ValueError(f"tuple of new {replace_type} must be same length as replaced mixed {replace_type} of type {old}") - if idx is None: - for k, v in zip(split(old), new): - replace_dict[k] = v - else: - replace_dict[split(old)[idx]] = new[idx] - - elif type(new.ufl_element()) is MixedElement: - if len(new.function_space()) != len(old.function_space()): - raise ValueError(f"New mixed {replace_type} of type {new} must be same length as replaced mixed {replace_type} of type {old}") - if idx is None: - for k, v in zip(split(old), split(new)): - replace_dict[k] = v - else: - replace_dict[split(old)[idx]] = split(new)[idx] - - else: # new is not indexable - if idx is None: - raise ValueError(f"idx must be specified to replace_{replace_type}" - + f" when replaced {replace_type} of type {old} is mixed and new {replace_type} of type {new} is a single component") - replace_dict[split(old)[idx]] = new - - else: # old is not mixed - - if type(new) is tuple: - if idx is None: - raise ValueError(f"idx must be specified to replace_{replace_type}" - + f" when replaced {replace_type} of type {old} is not mixed and new {replace_type} is a tuple") - replace_dict[old] = new[idx] - - elif type(new.ufl_element()) is MixedElement: - if idx is None: - raise ValueError(f"idx must be specified to replace_{replace_type}" - + f" when replaced {replace_type} of type {old} is not mixed and new {replace_type} of type {new} is mixed") - replace_dict[old] = split(new)[idx] - - else: # new is not mixed - replace_dict[old] = new - - return replace_dict - - def _replace_dict(old, new, idx, replace_type): """ Build a dictionary to pass to the ufl.replace routine @@ -236,7 +180,7 @@ def repl(t): try: new_form = ufl.replace(t.form, replace_dict) except Exception as err: - error_message = f"{err} raised by ufl.replace when trying to" \ + error_message = f"{type(err)} raised by ufl.replace when trying to" \ + f" replace_test_function with {new_test}" raise type(err)(error_message) from err @@ -279,7 +223,7 @@ def repl(t): try: new_form = ufl.replace(t.form, replace_dict) except Exception as err: - error_message = f"{err} raised by ufl.replace when trying to" \ + error_message = f"{type(err)} raised by ufl.replace when trying to" \ + f" replace_trial_function with {new_trial}" raise type(err)(error_message) from err @@ -319,7 +263,7 @@ def repl(t): try: new_form = ufl.replace(t.form, replace_dict) except Exception as err: - error_message = f"{err} raised by ufl.replace when trying to" \ + error_message = f"{type(err)} raised by ufl.replace when trying to" \ + f" replace_subject with {new_subj}" raise type(err)(error_message) from err From 259f2d1e9286569725d618620cdf984e7a62300a Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 24 Jan 2023 18:17:04 +0000 Subject: [PATCH 09/11] replace_* label maps: parametrize tests for replace_{subject,trial,test} --- gusto/labels.py | 3 +- unit-tests/fml_tests/test_replace_subject.py | 34 ++++++++++++++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index 2cd77ddd4..a4248f6b0 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -109,10 +109,9 @@ def _replace_dict(old, new, idx, replace_type): indexable_new = type(new) is tuple or mixed_new if indexable_new: - new_len = len(new) if type(new) is tuple else len(new.function_space()) split_new = new if type(new) is tuple else split(new) - if new_len != len(old.function_space()): + if len(split_new) != len(old.function_space()): raise ValueError(f"new {replace_type} of type {new} must be same length" + f"as replaced mixed {replace_type} of type {old}") diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index 8da0ca28e..90bb45a77 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -4,16 +4,23 @@ from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, VectorFunctionSpace, MixedFunctionSpace, dx, inner, - TrialFunctions, split) + TrialFunctions, TrialFunction, split) from gusto.fml import Label -from gusto import subject, replace_subject +from gusto import subject, replace_subject, replace_test_function, replace_trial_function import pytest +replace_funcs = [ + pytest.param((Function, replace_subject), id="replace_subj"), + pytest.param((TestFunction, replace_test_function), id="replace_test"), + pytest.param((TrialFunction, replace_trial_function), id="replace_trial") +] + @pytest.mark.parametrize('subject_type', ['normal', 'mixed', 'vector']) @pytest.mark.parametrize('replacement_type', ['normal', 'mixed', 'mixed-component', 'vector', 'tuple']) @pytest.mark.parametrize('function_or_indexed', ['function', 'indexed']) -def test_replace_subject(subject_type, replacement_type, function_or_indexed): +@pytest.mark.parametrize('replace_func', replace_funcs) +def test_replace_subject(subject_type, replacement_type, function_or_indexed, replace_func): # ------------------------------------------------------------------------ # # Only certain combinations of options are valid @@ -52,6 +59,9 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed): # Choose subject # ------------------------------------------------------------------------ # + FunctionType = replace_func[0] + replace_map = replace_func[1] + if subject_type == 'normal': V = V0 elif subject_type == 'mixed': @@ -64,7 +74,12 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed): raise ValueError the_subject = Function(V) - not_subject = Function(V) + + if replace_map is replace_trial_function: + not_subject = TrialFunction(V) + else: + not_subject = Function(V) + test = TestFunction(V) form_1 = inner(the_subject, test)*dx @@ -94,7 +109,7 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed): else: raise ValueError - the_replacement = Function(V) + the_replacement = FunctionType(V) if function_or_indexed == 'indexed' and replacement_type != 'vector': the_replacement = split(the_replacement) @@ -111,7 +126,12 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed): # Test replace_subject # ------------------------------------------------------------------------ # + if replace_map is replace_trial_function: + match_label = bar_label + else: + match_label = subject + labelled_form = labelled_form.label_map( - lambda t: t.has_label(subject), - map_if_true=replace_subject(the_replacement, idx=idx) + lambda t: t.has_label(match_label), + map_if_true=replace_map(the_replacement, idx=idx) ) From 1d7852e02a98c44a9fe2486c3ea723ed63b97e7b Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Tue, 24 Jan 2023 18:33:10 +0000 Subject: [PATCH 10/11] replace_* label maps: remove mixed-component test parameter for in-test if-statement --- unit-tests/fml_tests/test_replace_subject.py | 42 +++++++++----------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index 90bb45a77..374913648 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize('subject_type', ['normal', 'mixed', 'vector']) -@pytest.mark.parametrize('replacement_type', ['normal', 'mixed', 'mixed-component', 'vector', 'tuple']) +@pytest.mark.parametrize('replacement_type', ['normal', 'mixed', 'vector', 'tuple']) @pytest.mark.parametrize('function_or_indexed', ['function', 'indexed']) @pytest.mark.parametrize('replace_func', replace_funcs) def test_replace_subject(subject_type, replacement_type, function_or_indexed, replace_func): @@ -26,16 +26,9 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re # Only certain combinations of options are valid # ------------------------------------------------------------------------ # - if subject_type == 'vector' and replacement_type != 'vector': + # only makes sense to replace a vector with a vector + if (subject_type == 'vector') ^ (replacement_type == 'vector'): pytest.skip("invalid option combination") - elif replacement_type == 'vector' and subject_type != 'vector': - pytest.skip("invalid option combination") - - if replacement_type == 'mixed-component': - if subject_type != 'mixed': - pytest.skip("invalid option combination") - elif function_or_indexed != 'indexed': - pytest.skip("invalid option combination") # ------------------------------------------------------------------------ # # Set up @@ -59,9 +52,6 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re # Choose subject # ------------------------------------------------------------------------ # - FunctionType = replace_func[0] - replace_map = replace_func[1] - if subject_type == 'normal': V = V0 elif subject_type == 'mixed': @@ -74,12 +64,7 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re raise ValueError the_subject = Function(V) - - if replace_map is replace_trial_function: - not_subject = TrialFunction(V) - else: - not_subject = Function(V) - + not_subject = TrialFunction(V) test = TestFunction(V) form_1 = inner(the_subject, test)*dx @@ -99,9 +84,6 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re V = Vmixed if subject_type != 'mixed': idx = 0 - elif replacement_type == 'mixed-component': - V = Vmixed - idx = 0 elif replacement_type == 'vector': V = V2 elif replacement_type == 'tuple': @@ -109,12 +91,14 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re else: raise ValueError + FunctionType = replace_func[0] + the_replacement = FunctionType(V) if function_or_indexed == 'indexed' and replacement_type != 'vector': the_replacement = split(the_replacement) - if len(the_replacement) == 1 or replacement_type == 'mixed-component': + if len(the_replacement) == 1: the_replacement = the_replacement[0] if replacement_type == 'tuple': @@ -126,6 +110,8 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re # Test replace_subject # ------------------------------------------------------------------------ # + replace_map = replace_func[1] + if replace_map is replace_trial_function: match_label = bar_label else: @@ -135,3 +121,13 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re lambda t: t.has_label(match_label), map_if_true=replace_map(the_replacement, idx=idx) ) + + # also test indexed + if subject_type == 'mixed' and function_or_indexed == 'indexed': + idx = 0 + the_replacement = split(FunctionType(Vmixed))[idx] + + labelled_form = labelled_form.label_map( + lambda t: t.has_label(match_label), + map_if_true=replace_map(the_replacement, idx=idx) + ) From 648effa617633a043b58ca5a95f1e9b14d05e375 Mon Sep 17 00:00:00 2001 From: JHopeCollins Date: Wed, 25 Jan 2023 09:30:32 +0000 Subject: [PATCH 11/11] replace_* label maps: remove old dictionary builder --- gusto/labels.py | 86 ------------------------------------------------- 1 file changed, 86 deletions(-) diff --git a/gusto/labels.py b/gusto/labels.py index a4248f6b0..331fa32cf 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -7,92 +7,6 @@ from types import MethodType -def _replace_dict_old(old, new, idx, replace_type): - """ - Build a dictionary to pass to the ufl.replace routine - The dictionary matches variables in the old term with those in the new - - Consider cases that old is normal Function or MixedFunction - vs cases of new being Function vs MixedFunction vs tuple - Ideally catch all cases or fail gracefully - """ - - replace_dict = {} - - acceptable_types = (type(old), ufl.algebra.Sum, ufl.indexed.Indexed) - if replace_type == 'trial': - acceptable_types = (*acceptable_types, Function) - - type_error_message = f'new must be a {tuple} or '+' or '.join((f"{t}" for t in acceptable_types))+f', not {type(new)}' - - if type(old.ufl_element()) is MixedElement: - if type(new) == tuple: - assert len(new) == len(old.function_space()) - for k, v in zip(split(old), new): - replace_dict[k] = v - - # Otherwise fail if new is not a function - elif not isinstance(new, acceptable_types): - raise TypeError(type_error_message) - - elif type(new) == ufl.algebra.Sum: - replace_dict[old] = new - - elif isinstance(new, ufl.indexed.Indexed): - if idx is None: - raise ValueError('idx must be specified to replace_{replace_type}' - + ' when {replace_type} is Mixed and new is a single component') - replace_dict[split(old)[idx]] = new - - # Now handle MixedElements separately as these need indexing - elif type(new.ufl_element()) is MixedElement: - assert len(new.function_space()) == len(old.function_space()) - # If idx specified, replace only that component - - if idx is not None: - replace_dict[split(old)[idx]] = split(new)[idx] - - # Otherwise replace all components - else: - for k, v in zip(split(old), split(new)): - replace_dict[k] = v - - # Otherwise 'new' is a normal Function - else: - if idx is None: - raise ValueError('idx must be specified to replace_{replace_type}' - + ' when {replace_type} is Mixed and new is a single component') - replace_dict[split(old)[idx]] = new - - # old is a normal Function - else: - if type(new) is tuple: - if idx is None: - raise ValueError('idx must be specified to replace_{replace_type}' - + ' when new is a tuple and {replace_type} is not Mixed') - replace_dict[old] = new[idx] - - elif not isinstance(new, acceptable_types): - raise TypeError(type_error_message) - - elif type(new) == ufl.algebra.Sum: - replace_dict[old] = new - - elif isinstance(new, ufl.indexed.Indexed): - replace_dict[old] = new - - elif type(new.ufl_element()) == MixedElement: - if idx is None: - raise ValueError('idx must be specified to replace_{replace_type}' - + ' when new is a tuple and {replace_type} is not Mixed') - replace_dict[old] = split(new)[idx] - - else: - replace_dict[old] = new - - return replace_dict - - def _replace_dict(old, new, idx, replace_type): """ Build a dictionary to pass to the ufl.replace routine