Skip to content

Commit

Permalink
Tweaks for FrozenClass (#437)
Browse files Browse the repository at this point in the history
* Allow to add variables to frozen classes via class attribute

* Bugfix

* Added default value to uninitialized variables

* Small cleanup
  • Loading branch information
brownbaerchen authored May 25, 2024
1 parent 5491e18 commit e372a43
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 200 deletions.
114 changes: 32 additions & 82 deletions pySDC/core/ConvergenceController.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, controller, params, description, **kwargs):
params (dict): The params passed for this specific convergence controller
description (dict): The description object used to instantiate the controller
"""
self.controller = controller
self.params = Pars(self.setup(controller, params, description))
params_ok, msg = self.check_parameters(controller, params, description)
assert params_ok, f'{type(self).__name__} -- {msg}'
Expand Down Expand Up @@ -425,94 +426,43 @@ def Recv(self, comm, source, buffer, **kwargs):

return data

def reset_variable(self, controller, name, MPI=False, place=None, where=None, init=None):
"""
Utility function for resetting variables. This function will call the `add_variable` function with all the same
arguments, but with `allow_overwrite = True`.
Args:
controller (pySDC.Controller): The controller
name (str): The name of the variable
MPI (bool): Whether to use MPI controller
place (object): The object you want to reset the variable of
where (list): List of strings containing a path to where you want to reset the variable
init: Initial value of the variable
Returns:
None
"""
self.add_variable(controller, name, MPI, place, where, init, allow_overwrite=True)
def add_status_variable_to_step(self, key, value=None):
if type(self.controller).__name__ == 'controller_MPI':
steps = [self.controller.S]
else:
steps = self.controller.MS

def add_variable(self, controller, name, MPI=False, place=None, where=None, init=None, allow_overwrite=False):
"""
Add a variable to a frozen class.
steps[0].status.add_attr(key)

This function goes through the path to the destination of the variable recursively and adds it to all instances
that are possible in the path. For example, giving `where = ["MS", "levels", "status"]` will result in adding a
variable to the status object of all levels of all steps of the controller.
if value is not None:
self.set_step_status_variable(key, value)

Part of the functionality of the frozen class is to separate initialization and setting of variables. By
enforcing this, you can make sure not to overwrite already existing variables. Since this function is called
outside of the `__init__` function of the status objects, this can otherwise lead to bugs that are hard to find.
For this reason, you need to specifically set `allow_overwrite = True` if you want to forgo the check if the
variable already exists. This can be useful when resetting variables between steps, but make sure to set it to
`allow_overwrite = False` the first time you add a variable.
def set_step_status_variable(self, key, value):
if type(self.controller).__name__ == 'controller_MPI':
steps = [self.controller.S]
else:
steps = self.controller.MS

Args:
controller (pySDC.Controller): The controller
name (str): The name of the variable
MPI (bool): Whether to use MPI controller
place (object): The object you want to add the variable to
where (list): List of strings containing a path to where you want to add the variable
init: Initial value of the variable
allow_overwrite (bool): Allow overwriting the variables if they already exist or raise an exception
for S in steps:
S.status.__dict__[key] = value

Returns:
None
"""
where = ["S" if MPI else "MS", "levels", "status"] if where is None else where
place = controller if place is None else place
def add_status_variable_to_level(self, key, value=None):
if type(self.controller).__name__ == 'controller_MPI':
steps = [self.controller.S]
else:
steps = self.controller.MS

# check if we have arrived at the end of the path to the variable
if len(where) == 0:
variable_exitsts = name in place.__dict__.keys()
# check if the variable already exists and raise an error in case we are about to introduce a bug
if not allow_overwrite and variable_exitsts:
raise ValueError(f"Key \"{name}\" already exists in {place}! Please rename the variable in {self}")
# if we allow overwriting, but the variable does not exist already, we are violating the intended purpose
# of this function, so we also raise an error if someone should be so mad as to attempt this
elif allow_overwrite and not variable_exitsts:
raise ValueError(f"Key \"{name}\" is supposed to be overwritten in {place}, but it does not exist!")
steps[0].levels[0].status.add_attr(key)

# actually add or overwrite the variable
place.__dict__[name] = init
if value is not None:
self.set_level_status_variable(key, value)

# follow the path to the final destination recursively
def set_level_status_variable(self, key, value):
if type(self.controller).__name__ == 'controller_MPI':
steps = [self.controller.S]
else:
# get all possible new places to continue the path
new_places = place.__dict__[where[0]]

# continue all possible paths
if type(new_places) == list:
# loop through all possibilities
for new_place in new_places:
self.add_variable(
controller,
name,
MPI=MPI,
place=new_place,
where=where[1:],
init=init,
allow_overwrite=allow_overwrite,
)
else:
# go to the only possible possibility
self.add_variable(
controller,
name,
MPI=MPI,
place=new_places,
where=where[1:],
init=init,
allow_overwrite=allow_overwrite,
)
steps = self.controller.MS

for S in steps:
for L in S.levels:
L.status.__dict__[key] = value
3 changes: 1 addition & 2 deletions pySDC/core/Level.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def __init__(self, params):
class _Status(FrozenClass):
"""
This class carries the status of the level. All variables that the core SDC / PFASST functionality depend on are
initialized here, while the convergence controllers are allowed to add more variables in a controlled fashion
later on using the `add_variable` function.
initialized here.
"""

def __init__(self):
Expand Down
3 changes: 1 addition & 2 deletions pySDC/core/Step.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def __init__(self, params):
class _Status(FrozenClass):
"""
This class carries the status of the step. All variables that the core SDC / PFASST functionality depend on are
initialized here, while the convergence controllers are allowed to add more variables in a controlled fashion
later on using the `add_variable` function.
initialized here.
"""

def __init__(self):
Expand Down
35 changes: 33 additions & 2 deletions pySDC/helpers/pysdc_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class FrozenClass(object):
__isfrozen: Flag to freeze a class
"""

attrs = []

__isfrozen = False

def __setattr__(self, key, value):
Expand All @@ -18,10 +20,33 @@ def __setattr__(self, key, value):
"""

# check if attribute exists and if class is frozen
if self.__isfrozen and not hasattr(self, key):
raise TypeError("%r is a frozen class" % self)
if self.__isfrozen and not (key in self.attrs or hasattr(self, key)):
raise TypeError(f'{type(self).__name__!r} is a frozen class, cannot add attribute {key!r}')

object.__setattr__(self, key, value)

def __getattr__(self, key):
"""
This is needed in case the variables have not been initialized after adding.
"""
if key in self.attrs:
return None
else:
super().__getattr__(key)

@classmethod
def add_attr(cls, key, raise_error_if_exists=False):
"""
Add a key to the allowed attributes of this class.
Args:
key (str): The key to add
raise_error_if_exists (bool): Raise an error if the attribute already exists in the class
"""
if key in cls.attrs and raise_error_if_exists:
raise TypeError(f'Attribute {key!r} already exists in {cls.__name__}!')
cls.attrs += [key]

def _freeze(self):
"""
Function to freeze the class
Expand All @@ -40,3 +65,9 @@ def get(self, key, default=None):
__dict__.get(key, default)
"""
return self.__dict__.get(key, default)

def __dir__(self):
"""
My hope is that some editors can use this for dynamic autocompletion.
"""
return super().__dir__() + self.attrs
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,26 @@ def setup(self, controller, params, description, **kwargs):

return {**defaults, **super().setup(controller, params, description, **kwargs)}

def setup_status_variables(self, controller, **kwargs):
def setup_status_variables(self, *args, **kwargs):
"""
Add status variables for whether to restart now and how many times the step has been restarted in a row to the
Steps
Args:
controller (pySDC.Controller): The controller
reset (bool): Whether the function is called for the first time or to reset
Returns:
None
"""
where = ["S" if 'comm' in kwargs.keys() else "MS", "status"]
self.add_variable(controller, name='restart', where=where, init=False)
self.add_variable(controller, name='restarts_in_a_row', where=where, init=0)
self.add_status_variable_to_step('restart', False)
self.add_status_variable_to_step('restarts_in_a_row', 0)

def reset_status_variables(self, controller, reset=False, **kwargs):
def reset_status_variables(self, *args, **kwargs):
"""
Add status variables for whether to restart now and how many times the step has been restarted in a row to the
Steps
Args:
controller (pySDC.Controller): The controller
reset (bool): Whether the function is called for the first time or to reset
Returns:
None
"""
where = ["S" if 'comm' in kwargs.keys() else "MS", "status"]
self.reset_variable(controller, name='restart', where=where, init=False)
self.set_step_status_variable('restart', False)

def dependencies(self, controller, description, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,41 +39,17 @@ def dependencies(self, controller, description, **kwargs):
description=description,
)

def setup_status_variables(self, controller, **kwargs):
def setup_status_variables(self, *args, **kwargs):
"""
Add the embedded error, contraction factor and iterations to convergence variable to the status of the levels.
Args:
controller (pySDC.Controller): The controller
Returns:
None
"""
if 'comm' in kwargs.keys():
steps = [controller.S]
else:
if 'active_slots' in kwargs.keys():
steps = [controller.MS[i] for i in kwargs['active_slots']]
else:
steps = controller.MS
where = ["levels", "status"]
for S in steps:
self.add_variable(S, name='error_embedded_estimate_last_iter', where=where, init=None)
self.add_variable(S, name='contraction_factor', where=where, init=None)
if self.params.e_tol is not None:
self.add_variable(S, name='iter_to_convergence', where=where, init=None)

def reset_status_variables(self, controller, **kwargs):
"""
Reinitialize new status variables for the levels.
Args:
controller (pySDC.controller): The controller
Returns:
None
"""
self.setup_status_variables(controller, **kwargs)
self.add_status_variable_to_level('error_embedded_estimate_last_iter')
self.add_status_variable_to_level('contraction_factor')
if self.params.e_tol is not None:
self.add_status_variable_to_level('iter_to_convergence')

def post_iteration_processing(self, controller, S, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,7 @@ def setup_status_variables(self, controller, **kwargs):
Args:
controller (pySDC.Controller): The controller
"""
if 'comm' in kwargs.keys():
steps = [controller.S]
else:
if 'active_slots' in kwargs.keys():
steps = [controller.MS[i] for i in kwargs['active_slots']]
else:
steps = controller.MS
where = ["levels", "status"]
for S in steps:
self.add_variable(S, name='error_embedded_estimate', where=where, init=None)

def reset_status_variables(self, controller, **kwargs):
self.setup_status_variables(controller, **kwargs)
self.add_status_variable_to_level('error_embedded_estimate')

def post_iteration_processing(self, controller, S, **kwargs):
"""
Expand Down Expand Up @@ -350,7 +338,7 @@ def post_iteration_processing(self, controller, step, **kwargs):
max([np.finfo(float).eps, abs(self.status.u[-1] - self.status.u[-2])]),
)

def setup_status_variables(self, controller, **kwargs):
def setup_status_variables(self, *args, **kwargs):
"""
Add the embedded error variable to the levels and add a status variable for previous steps.
Expand All @@ -361,16 +349,4 @@ def setup_status_variables(self, controller, **kwargs):
self.status.u = [] # the solutions of converged collocation problems
self.status.iter = [] # the iteration in which the solution converged

if 'comm' in kwargs.keys():
steps = [controller.S]
else:
if 'active_slots' in kwargs.keys():
steps = [controller.MS[i] for i in kwargs['active_slots']]
else:
steps = controller.MS
where = ["levels", "status"]
for S in steps:
self.add_variable(S, name='error_embedded_estimate_collocation', where=where, init=None)

def reset_status_variables(self, controller, **kwargs):
self.setup_status_variables(controller, **kwargs)
self.add_status_variable_to_level('error_embedded_estimate_collocation')
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,7 @@ def setup_status_variables(self, controller, **kwargs):
self.coeff.u = [None] * self.params.n
self.coeff.f = [0.0] * self.params.n

self.reset_status_variables(controller, **kwargs)
return None

def reset_status_variables(self, controller, **kwargs):
"""
Add variable for extrapolated error
Args:
controller (pySDC.Controller): The controller
Returns:
None
"""
if 'comm' in kwargs.keys():
steps = [controller.S]
else:
if 'active_slots' in kwargs.keys():
steps = [controller.MS[i] for i in kwargs['active_slots']]
else:
steps = controller.MS
where = ["levels", "status"]
for S in steps:
self.add_variable(S, name='error_extrapolation_estimate', where=where, init=None)
self.add_status_variable_to_level('error_extrapolation_estimate')

def check_parameters(self, controller, params, description, **kwargs):
"""
Expand Down
Loading

0 comments on commit e372a43

Please sign in to comment.