diff --git a/statemachine/dispatcher.py b/statemachine/dispatcher.py index 29841a85..55993b48 100644 --- a/statemachine/dispatcher.py +++ b/statemachine/dispatcher.py @@ -37,7 +37,7 @@ def _get_func_by_attr(attr, *configs): return func, config.obj -def ensure_callable(attr, *objects): +def ensure_callable(attr, *objects): # noqa: C901 """Ensure that `attr` is a callable, if not, tries to retrieve one from any of the given `objects`. @@ -66,6 +66,15 @@ def wrapper(*args, **kwargs): return wrapper + if getattr(func, "_is_sm_event", False): + "Events already have the 'machine' parameter defined." + + def wrapper(*args, **kwargs): + kwargs.pop("machine") + return func(*args, **kwargs) + + return wrapper + return SignatureAdapter.wrap(func) diff --git a/statemachine/event.py b/statemachine/event.py index 65c2a2ac..f1deee8f 100644 --- a/statemachine/event.py +++ b/statemachine/event.py @@ -1,4 +1,5 @@ from .event_data import EventData +from .event_data import TriggerData from .exceptions import TransitionNotAllowed @@ -13,38 +14,36 @@ def __call__(self, machine, *args, **kwargs): return self.trigger(machine, *args, **kwargs) def trigger(self, machine, *args, **kwargs): - event_data = EventData(machine, self.name, *args, **kwargs) - def trigger_wrapper(): """Wrapper that captures event_data as closure.""" - return self._trigger(event_data) + trigger_data = TriggerData( + machine=machine, + event=self.name, + args=args, + kwargs=kwargs, + ) + return self._trigger(trigger_data) return machine._process(trigger_wrapper) - def _trigger(self, event_data): - event_data.source = event_data.machine.current_state - event_data.state = event_data.machine.current_state - event_data.model = event_data.machine.model - - try: - self._process(event_data) - except Exception as error: - event_data.error = error - # TODO: Log errors - # TODO: Allow exception handlers - raise + def _trigger(self, trigger_data: TriggerData): + event_data = self._process(trigger_data) return event_data.result - def _process(self, event_data): - for transition in event_data.source.transitions: - if not transition.match(event_data.event): + def _process(self, trigger_data: TriggerData): + state = trigger_data.machine.current_state + for transition in state.transitions: + if not transition.match(trigger_data.event): continue - event_data._set_transition(transition) + + event_data = EventData(trigger_data=trigger_data, transition=transition) if transition.execute(event_data): event_data.executed = True break else: - raise TransitionNotAllowed(event_data.event, event_data.state) + raise TransitionNotAllowed(trigger_data.event, state) + + return event_data def trigger_event_factory(event): @@ -56,5 +55,6 @@ def trigger_event(self, *args, **kwargs): trigger_event.name = event trigger_event.identifier = event + trigger_event._is_sm_event = True return trigger_event diff --git a/statemachine/event_data.py b/statemachine/event_data.py index 44964295..279df61a 100644 --- a/statemachine/event_data.py +++ b/statemachine/event_data.py @@ -1,43 +1,77 @@ +from dataclasses import dataclass +from dataclasses import field from typing import TYPE_CHECKING +from typing import Any if TYPE_CHECKING: + from .state import State from .statemachine import StateMachine from .transition import Transition +@dataclass +class TriggerData: + machine: "StateMachine" + event: str + """The Event that was triggered.""" + + model: Any = field(init=False) + """A reference to the underlying model that holds the current State.""" + + args: tuple = field(default_factory=tuple) + """All positional arguments provided on the Event.""" + + kwargs: dict = field(default_factory=dict) + """All keyword arguments provided on the Event.""" + + def __post_init__(self): + self.model = self.machine.model + + +@dataclass class EventData: - def __init__(self, machine: "StateMachine", event: str, *args, **kwargs): - self.machine = machine - self.event = event - self.source = kwargs.get("source", None) - self.state = kwargs.get("state", None) - self.model = kwargs.get("model", None) - self.executed = False - self.transition: Transition | None = None - self.target = None - self._set_transition(kwargs.get("transition", None)) - - # runtime and error - self.args = args - self.kwargs = kwargs - self.error = None - self.result = None - - def __repr__(self): - return f"{type(self).__name__}({self.__dict__!r})" - - def _set_transition(self, transition: "Transition"): - self.transition = transition - self.target = getattr(transition, "target", None) + trigger_data: TriggerData + transition: "Transition" + """The Transition instance that was activated by the Event.""" + + state: "State" = field(init=False) + """The current State of the state machine.""" + + source: "State" = field(init=False) + """The State the state machine was in when the Event started.""" + + target: "State" = field(init=False) + """The destination State of the transition.""" + + result: Any | None = None + executed: bool = False + + def __post_init__(self): + self.state = self.transition.source + self.source = self.transition.source + self.target = self.transition.target + + @property + def machine(self): + return self.trigger_data.machine + + @property + def event(self): + return self.trigger_data.event + + @property + def args(self): + return self.trigger_data.args @property def extended_kwargs(self): - kwargs = self.kwargs.copy() + kwargs = self.trigger_data.kwargs.copy() kwargs["event_data"] = self - kwargs["event"] = self.event - kwargs["source"] = self.source - kwargs["state"] = self.state - kwargs["model"] = self.model + kwargs["machine"] = self.trigger_data.machine + kwargs["event"] = self.trigger_data.event + kwargs["model"] = self.trigger_data.model kwargs["transition"] = self.transition + kwargs["state"] = self.state + kwargs["source"] = self.source kwargs["target"] = self.target return kwargs diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index 9252ed44..a18618bf 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -3,6 +3,7 @@ from .dispatcher import ObjectConfig from .dispatcher import resolver_factory from .event import Event +from .event_data import TriggerData from .event_data import EventData from .exceptions import InvalidStateValue from .exceptions import InvalidDefinition @@ -62,30 +63,29 @@ def _activate_initial_state(self, initial_transition): initial_transition.before.clear() initial_transition.on.clear() initial_transition.after.clear() + event_data = EventData( - self, - initial_transition.event, + trigger_data=TriggerData( + machine=self, + event=initial_transition.event, + ), transition=initial_transition, ) self._activate(event_data) def _get_protected_attrs(self): - return ( - { - "_abstract", - "model", - "state_field", - "start_value", - "initial_state", - "final_states", - "states", - "_events", - "states_map", - "send", - } - | {s.id for s in self.states} - | set(self._events.keys()) - ) + return { + "_abstract", + "model", + "state_field", + "start_value", + "initial_state", + "final_states", + "states", + "_events", + "states_map", + "send", + } | {s.id for s in self.states} def _visit_states_and_transitions(self, visitor): for state in self.states: @@ -165,7 +165,6 @@ def _process(self, trigger): def _activate(self, event_data: EventData): transition = event_data.transition - assert transition is not None source = event_data.state target = transition.target diff --git a/statemachine/transition.py b/statemachine/transition.py index 173004ab..07fbef44 100644 --- a/statemachine/transition.py +++ b/statemachine/transition.py @@ -1,11 +1,14 @@ from functools import partial +from typing import TYPE_CHECKING from .callbacks import Callbacks from .callbacks import ConditionWrapper -from .event_data import EventData from .events import Events from .exceptions import InvalidDefinition +if TYPE_CHECKING: + from .event_data import EventData + class Transition: """A transition holds reference to the source and target state. @@ -119,7 +122,7 @@ def events(self): def add_event(self, value): self._events.add(value) - def execute(self, event_data: EventData): + def execute(self, event_data: "EventData"): self.validators.call(*event_data.args, **event_data.extended_kwargs) if not self._eval_cond(event_data): return False diff --git a/tests/examples/order_control_rich_model_machine.py b/tests/examples/order_control_rich_model_machine.py index bd70c5e8..167bb5c0 100644 --- a/tests/examples/order_control_rich_model_machine.py +++ b/tests/examples/order_control_rich_model_machine.py @@ -19,7 +19,7 @@ def __init__(self): def payments_enough(self, amount): return sum(self.payments) + amount >= self.order_total - def add_to_order(self, amount): + def before_add_to_order(self, amount): self.order_total += amount return self.order_total @@ -40,7 +40,7 @@ class OrderControl(StateMachine): shipping = State() completed = State(final=True) - add_to_order = waiting_for_payment.to(waiting_for_payment, before="add_to_order") + add_to_order = waiting_for_payment.to(waiting_for_payment) receive_payment = waiting_for_payment.to( processing, cond="payments_enough" ) | waiting_for_payment.to(waiting_for_payment, unless="payments_enough") diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 1ac7b97f..80b115ac 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest from statemachine import State @@ -231,3 +233,63 @@ class TestStateMachine(StateMachine): final = State(final=True) execute = initial.to(initial, final, internal=True) + + +@pytest.fixture() +def chained_sm_class(): # noqa: C901 + class ChainedSM(StateMachine): + a = State(initial=True) + b = State() + c = State() + + t1 = a.to(b, after="t1") | b.to(c) + + def __init__(self, *args, **kwargs): + self.spy = mock.Mock(side_effect=lambda x, **kwargs: x) + super().__init__(*args, **kwargs) + + def before_t1(self, source: State, value: int = 0): + return self.spy("before_t1", source=source.id, value=value) + + def on_t1(self, source: State, value: int = 0): + return self.spy("on_t1", source=source.id, value=value) + + def after_t1(self, source: State, value: int = 0): + return self.spy("after_t1", source=source.id, value=value) + + def on_enter_state(self, state: State, source: State, value: int = 0): + return self.spy( + "on_enter_state", + state=state.id, + source=getattr(source, "id", None), + value=value, + ) + + def on_exit_state(self, state: State, source: State, value: int = 0): + return self.spy( + "on_exit_state", state=state.id, source=source.id, value=value + ) + + return ChainedSM + + +class TestChainedTransition: + def test_should_allow_chaining_transitions_using_actions(self, chained_sm_class): + sm = chained_sm_class() + sm.t1(42) + + assert sm.c.is_active + + assert sm.spy.call_args_list == [ + mock.call("on_enter_state", state="a", source=None, value=0), + mock.call("before_t1", source="a", value=0), + mock.call("on_exit_state", state="a", source="a", value=0), + mock.call("on_t1", source="a", value=0), + mock.call("on_enter_state", state="b", source="a", value=0), + mock.call("before_t1", source="b", value=0), + mock.call("on_exit_state", state="b", source="b", value=0), + mock.call("on_t1", source="b", value=0), + mock.call("on_enter_state", state="c", source="b", value=0), + mock.call("after_t1", source="b", value=0), + mock.call("after_t1", source="a", value=0), + ]