Skip to content

Commit

Permalink
feat: Allow using events on callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
fgmacedo committed Feb 22, 2023
1 parent 9a64f4b commit 2c82d98
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 72 deletions.
11 changes: 10 additions & 1 deletion statemachine/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_func_by_attr(attr, *configs):
return func, config.obj


def ensure_callable(attr, *objects):
def ensure_callable(attr, *objects): # noqa: C901
"""Ensure that `attr` is a callable, if not, tries to retrieve one from any of the given
`objects`.
Expand Down Expand Up @@ -66,6 +66,15 @@ def wrapper(*args, **kwargs):

return wrapper

if getattr(func, "_is_sm_event", False):
"Events already have the 'machine' parameter defined."

def wrapper(*args, **kwargs):
kwargs.pop("machine")
return func(*args, **kwargs)

return wrapper

return SignatureAdapter.wrap(func)


Expand Down
40 changes: 20 additions & 20 deletions statemachine/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .event_data import EventData
from .event_data import TriggerData
from .exceptions import TransitionNotAllowed


Expand All @@ -13,38 +14,36 @@ def __call__(self, machine, *args, **kwargs):
return self.trigger(machine, *args, **kwargs)

def trigger(self, machine, *args, **kwargs):
event_data = EventData(machine, self.name, *args, **kwargs)

def trigger_wrapper():
"""Wrapper that captures event_data as closure."""
return self._trigger(event_data)
trigger_data = TriggerData(
machine=machine,
event=self.name,
args=args,
kwargs=kwargs,
)
return self._trigger(trigger_data)

return machine._process(trigger_wrapper)

def _trigger(self, event_data):
event_data.source = event_data.machine.current_state
event_data.state = event_data.machine.current_state
event_data.model = event_data.machine.model

try:
self._process(event_data)
except Exception as error:
event_data.error = error
# TODO: Log errors
# TODO: Allow exception handlers
raise
def _trigger(self, trigger_data: TriggerData):
event_data = self._process(trigger_data)
return event_data.result

def _process(self, event_data):
for transition in event_data.source.transitions:
if not transition.match(event_data.event):
def _process(self, trigger_data: TriggerData):
state = trigger_data.machine.current_state
for transition in state.transitions:
if not transition.match(trigger_data.event):
continue
event_data._set_transition(transition)

event_data = EventData(trigger_data=trigger_data, transition=transition)
if transition.execute(event_data):
event_data.executed = True
break
else:
raise TransitionNotAllowed(event_data.event, event_data.state)
raise TransitionNotAllowed(trigger_data.event, state)

return event_data


def trigger_event_factory(event):
Expand All @@ -56,5 +55,6 @@ def trigger_event(self, *args, **kwargs):

trigger_event.name = event
trigger_event.identifier = event
trigger_event._is_sm_event = True

return trigger_event
90 changes: 62 additions & 28 deletions statemachine/event_data.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,77 @@
from dataclasses import dataclass
from dataclasses import field
from typing import TYPE_CHECKING
from typing import Any

if TYPE_CHECKING:
from .state import State
from .statemachine import StateMachine
from .transition import Transition


@dataclass
class TriggerData:
machine: "StateMachine"
event: str
"""The Event that was triggered."""

model: Any = field(init=False)
"""A reference to the underlying model that holds the current State."""

args: tuple = field(default_factory=tuple)
"""All positional arguments provided on the Event."""

kwargs: dict = field(default_factory=dict)
"""All keyword arguments provided on the Event."""

def __post_init__(self):
self.model = self.machine.model


@dataclass
class EventData:
def __init__(self, machine: "StateMachine", event: str, *args, **kwargs):
self.machine = machine
self.event = event
self.source = kwargs.get("source", None)
self.state = kwargs.get("state", None)
self.model = kwargs.get("model", None)
self.executed = False
self.transition: Transition | None = None
self.target = None
self._set_transition(kwargs.get("transition", None))

# runtime and error
self.args = args
self.kwargs = kwargs
self.error = None
self.result = None

def __repr__(self):
return f"{type(self).__name__}({self.__dict__!r})"

def _set_transition(self, transition: "Transition"):
self.transition = transition
self.target = getattr(transition, "target", None)
trigger_data: TriggerData
transition: "Transition"
"""The Transition instance that was activated by the Event."""

state: "State" = field(init=False)
"""The current State of the state machine."""

source: "State" = field(init=False)
"""The State the state machine was in when the Event started."""

target: "State" = field(init=False)
"""The destination State of the transition."""

result: Any | None = None
executed: bool = False

def __post_init__(self):
self.state = self.transition.source
self.source = self.transition.source
self.target = self.transition.target

@property
def machine(self):
return self.trigger_data.machine

@property
def event(self):
return self.trigger_data.event

@property
def args(self):
return self.trigger_data.args

@property
def extended_kwargs(self):
kwargs = self.kwargs.copy()
kwargs = self.trigger_data.kwargs.copy()
kwargs["event_data"] = self
kwargs["event"] = self.event
kwargs["source"] = self.source
kwargs["state"] = self.state
kwargs["model"] = self.model
kwargs["machine"] = self.trigger_data.machine
kwargs["event"] = self.trigger_data.event
kwargs["model"] = self.trigger_data.model
kwargs["transition"] = self.transition
kwargs["state"] = self.state
kwargs["source"] = self.source
kwargs["target"] = self.target
return kwargs
37 changes: 18 additions & 19 deletions statemachine/statemachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .dispatcher import ObjectConfig
from .dispatcher import resolver_factory
from .event import Event
from .event_data import TriggerData
from .event_data import EventData
from .exceptions import InvalidStateValue
from .exceptions import InvalidDefinition
Expand Down Expand Up @@ -62,30 +63,29 @@ def _activate_initial_state(self, initial_transition):
initial_transition.before.clear()
initial_transition.on.clear()
initial_transition.after.clear()

event_data = EventData(
self,
initial_transition.event,
trigger_data=TriggerData(
machine=self,
event=initial_transition.event,
),
transition=initial_transition,
)
self._activate(event_data)

def _get_protected_attrs(self):
return (
{
"_abstract",
"model",
"state_field",
"start_value",
"initial_state",
"final_states",
"states",
"_events",
"states_map",
"send",
}
| {s.id for s in self.states}
| set(self._events.keys())
)
return {
"_abstract",
"model",
"state_field",
"start_value",
"initial_state",
"final_states",
"states",
"_events",
"states_map",
"send",
} | {s.id for s in self.states}

def _visit_states_and_transitions(self, visitor):
for state in self.states:
Expand Down Expand Up @@ -165,7 +165,6 @@ def _process(self, trigger):

def _activate(self, event_data: EventData):
transition = event_data.transition
assert transition is not None
source = event_data.state
target = transition.target

Expand Down
7 changes: 5 additions & 2 deletions statemachine/transition.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from functools import partial
from typing import TYPE_CHECKING

from .callbacks import Callbacks
from .callbacks import ConditionWrapper
from .event_data import EventData
from .events import Events
from .exceptions import InvalidDefinition

if TYPE_CHECKING:
from .event_data import EventData


class Transition:
"""A transition holds reference to the source and target state.
Expand Down Expand Up @@ -119,7 +122,7 @@ def events(self):
def add_event(self, value):
self._events.add(value)

def execute(self, event_data: EventData):
def execute(self, event_data: "EventData"):
self.validators.call(*event_data.args, **event_data.extended_kwargs)
if not self._eval_cond(event_data):
return False
Expand Down
4 changes: 2 additions & 2 deletions tests/examples/order_control_rich_model_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self):
def payments_enough(self, amount):
return sum(self.payments) + amount >= self.order_total

def add_to_order(self, amount):
def before_add_to_order(self, amount):
self.order_total += amount
return self.order_total

Expand All @@ -40,7 +40,7 @@ class OrderControl(StateMachine):
shipping = State()
completed = State(final=True)

add_to_order = waiting_for_payment.to(waiting_for_payment, before="add_to_order")
add_to_order = waiting_for_payment.to(waiting_for_payment)
receive_payment = waiting_for_payment.to(
processing, cond="payments_enough"
) | waiting_for_payment.to(waiting_for_payment, unless="payments_enough")
Expand Down
62 changes: 62 additions & 0 deletions tests/test_transitions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

import pytest

from statemachine import State
Expand Down Expand Up @@ -231,3 +233,63 @@ class TestStateMachine(StateMachine):
final = State(final=True)

execute = initial.to(initial, final, internal=True)


@pytest.fixture()
def chained_sm_class(): # noqa: C901
class ChainedSM(StateMachine):
a = State(initial=True)
b = State()
c = State()

t1 = a.to(b, after="t1") | b.to(c)

def __init__(self, *args, **kwargs):
self.spy = mock.Mock(side_effect=lambda x, **kwargs: x)
super().__init__(*args, **kwargs)

def before_t1(self, source: State, value: int = 0):
return self.spy("before_t1", source=source.id, value=value)

def on_t1(self, source: State, value: int = 0):
return self.spy("on_t1", source=source.id, value=value)

def after_t1(self, source: State, value: int = 0):
return self.spy("after_t1", source=source.id, value=value)

def on_enter_state(self, state: State, source: State, value: int = 0):
return self.spy(
"on_enter_state",
state=state.id,
source=getattr(source, "id", None),
value=value,
)

def on_exit_state(self, state: State, source: State, value: int = 0):
return self.spy(
"on_exit_state", state=state.id, source=source.id, value=value
)

return ChainedSM


class TestChainedTransition:
def test_should_allow_chaining_transitions_using_actions(self, chained_sm_class):
sm = chained_sm_class()
sm.t1(42)

assert sm.c.is_active

assert sm.spy.call_args_list == [
mock.call("on_enter_state", state="a", source=None, value=0),
mock.call("before_t1", source="a", value=0),
mock.call("on_exit_state", state="a", source="a", value=0),
mock.call("on_t1", source="a", value=0),
mock.call("on_enter_state", state="b", source="a", value=0),
mock.call("before_t1", source="b", value=0),
mock.call("on_exit_state", state="b", source="b", value=0),
mock.call("on_t1", source="b", value=0),
mock.call("on_enter_state", state="c", source="b", value=0),
mock.call("after_t1", source="b", value=0),
mock.call("after_t1", source="a", value=0),
]

0 comments on commit 2c82d98

Please sign in to comment.