diff --git a/routemaster/conftest.py b/routemaster/conftest.py index 76804c58..50b78503 100644 --- a/routemaster/conftest.py +++ b/routemaster/conftest.py @@ -326,7 +326,7 @@ def database_clear(app: TestApp) -> Iterator[None]: with app.new_session(): for table in metadata.tables: app.session.execute( - f'truncate table {table} cascade', + f'truncate table {table} cascade', # type: ignore[arg-type] # noqa: E501 {}, ) @@ -514,8 +514,7 @@ def _inner(label: LabelRef) -> str: label_name=label.name, label_state_machine=label.state_machine, ).order_by( - # TODO: use the sqlalchemy mypy plugin rather than our stubs - History.id.desc(), # type: ignore[attr-defined] + History.id.desc(), ).limit(1).scalar() return _inner diff --git a/routemaster/db/model.py b/routemaster/db/model.py index 9c52be81..919600f6 100644 --- a/routemaster/db/model.py +++ b/routemaster/db/model.py @@ -1,7 +1,7 @@ """Database model definition.""" import datetime import functools -from typing import Any +from typing import Any, Dict, List, Optional import dateutil.tz from sqlalchemy import DDL, Table @@ -16,7 +16,7 @@ ForeignKeyConstraint, func, ) -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, relationship from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.dialects.postgresql import JSONB @@ -53,7 +53,8 @@ class Label(Base): """A single label including context.""" - # Note: type annotations for this class are provided by a stubs file + # Note: type annotations provided below must be manually kept in sync with + # the fields defined in the Table. __table__ = Table( 'labels', @@ -74,6 +75,15 @@ class Label(Base): ], ) + name: Mapped[str] + state_machine: Mapped[str] + metadata: Mapped[Dict[str, Any]] + metadata_triggers_processed: Mapped[bool] + deleted: Mapped[bool] + updated: Mapped[datetime.datetime] + + history: List['History'] = relationship('History') + def __repr__(self) -> str: """Return a useful debug representation.""" return ( @@ -84,7 +94,8 @@ def __repr__(self) -> str: class History(Base): """A single historical state transition of a label.""" - # Note: type annotations for this class are provided by a stubs file + # Note: type annotations provided below must be manually kept in sync with + # the fields defined in the Table. __table__ = Table( 'history', @@ -115,7 +126,15 @@ class History(Base): NullableColumn('new_state', String), ) - label = relationship(Label, backref='history') + id: Mapped[int] + label_name: Mapped[str] + label_state_machine: Mapped[str] + created: Mapped[datetime.datetime] + forced: Mapped[bool] + old_state: Mapped[Optional[str]] + new_state: Mapped[Optional[str]] + + label: Label = relationship(Label) def __repr__(self) -> str: """Return a useful debug representation.""" diff --git a/routemaster/db/model.pyi b/routemaster/db/model.pyi deleted file mode 100644 index 6f108bc1..00000000 --- a/routemaster/db/model.pyi +++ /dev/null @@ -1,60 +0,0 @@ -import datetime -from typing import Any, Dict, List, Union, Optional - -from sqlalchemy import MetaData - -# Imperfect JSON type (see https://github.com/python/typing/issues/182) -_JSON = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] - - -metadata: MetaData - - -class Label: - name: str - state_machine: str - metadata: _JSON - metadata_triggers_processed: bool - deleted: bool - updated: datetime.datetime - - history: List['History'] - - def __init__( - self, - *, - name: str = ..., - state_machine: str = ..., - metadata: _JSON = ..., - metadata_triggers_processed: bool = ..., - deleted: bool = ..., - updated: datetime.datetime = ..., - history: List['History'] = ..., - ) -> None: ... - - -class History: - id: int - - label_name: str - label_state_machine: str - created: datetime.datetime - forced: bool - - old_state: Optional[str] - new_state: Optional[str] - - label: Label - - def __init__( - self, - *, - id: int = ..., - label_name: str = ..., - label_state_machine: str = ..., - created: datetime.datetime = ..., - forced: bool = ..., - old_state: Optional[str] = ..., - new_state: Optional[str] = ..., - label: Label = ..., - ) -> None: ... diff --git a/routemaster/state_machine/actions.py b/routemaster/state_machine/actions.py index 0b11a71d..eea1fcae 100644 --- a/routemaster/state_machine/actions.py +++ b/routemaster/state_machine/actions.py @@ -16,7 +16,6 @@ get_label_metadata, get_current_history, ) -from routemaster.state_machine.exceptions import DeletedLabel def process_action( @@ -42,9 +41,7 @@ def process_action( action = state - metadata, deleted = get_label_metadata(app, label, state_machine) - if deleted: - raise DeletedLabel(label) + metadata = get_label_metadata(app, label) latest_history = get_current_history(app, label) diff --git a/routemaster/state_machine/api.py b/routemaster/state_machine/api.py index 4fa9f37b..9a532f7b 100644 --- a/routemaster/state_machine/api.py +++ b/routemaster/state_machine/api.py @@ -9,15 +9,11 @@ from routemaster.config import Gate, State, StateMachine from routemaster.state_machine.gates import process_gate from routemaster.state_machine.types import LabelRef, Metadata -from routemaster.state_machine.utils import ( +from routemaster.state_machine.utils import ( # noqa: F401 # re-export lock_label, get_current_state, get_state_machine, -) -from routemaster.state_machine.utils import ( - get_label_metadata as get_label_metadata_internal, -) -from routemaster.state_machine.utils import ( + get_label_metadata, needs_gate_evaluation_for_metadata_change, ) from routemaster.state_machine.exceptions import ( @@ -53,23 +49,6 @@ def get_label_state(app: App, label: LabelRef) -> Optional[State]: return get_current_state(app, label, state_machine) -def get_label_metadata(app: App, label: LabelRef) -> Metadata: - """Returns the metadata associated with a label.""" - state_machine = get_state_machine(app, label) - - row = get_label_metadata_internal(app, label, state_machine) - - if row is None: - raise UnknownLabel(label) - - metadata, deleted = row - - if deleted: - raise DeletedLabel(label) - - return metadata - - def create_label(app: App, label: LabelRef, metadata: Metadata) -> Metadata: """Creates a label and starts it in a state machine.""" state_machine = get_state_machine(app, label) @@ -129,7 +108,7 @@ def update_metadata_for_label( ) # FIXME: handle cases where metadata aren't dicts. - new_metadata = dict_merge(existing_metadata, update) # type: ignore[arg-type] # noqa: E501 + new_metadata = dict_merge(existing_metadata, update) row.metadata = new_metadata row.metadata_triggers_processed = not needs_gate_evaluation diff --git a/routemaster/state_machine/gates.py b/routemaster/state_machine/gates.py index 53ac3ab2..91631a1b 100644 --- a/routemaster/state_machine/gates.py +++ b/routemaster/state_machine/gates.py @@ -6,11 +6,9 @@ from routemaster.state_machine.utils import ( choose_next_state, context_for_label, - get_state_machine, get_label_metadata, get_current_history, ) -from routemaster.state_machine.exceptions import DeletedLabel def process_gate( @@ -36,10 +34,7 @@ def process_gate( gate = state - state_machine = get_state_machine(app, label) - metadata, deleted = get_label_metadata(app, label, state_machine) - if deleted: - raise DeletedLabel(label) + metadata = get_label_metadata(app, label) history_entry = get_current_history(app, label) diff --git a/routemaster/state_machine/tests/test_state_machine.py b/routemaster/state_machine/tests/test_state_machine.py index 94c14468..4089607d 100644 --- a/routemaster/state_machine/tests/test_state_machine.py +++ b/routemaster/state_machine/tests/test_state_machine.py @@ -276,7 +276,8 @@ def test_maintains_updated_field_on_label(app, mock_test_feed): def test_continues_after_time_since_entering_gate(app, current_state): label = LabelRef('foo', 'test_machine_timing') - gate = app.config.state_machines['test_machine_timing'].states[0] + test_machine = app.config.state_machines['test_machine_timing'] + gate = test_machine.states[0] with freeze_time('2018-01-24 12:00:00'), app.new_session(): state_machine.create_label( @@ -290,7 +291,7 @@ def test_continues_after_time_since_entering_gate(app, current_state): process_gate( app=app, state=gate, - state_machine=state_machine, + state_machine=test_machine, label=label, ) @@ -301,7 +302,7 @@ def test_continues_after_time_since_entering_gate(app, current_state): process_gate( app=app, state=gate, - state_machine=state_machine, + state_machine=test_machine, label=label, ) diff --git a/routemaster/state_machine/utils.py b/routemaster/state_machine/utils.py index ccf2fc73..410a13ce 100644 --- a/routemaster/state_machine/utils.py +++ b/routemaster/state_machine/utils.py @@ -27,6 +27,7 @@ from routemaster.logging import BaseLogger from routemaster.state_machine.types import LabelRef, Metadata from routemaster.state_machine.exceptions import ( + DeletedLabel, UnknownLabel, UnknownStateMachine, ) @@ -52,11 +53,11 @@ def choose_next_state( return state_machine.get_state(next_state_name) -def get_label_metadata( +def _get_label_metadata( app: App, label: LabelRef, state_machine: StateMachine, -) -> Tuple[Dict[str, Any], bool]: +) -> Optional[Tuple[Dict[str, Any], bool]]: """Get the metadata and whether the label has been deleted.""" return app.session.query(Label.metadata, Label.deleted).filter_by( name=label.name, @@ -64,6 +65,23 @@ def get_label_metadata( ).first() +def get_label_metadata(app: App, label: LabelRef) -> Metadata: + """Returns the metadata associated with a label.""" + state_machine = get_state_machine(app, label) + + row = _get_label_metadata(app, label, state_machine) + + if row is None: + raise UnknownLabel(label) + + metadata, deleted = row + + if deleted: + raise DeletedLabel(label) + + return metadata + + def get_current_state( app: App, label: LabelRef, @@ -83,11 +101,7 @@ def get_current_history(app: App, label: LabelRef) -> History: label_name=label.name, label_state_machine=label.state_machine, ).order_by( - # Our model type stubs define the `id` attribute as `int`, yet - # sqlalchemy actually allows the attribute to be used for ordering like - # this; ignore the type check here specifically rather than complicate - # our type definitions. - History.id.desc(), # type: ignore[attr-defined] + History.id.desc(), ).first() if history_entry is None: @@ -167,14 +181,13 @@ def labels_in_state_with_metadata( metadata_lookup = Label.metadata for part in path: - metadata_lookup = metadata_lookup[part] # type: ignore[call-overload, index] # noqa: E501 + metadata_lookup = metadata_lookup[part] # type: ignore[assignment] # noqa: E501 return _labels_in_state( app, state_machine, state, - # TODO: use the sqlalchemy mypy plugin rather than our stubs file - metadata_lookup.astext.in_(values), # type: ignore[union-attr] + metadata_lookup.astext.in_(values), ) @@ -210,11 +223,7 @@ def _labels_in_state( History.label_name, History.new_state, func.row_number().over( - # Our model type stubs define the `id` attribute as `int`, yet - # sqlalchemy actually allows the attribute to be used for ordering - # like this; ignore the type check here specifically rather than - # complicate our type definitions. - order_by=History.id.desc(), # type: ignore[attr-defined] + order_by=History.id.desc(), partition_by=History.label_name, ).label('rank'), ).filter_by( diff --git a/routemaster/validation.py b/routemaster/validation.py index 7efff465..83b72246 100644 --- a/routemaster/validation.py +++ b/routemaster/validation.py @@ -78,8 +78,7 @@ def _validate_no_labels_in_nonexistent_states( History.label_name, History.new_state, func.row_number().over( - # TODO: use the sqlalchemy mypy plugin rather than our stubs file - order_by=History.id.desc(), # type: ignore[attr-defined] + order_by=History.id.desc(), partition_by=History.label_name, ).label('rank'), ).filter_by( diff --git a/setup.cfg b/setup.cfg index 0c94fe5e..782e94e9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,7 @@ strict_optional=true show_error_codes = true enable_error_code = ignore-without-code warn_unused_ignores = true +plugins = sqlalchemy.ext.mypy.plugin [coverage:run] branch=True diff --git a/setup.py b/setup.py index e6c129e7..fb5ec3f3 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ 'jsonschema >=3, <5', 'flask', 'psycopg2-binary', - 'sqlalchemy', + 'sqlalchemy[mypy]', 'python-dateutil', 'alembic >=0.9.6', 'gunicorn >=19.7',