Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WorkChain: Protect public methods from being subclassed #5779

Merged
merged 2 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 81 additions & 17 deletions aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Components for the WorkChain concept of the workflow engine."""
from __future__ import annotations

import collections.abc
import functools
import logging
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import typing as t

from plumpy.persistence import auto_persist
from plumpy.process_states import Continue, Wait
from plumpy.processes import ProcessStateMachineMeta
from plumpy.workchains import Stepper
from plumpy.workchains import WorkChainSpec as PlumpyWorkChainSpec
from plumpy.workchains import _PropagateReturn, if_, return_, while_
Expand All @@ -30,8 +33,8 @@
from ..process_spec import ProcessSpec
from .awaitable import Awaitable, AwaitableAction, AwaitableTarget, construct_awaitable

if TYPE_CHECKING:
from aiida.engine.runners import Runner
if t.TYPE_CHECKING:
from aiida.engine.runners import Runner # pylint: disable=unused-import

__all__ = ('WorkChain', 'if_', 'while_', 'return_')

Expand All @@ -40,8 +43,60 @@ class WorkChainSpec(ProcessSpec, PlumpyWorkChainSpec):
pass


class Protect(ProcessStateMachineMeta):
"""Metaclass that allows protecting class methods from being overridden by subclasses.

Usage as follows::

class SomeClass(metaclass=Protect):

@Protect.final
def private_method(self):
"This method cannot be overridden by a subclass."

If a subclass is imported that overrides the subclass, a ``RuntimeError`` is raised.
"""

__SENTINEL = object()

def __new__(cls, name, bases, namespace, **kwargs):
"""Collect all methods that were marked as protected and raise if the subclass defines it.

:raises RuntimeError: If the new class defines (i.e. overrides) a method that was decorated with ``final``.
"""
private = {
key for base in bases for key, value in vars(base).items() if callable(value) and cls.__is_final(value)
}
for key in namespace:
if key in private:
raise RuntimeError(f'the method `{key}` is protected cannot be overridden.')
return super().__new__(cls, name, bases, namespace, **kwargs)

@classmethod
def __is_final(cls, method) -> bool:
"""Return whether the method has been decorated by the ``final`` classmethod.

:return: Boolean, ``True`` if the method is marked as final, ``False`` otherwise.
"""
try:
return method.__final is cls.__SENTINEL # pylint: disable=protected-access
except AttributeError:
return False

@classmethod
def final(cls, method: t.Any):
"""Decorate a method with this method to protect it from being overridden.

Adds the ``__SENTINEL`` object as the ``__final`` private attribute to the given ``method`` and wraps it in
the ``typing.final`` decorator. The latter indicates to typing systems that it cannot be overridden in
subclasses.
"""
method.__final = cls.__SENTINEL # pylint: disable=protected-access,unused-private-member
return t.final(method)


@auto_persist('_awaitables')
class WorkChain(Process):
class WorkChain(Process, metaclass=Protect):
"""The `WorkChain` class is the principle component to implement workflows in AiiDA."""

_node_class = WorkChainNode
Expand All @@ -51,9 +106,9 @@ class WorkChain(Process):

def __init__(
self,
inputs: Optional[dict] = None,
logger: Optional[logging.Logger] = None,
runner: Optional['Runner'] = None,
inputs: dict | None = None,
logger: logging.Logger | None = None,
runner: 'Runner' | None = None,
enable_persistence: bool = True
) -> None:
"""Construct a WorkChain instance.
Expand All @@ -71,8 +126,8 @@ def __init__(

super().__init__(inputs, logger, runner, enable_persistence=enable_persistence)

self._stepper: Optional[Stepper] = None
self._awaitables: List[Awaitable] = []
self._stepper: Stepper | None = None
self._awaitables: list[Awaitable] = []
self._context = AttributeDict()

@classmethod
Expand Down Expand Up @@ -119,11 +174,12 @@ def load_instance_state(self, saved_state, load_context):
if self._awaitables:
self.action_awaitables()

@Protect.final
def on_run(self):
super().on_run()
self.node.set_stepper_state_info(str(self._stepper))

def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]:
def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]:
"""
Returns a reference to a sub-dictionary of the context and the last key,
after resolving a potentially segmented key where required sub-dictionaries are created as needed.
Expand Down Expand Up @@ -155,6 +211,7 @@ def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]:

return ctx, ctx_path[-1]

@Protect.final
def insert_awaitable(self, awaitable: Awaitable) -> None:
"""Insert an awaitable that should be terminated before before continuing to the next step.

Expand All @@ -178,7 +235,8 @@ def insert_awaitable(self, awaitable: Awaitable) -> None:
) # add only if everything went ok, otherwise we end up in an inconsistent state
self._update_process_status()

def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
@Protect.final
def resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None:
"""Resolve an awaitable.

Precondition: must be an awaitable that was previously inserted.
Expand Down Expand Up @@ -210,7 +268,8 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
# then we should not try to update it
self._update_process_status()

def to_context(self, **kwargs: Union[Awaitable, ProcessNode]) -> None:
@Protect.final
def to_context(self, **kwargs: Awaitable | ProcessNode) -> None:
"""Add a dictionary of awaitables to the context.

This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will
Expand All @@ -230,11 +289,12 @@ def _update_process_status(self) -> None:
self.node.set_process_status(None)

@override
def run(self) -> Any:
@Protect.final
def run(self) -> t.Any:
self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type]
return self._do_step()

def _do_step(self) -> Any:
def _do_step(self) -> t.Any:
"""Execute the next step in the outline and return the result.

If the stepper returns a non-finished status and the return value is of type ToContext, the contents of the
Expand All @@ -245,7 +305,7 @@ def _do_step(self) -> Any:
from .context import ToContext

self._awaitables = []
result: Any = None
result: t.Any = None

try:
assert self._stepper is not None
Expand Down Expand Up @@ -273,7 +333,7 @@ def _do_step(self) -> Any:

return Continue(self._do_step)

def _store_nodes(self, data: Any) -> None:
def _store_nodes(self, data: t.Any) -> None:
"""Recurse through a data structure and store any unstored nodes that are found along the way

:param data: a data structure potentially containing unstored nodes
Expand All @@ -288,6 +348,7 @@ def _store_nodes(self, data: Any) -> None:
self._store_nodes(value)

@override
@Protect.final
def on_exiting(self) -> None:
"""Ensure that any unstored nodes in the context are stored, before the state is exited

Expand All @@ -301,14 +362,16 @@ def on_exiting(self) -> None:
# An uncaught exception here will have bizarre and disastrous consequences
self.logger.exception('exception in _store_nodes called in on_exiting')

def on_wait(self, awaitables: Sequence[Awaitable]):
@Protect.final
def on_wait(self, awaitables: t.Sequence[Awaitable]):
"""Entering the WAITING state."""
super().on_wait(awaitables)
if self._awaitables:
self.action_awaitables()
else:
self.call_soon(self.resume)

@Protect.final
def action_awaitables(self) -> None:
"""Handle the awaitables that are currently registered with the work chain.

Expand All @@ -323,6 +386,7 @@ def action_awaitables(self) -> None:
else:
assert f"invalid awaitable target '{awaitable.target}'"

@Protect.final
def on_process_finished(self, awaitable: Awaitable) -> None:
"""Callback function called by the runner when the process instance identified by pk is completed.

Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ py:class plumpy.utils.AttributesDict
py:class plumpy.process_states.State
py:class plumpy.workchains._If
py:class plumpy.workchains._While
py:class plumpy.processes.ProcessStateMachineMeta
py:class PersistenceError
py:class State
py:class Stepper
Expand Down
32 changes: 22 additions & 10 deletions tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,16 +731,15 @@ class TestWorkChain(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.run, cls.check)
spec.outline(cls.emit_report, cls.check)
spec.outputs.dynamic = True

def run(self):
orm.Log.collection.delete_all()
def emit_report(self):
self.report('Testing the report function')

def check(self):
logs = self._backend.logs.find()
assert len(logs) == 1
messages = [log.message for log in orm.Log.collection.get_logs_for(self.node)]
assert any('Testing the report function' in message for message in messages)

run_and_check_success(TestWorkChain)

Expand Down Expand Up @@ -996,12 +995,9 @@ class ExitCodeWorkChain(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.run)
spec.outline()
spec.exit_code(status, label, message)

def run(self):
pass

wc = ExitCodeWorkChain()

# The exit code can be gotten by calling it with the status or label, as well as using attribute dereferencing
Expand Down Expand Up @@ -1600,7 +1596,7 @@ def define(cls, spec):
super().define(spec)
spec.input('a', valid_type=Bool, default=lambda: Bool(True))

def run(self):
def step(self):
pass

def test_unique_default_inputs(self):
Expand All @@ -1623,3 +1619,19 @@ def test_unique_default_inputs(self):
# as both `child_one.a` and `child_two.a` should have the same UUID.
node = load_node(uuid=node.base.links.get_incoming().get_node_by_label('child_one__a').uuid)
assert len(uuids) == len(nodes), f'Only {len(uuids)} unique UUIDS for {len(nodes)} input nodes'


def test_illegal_override_run():
"""Test that overriding a protected workchain method raises a ``RuntimeError``."""
with pytest.raises(RuntimeError, match='the method `run` is protected cannot be overridden.'):

class IllegalWorkChain(WorkChain): # pylint: disable=unused-variable
"""Work chain that illegally overrides the ``run`` method."""

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.run)

def run(self):
pass