diff --git a/.github/pages/make_switcher.py b/.github/pages/make_switcher.py index 2b81e7696..6d90f4905 100755 --- a/.github/pages/make_switcher.py +++ b/.github/pages/make_switcher.py @@ -3,7 +3,6 @@ from argparse import ArgumentParser from pathlib import Path from subprocess import CalledProcessError, check_output -from typing import Optional def report_output(stdout: bytes, label: str) -> list[str]: @@ -24,7 +23,7 @@ def get_sorted_tags_list() -> list[str]: return report_output(stdout, "Tags list") -def get_versions(ref: str, add: Optional[str]) -> list[str]: +def get_versions(ref: str, add: str | None) -> list[str]: """Generate the file containing the list of all GitHub Pages builds.""" # Get the directories (i.e. builds) from the GitHub Pages branch try: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27769757f..f45e2aaad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: runs-on: ["ubuntu-latest"] # can add windows-latest, macos-latest - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] include: # Include one that runs in the dev environment - runs-on: "ubuntu-latest" diff --git a/pyproject.toml b/pyproject.toml index 7681cb933..cb0bb3987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ name = "blueapi" classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ] @@ -32,7 +31,7 @@ dependencies = [ dynamic = ["version"] license.file = "LICENSE" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" [project.optional-dependencies] dev = [ diff --git a/src/blueapi/cli/amq.py b/src/blueapi/cli/amq.py index b9b6c0b3c..face01b4b 100644 --- a/src/blueapi/cli/amq.py +++ b/src/blueapi/cli/amq.py @@ -1,5 +1,5 @@ import threading -from typing import Callable, Optional, Union +from collections.abc import Callable from bluesky.callbacks.best_effort import BestEffortCallback @@ -15,13 +15,13 @@ def __init__(self, message: str) -> None: super().__init__(message) -_Event = Union[WorkerEvent, ProgressEvent, DataEvent] +_Event = WorkerEvent | ProgressEvent | DataEvent class AmqClient: app: MessagingTemplate complete: threading.Event - timed_out: Optional[bool] + timed_out: bool | None def __init__(self, app: MessagingTemplate) -> None: self.app = app @@ -37,7 +37,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback) -> None: def subscribe_to_topics( self, correlation_id: str, - on_event: Optional[Callable[[WorkerEvent], None]] = None, + on_event: Callable[[WorkerEvent], None] | None = None, ) -> None: """Run callbacks on events/progress events with a given correlation id.""" @@ -70,7 +70,7 @@ def subscribe_to_all_events( on_event, ) - def wait_for_complete(self, timeout: Optional[float] = None) -> None: + def wait_for_complete(self, timeout: float | None = None) -> None: self.timed_out = not self.complete.wait(timeout=timeout) self.complete.clear() diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index ab2f6549e..5bad55768 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -4,7 +4,6 @@ from functools import wraps from pathlib import Path from pprint import pprint -from typing import Optional, Union import click from requests.exceptions import ConnectionError @@ -40,7 +39,7 @@ "-c", "--config", type=Path, help="Path to configuration YAML file", multiple=True ) @click.pass_context -def main(ctx: click.Context, config: Union[Optional[Path], tuple[Path, ...]]) -> None: +def main(ctx: click.Context, config: Path | None | tuple[Path, ...]) -> None: # if no command is supplied, run with the options passed config_loader = ConfigLoader(ApplicationConfig) @@ -71,7 +70,7 @@ def main(ctx: click.Context, config: Union[Optional[Path], tuple[Path, ...]]) -> is_flag=True, help="[Development only] update the schema in the documentation", ) -def schema(output: Optional[Path] = None, update: bool = False) -> None: +def schema(output: Path | None = None, update: bool = False) -> None: """Generate the schema for the REST API""" schema = generate_schema() @@ -167,7 +166,7 @@ def is_allowed(event: Union[WorkerEvent, ProgressEvent, DataEvent]) -> bool: def on_event( context: MessageContext, - event: Union[WorkerEvent, ProgressEvent, DataEvent], + event: WorkerEvent | ProgressEvent | DataEvent, ) -> None: if is_allowed(event): converted = json.dumps(event.dict(), indent=2) @@ -195,7 +194,7 @@ def on_event( @check_connection @click.pass_obj def run_plan( - obj: dict, name: str, parameters: Optional[str], timeout: Optional[float] + obj: dict, name: str, parameters: str | None, timeout: float | None ) -> None: """Run a plan with parameters""" config: ApplicationConfig = obj["config"] @@ -265,7 +264,7 @@ def resume(obj: dict) -> None: @check_connection @click.argument("reason", type=str, required=False) @click.pass_obj -def abort(obj: dict, reason: Optional[str] = None) -> None: +def abort(obj: dict, reason: str | None = None) -> None: """ Abort the execution of the current task, marking any ongoing runs as failed, with optional reason diff --git a/src/blueapi/cli/rest.py b/src/blueapi/cli/rest.py index 111986d2e..5e363faee 100644 --- a/src/blueapi/cli/rest.py +++ b/src/blueapi/cli/rest.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping -from typing import Any, Callable, Literal, Optional, TypeVar +from collections.abc import Callable, Mapping +from typing import Any, Literal, TypeVar import requests from pydantic import parse_obj_as @@ -27,7 +27,7 @@ def _is_exception(response: requests.Response) -> bool: class BlueapiRestClient: _config: RestConfig - def __init__(self, config: Optional[RestConfig] = None) -> None: + def __init__(self, config: RestConfig | None = None) -> None: self._config = config or RestConfig() def get_plans(self) -> PlanResponse: @@ -48,7 +48,7 @@ def get_state(self) -> WorkerState: def set_state( self, state: Literal[WorkerState.RUNNING, WorkerState.PAUSED], - defer: Optional[bool] = False, + defer: bool | None = False, ): return self._request_and_deserialize( "/worker/state", @@ -87,7 +87,7 @@ def update_worker_task(self, task: WorkerTask) -> WorkerTask: def cancel_current_task( self, state: Literal[WorkerState.ABORTING, WorkerState.STOPPING], - reason: Optional[str] = None, + reason: str | None = None, ): return self._request_and_deserialize( "/worker/state", @@ -100,7 +100,7 @@ def _request_and_deserialize( self, suffix: str, target_type: type[T], - data: Optional[Mapping[str, Any]] = None, + data: Mapping[str, Any] | None = None, method="GET", raise_if: Callable[[requests.Response], bool] = _is_exception, ) -> T: diff --git a/src/blueapi/cli/updates.py b/src/blueapi/cli/updates.py index 530abfa2e..7b7d73dd0 100644 --- a/src/blueapi/cli/updates.py +++ b/src/blueapi/cli/updates.py @@ -1,6 +1,5 @@ import itertools from collections.abc import Mapping -from typing import Optional, Union from tqdm import tqdm @@ -44,13 +43,13 @@ def _update(self, name: str, view: StatusView) -> None: class CliEventRenderer: - _task_id: Optional[str] + _task_id: str | None _pbar_renderer: ProgressBarRenderer def __init__( self, - task_id: Optional[str] = None, - pbar_renderer: Optional[ProgressBarRenderer] = None, + task_id: str | None = None, + pbar_renderer: ProgressBarRenderer | None = None, ) -> None: self._task_id = task_id if pbar_renderer is None: @@ -65,7 +64,7 @@ def on_worker_event(self, event: WorkerEvent) -> None: if self._relates_to_task(event): print(str(event.state)) - def _relates_to_task(self, event: Union[WorkerEvent, ProgressEvent]) -> bool: + def _relates_to_task(self, event: WorkerEvent | ProgressEvent) -> bool: if self._task_id is None: return True elif isinstance(event, WorkerEvent): diff --git a/src/blueapi/config.py b/src/blueapi/config.py index f109bb36a..88180cb5e 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from enum import Enum from pathlib import Path -from typing import Any, Generic, Literal, Optional, TypeVar, Union +from typing import Any, Generic, Literal, TypeVar import yaml from pydantic import BaseModel, Field, ValidationError, parse_obj_as, validator @@ -20,7 +20,7 @@ class SourceKind(str, Enum): class Source(BaseModel): kind: SourceKind - module: Union[Path, str] + module: Path | str class BasicAuthentication(BaseModel): @@ -48,11 +48,11 @@ class StompConfig(BaseModel): host: str = "localhost" port: int = 61613 - auth: Optional[BasicAuthentication] = None + auth: BasicAuthentication | None = None class DataWritingConfig(BlueapiBaseModel): - visit_service_url: Optional[str] = None # e.g. "http://localhost:8088/api" + visit_service_url: str | None = None # e.g. "http://localhost:8088/api" visit_directory: Path = Path("/tmp/0-0") group_name: str = "example" diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index 6514da056..ce5fd66d3 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -1,11 +1,8 @@ import inspect -from collections.abc import Mapping +from collections.abc import Callable, Mapping from typing import ( Any, - Callable, - Optional, Protocol, - Union, get_type_hints, runtime_checkable, ) @@ -37,23 +34,23 @@ #: An object that encapsulates the device to do useful things to produce # data (e.g. move and read) -Device = Union[ - Checkable, - Flyable, - HasHints, - HasName, - HasParent, - Movable, - Pausable, - Readable, - Stageable, - Stoppable, - Subscribable, - WritesExternalAssets, - Configurable, - Triggerable, - AsyncDevice, -] +Device = ( + Checkable + | Flyable + | HasHints + | HasName + | HasParent + | Movable + | Pausable + | Readable + | Stageable + | Stoppable + | Subscribable + | WritesExternalAssets + | Configurable + | Triggerable + | AsyncDevice +) #: Protocols defining interface to hardware BLUESKY_PROTOCOLS = list(Device.__args__) # type: ignore @@ -90,7 +87,7 @@ class Plan(BlueapiBaseModel): """ name: str = Field(description="Referenceable name of the plan") - description: Optional[str] = Field( + description: str | None = Field( description="Description/docstring of the plan", default=None ) model: type[BaseModel] = Field( diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index a4ac20d7e..07dc90b10 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,17 +1,14 @@ import functools import logging -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass, field from importlib import import_module from inspect import Parameter, signature from types import ModuleType from typing import ( Any, - Callable, Generic, - Optional, TypeVar, - Union, get_args, get_origin, get_type_hints, @@ -71,7 +68,7 @@ def wrap(self, plan: MsgGenerator) -> MsgGenerator: ) yield from wrapped_plan - def find_device(self, addr: Union[str, list[str]]) -> Optional[Device]: + def find_device(self, addr: str | list[str]) -> Device | None: """ Find a device in this context, allows for recursive search. @@ -168,7 +165,7 @@ def my_plan(a: int, b: str): self.plan_functions[plan.__name__] = plan return plan - def device(self, device: Device, name: Optional[str] = None) -> None: + def device(self, device: Device, name: str | None = None) -> None: """ Register an device in the context. The device needs to be registered with a name. If the device is Readable, Movable or Flyable it has a `name` @@ -223,7 +220,7 @@ def valid(cls, value): @classmethod def __modify_schema__( - cls, field_schema: dict[str, Any], field: Optional[ModelField] + cls, field_schema: dict[str, Any], field: ModelField | None ): if field: field_schema.update({field.name: repr(target)}) diff --git a/src/blueapi/core/device_lookup.py b/src/blueapi/core/device_lookup.py index 28e616ede..1bace1676 100644 --- a/src/blueapi/core/device_lookup.py +++ b/src/blueapi/core/device_lookup.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from .bluesky_types import Device, is_bluesky_compatible_device @@ -6,7 +6,7 @@ D = TypeVar("D", bound=Device) -def find_component(obj: Any, addr: list[str]) -> Optional[D]: +def find_component(obj: Any, addr: list[str]) -> D | None: """ Best effort function to locate a child device, either in a dictionary of devices or a device with child attributes. diff --git a/src/blueapi/core/event.py b/src/blueapi/core/event.py index dbda0694d..fff9c833d 100644 --- a/src/blueapi/core/event.py +++ b/src/blueapi/core/event.py @@ -1,6 +1,7 @@ import itertools from abc import ABC, abstractmethod -from typing import Callable, Generic, Optional, TypeVar +from collections.abc import Callable +from typing import Generic, TypeVar #: Event type E = TypeVar("E") @@ -15,7 +16,7 @@ class EventStream(ABC, Generic[E, S]): """ @abstractmethod - def subscribe(self, __callback: Callable[[E, Optional[str]], None]) -> S: + def subscribe(self, __callback: Callable[[E, str | None], None]) -> S: """ Subscribe to new events with a callback @@ -47,14 +48,14 @@ class EventPublisher(EventStream[E, int]): Simple Observable that can be fed values to publish """ - _subscriptions: dict[int, Callable[[E, Optional[str]], None]] + _subscriptions: dict[int, Callable[[E, str | None], None]] _count: itertools.count def __init__(self) -> None: self._subscriptions = {} self._count = itertools.count() - def subscribe(self, callback: Callable[[E, Optional[str]], None]) -> int: + def subscribe(self, callback: Callable[[E, str | None], None]) -> int: sub_id = next(self._count) self._subscriptions[sub_id] = callback return sub_id @@ -65,7 +66,7 @@ def unsubscribe(self, subscription: int) -> None: def unsubscribe_all(self) -> None: self._subscriptions = {} - def publish(self, event: E, correlation_id: Optional[str] = None) -> None: + def publish(self, event: E, correlation_id: str | None = None) -> None: """ Publish a new event to all subscribers diff --git a/src/blueapi/data_management/visit_directory_provider.py b/src/blueapi/data_management/visit_directory_provider.py index 9b27e9d9c..bc05040b7 100644 --- a/src/blueapi/data_management/visit_directory_provider.py +++ b/src/blueapi/data_management/visit_directory_provider.py @@ -1,7 +1,6 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional from aiohttp import ClientSession from ophyd_async.core import DirectoryInfo, DirectoryProvider @@ -75,8 +74,8 @@ class VisitDirectoryProvider(DirectoryProvider): _data_directory: Path _client: VisitServiceClientBase - _current_collection: Optional[DirectoryInfo] - _session: Optional[ClientSession] + _current_collection: DirectoryInfo | None + _session: ClientSession | None def __init__( self, diff --git a/src/blueapi/messaging/base.py b/src/blueapi/messaging/base.py index bd517a6a2..6c350639a 100644 --- a/src/blueapi/messaging/base.py +++ b/src/blueapi/messaging/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import Future -from typing import Any, Callable, Optional +from typing import Any from .context import MessageContext @@ -87,7 +88,7 @@ def send_and_receive( destination: str, obj: Any, reply_type: type = str, - correlation_id: Optional[str] = None, + correlation_id: str | None = None, ) -> Future: """ Send a message expecting a single reply. @@ -118,8 +119,8 @@ def send( self, destination: str, obj: Any, - on_reply: Optional[MessageListener] = None, - correlation_id: Optional[str] = None, + on_reply: MessageListener | None = None, + correlation_id: str | None = None, ) -> None: """ Send a message to a destination diff --git a/src/blueapi/messaging/context.py b/src/blueapi/messaging/context.py index be79a3958..d202b700e 100644 --- a/src/blueapi/messaging/context.py +++ b/src/blueapi/messaging/context.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional @dataclass @@ -9,5 +8,5 @@ class MessageContext: """ destination: str - reply_destination: Optional[str] - correlation_id: Optional[str] + reply_destination: str | None + correlation_id: str | None diff --git a/src/blueapi/messaging/stomptemplate.py b/src/blueapi/messaging/stomptemplate.py index e56c72435..bf107c36b 100644 --- a/src/blueapi/messaging/stomptemplate.py +++ b/src/blueapi/messaging/stomptemplate.py @@ -3,9 +3,10 @@ import logging import time import uuid +from collections.abc import Callable from dataclasses import dataclass from threading import Event -from typing import Any, Callable, Optional +from typing import Any import stomp from pydantic import parse_obj_as @@ -84,8 +85,8 @@ class StompMessagingTemplate(MessagingTemplate): def __init__( self, conn: stomp.Connection, - reconnect_policy: Optional[StompReconnectPolicy] = None, - authentication: Optional[BasicAuthentication] = None, + reconnect_policy: StompReconnectPolicy | None = None, + authentication: BasicAuthentication | None = None, ) -> None: self._conn = conn self._reconnect_policy = reconnect_policy or StompReconnectPolicy() @@ -117,8 +118,8 @@ def send( self, destination: str, obj: Any, - on_reply: Optional[MessageListener] = None, - correlation_id: Optional[str] = None, + on_reply: MessageListener | None = None, + correlation_id: str | None = None, ) -> None: self._send_str( destination, json.dumps(serialize(obj)), on_reply, correlation_id @@ -128,8 +129,8 @@ def _send_str( self, destination: str, message: str, - on_reply: Optional[MessageListener] = None, - correlation_id: Optional[str] = None, + on_reply: MessageListener | None = None, + correlation_id: str | None = None, ) -> None: LOGGER.info(f"SENDING {message} to {destination}") @@ -193,7 +194,7 @@ def finished_connecting(_: Frame): self._ensure_subscribed() - def _ensure_subscribed(self, sub_ids: Optional[list[str]] = None) -> None: + def _ensure_subscribed(self, sub_ids: list[str] | None = None) -> None: # We must defer subscription until after connection, because stomp literally # sends a SUB to the broker. But it still nice to be able to call subscribe # on template before it connects, then just run the subscribes after connection. diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index be2016262..f38dcf434 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -1,6 +1,5 @@ import logging from collections.abc import Mapping -from typing import Optional from blueapi.config import ApplicationConfig from blueapi.core import BlueskyContext @@ -33,10 +32,10 @@ class Handler(BlueskyHandler): def __init__( self, - config: Optional[ApplicationConfig] = None, - context: Optional[BlueskyContext] = None, - messaging_template: Optional[MessagingTemplate] = None, - worker: Optional[Worker] = None, + config: ApplicationConfig | None = None, + context: BlueskyContext | None = None, + messaging_template: MessagingTemplate | None = None, + worker: Worker | None = None, ) -> None: self._config = config or ApplicationConfig() self._context = context or BlueskyContext() @@ -115,27 +114,27 @@ def begin_task(self, task: WorkerTask) -> WorkerTask: return task @property - def active_task(self) -> Optional[TrackableTask]: + def active_task(self) -> TrackableTask | None: return self._worker.get_active_task() @property def state(self) -> WorkerState: return self._worker.state - def pause_worker(self, defer: Optional[bool]) -> None: + def pause_worker(self, defer: bool | None) -> None: self._worker.pause(defer) def resume_worker(self) -> None: self._worker.resume() - def cancel_active_task(self, failure: bool, reason: Optional[str]): + def cancel_active_task(self, failure: bool, reason: str | None): self._worker.cancel_active_task(failure, reason) @property def tasks(self) -> list[TrackableTask]: return self._worker.get_tasks() - def get_task_by_id(self, task_id: str) -> Optional[TrackableTask]: + def get_task_by_id(self, task_id: str) -> TrackableTask | None: return self._worker.get_task_by_id(task_id) @property @@ -143,11 +142,11 @@ def initialized(self) -> bool: return self._initialized -HANDLER: Optional[Handler] = None +HANDLER: Handler | None = None def setup_handler( - config: Optional[ApplicationConfig] = None, + config: ApplicationConfig | None = None, ) -> None: global HANDLER diff --git a/src/blueapi/service/handler_base.py b/src/blueapi/service/handler_base.py index 36e8090e9..0671ebad8 100644 --- a/src/blueapi/service/handler_base.py +++ b/src/blueapi/service/handler_base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional from blueapi.service.model import DeviceModel, PlanModel, WorkerTask from blueapi.worker.event import WorkerState @@ -52,7 +51,7 @@ def begin_task(self, task: WorkerTask) -> WorkerTask: @property @abstractmethod - def active_task(self) -> Optional[TrackableTask]: + def active_task(self) -> TrackableTask | None: """Task the worker is currently running""" @property @@ -61,7 +60,7 @@ def state(self) -> WorkerState: """State of the worker""" @abstractmethod - def pause_worker(self, defer: Optional[bool]) -> None: + def pause_worker(self, defer: bool | None) -> None: """Command the worker to pause""" @abstractmethod @@ -69,7 +68,7 @@ def resume_worker(self) -> None: """Command the worker to resume""" @abstractmethod - def cancel_active_task(self, failure: bool, reason: Optional[str]) -> None: + def cancel_active_task(self, failure: bool, reason: str | None) -> None: """Remove the currently active task from the worker if there is one Returns the task_id of the active task""" @@ -80,7 +79,7 @@ def tasks(self) -> list[TrackableTask]: any one of which can be triggered with begin_task""" @abstractmethod - def get_task_by_id(self, task_id: str) -> Optional[TrackableTask]: + def get_task_by_id(self, task_id: str) -> TrackableTask | None: """Returns a task matching the task ID supplied, if the worker knows of it""" diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index e41c992d8..9c0551cd7 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -1,5 +1,4 @@ from contextlib import asynccontextmanager -from typing import Optional from fastapi import ( BackgroundTasks, @@ -33,7 +32,7 @@ REST_API_VERSION = "0.0.5" -HANDLER: Optional[BlueskyHandler] = None +HANDLER: BlueskyHandler | None = None def get_handler() -> BlueskyHandler: @@ -42,7 +41,7 @@ def get_handler() -> BlueskyHandler: return HANDLER -def setup_handler(config: Optional[ApplicationConfig] = None): +def setup_handler(config: ApplicationConfig | None = None): global HANDLER handler = SubprocessHandler(config) handler.start() diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index 625ecc1e4..f8c42b33b 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Any, Optional +from typing import Any from bluesky.protocols import HasName from pydantic import Field @@ -55,10 +55,8 @@ class PlanModel(BlueapiBaseModel): """ name: str = Field(description="Name of the plan") - description: Optional[str] = Field( - description="Docstring of the plan", default=None - ) - parameter_schema: Optional[dict[str, Any]] = Field( + description: str | None = Field(description="Docstring of the plan", default=None) + parameter_schema: dict[str, Any] | None = Field( description="Schema of the plan's parameters", alias="schema", default_factory=dict, @@ -102,7 +100,7 @@ class WorkerTask(BlueapiBaseModel): Worker's active task ID, can be None """ - task_id: Optional[str] = Field( + task_id: str | None = Field( description="The ID of the current task, None if the worker is idle" ) @@ -121,11 +119,11 @@ class StateChangeRequest(BlueapiBaseModel): """ new_state: WorkerState = Field() - defer: Optional[bool] = Field( + defer: bool | None = Field( description="Should worker defer Pausing until the next checkpoint", default=False, ) - reason: Optional[str] = Field( + reason: str | None = Field( description="The reason for the current run to be aborted", default=None, ) diff --git a/src/blueapi/service/subprocess_handler.py b/src/blueapi/service/subprocess_handler.py index 13fc127b0..32c3d729a 100644 --- a/src/blueapi/service/subprocess_handler.py +++ b/src/blueapi/service/subprocess_handler.py @@ -1,9 +1,8 @@ import logging import signal -from collections.abc import Iterable +from collections.abc import Callable, Iterable from multiprocessing import Pool, set_start_method from multiprocessing.pool import Pool as PoolClass -from typing import Callable, Optional from blueapi.config import ApplicationConfig from blueapi.service.handler import get_handler, setup_handler, teardown_handler @@ -24,12 +23,12 @@ def _init_worker(): class SubprocessHandler(BlueskyHandler): _config: ApplicationConfig - _subprocess: Optional[PoolClass] + _subprocess: PoolClass | None _initialized: bool = False def __init__( self, - config: Optional[ApplicationConfig] = None, + config: ApplicationConfig | None = None, ) -> None: self._config = config or ApplicationConfig() self._subprocess = None @@ -56,9 +55,7 @@ def reload_context(self): self.start() LOGGER.info("Context reloaded") - def _run_in_subprocess( - self, function: Callable, arguments: Optional[Iterable] = None - ): + def _run_in_subprocess(self, function: Callable, arguments: Iterable | None = None): if arguments is None: arguments = [] if self._subprocess is None: @@ -89,27 +86,27 @@ def begin_task(self, task: WorkerTask) -> WorkerTask: return self._run_in_subprocess(begin_task, [task]) @property - def active_task(self) -> Optional[TrackableTask]: + def active_task(self) -> TrackableTask | None: return self._run_in_subprocess(active_task) @property def state(self) -> WorkerState: return self._run_in_subprocess(state) - def pause_worker(self, defer: Optional[bool]) -> None: + def pause_worker(self, defer: bool | None) -> None: return self._run_in_subprocess(pause_worker, [defer]) def resume_worker(self) -> None: return self._run_in_subprocess(resume_worker) - def cancel_active_task(self, failure: bool, reason: Optional[str]) -> None: + def cancel_active_task(self, failure: bool, reason: str | None) -> None: return self._run_in_subprocess(cancel_active_task, [failure, reason]) @property def tasks(self) -> list[TrackableTask]: return self._run_in_subprocess(tasks) - def get_task_by_id(self, task_id: str) -> Optional[TrackableTask]: + def get_task_by_id(self, task_id: str) -> TrackableTask | None: return self._run_in_subprocess(get_task_by_id, [task_id]) @property @@ -148,7 +145,7 @@ def begin_task(task: WorkerTask) -> WorkerTask: return get_handler().begin_task(task) -def active_task() -> Optional[TrackableTask]: +def active_task() -> TrackableTask | None: return get_handler().active_task @@ -156,7 +153,7 @@ def state() -> WorkerState: return get_handler().state -def pause_worker(defer: Optional[bool]) -> None: +def pause_worker(defer: bool | None) -> None: return get_handler().pause_worker(defer) @@ -164,7 +161,7 @@ def resume_worker() -> None: return get_handler().resume_worker() -def cancel_active_task(failure: bool, reason: Optional[str]) -> None: +def cancel_active_task(failure: bool, reason: str | None) -> None: return get_handler().cancel_active_task(failure, reason) @@ -172,5 +169,5 @@ def tasks() -> list[TrackableTask]: return get_handler().tasks -def get_task_by_id(task_id: str) -> Optional[TrackableTask]: +def get_task_by_id(task_id: str) -> TrackableTask | None: return get_handler().get_task_by_id(task_id) diff --git a/src/blueapi/startup/simmotor.py b/src/blueapi/startup/simmotor.py index 49f143920..84067ba21 100644 --- a/src/blueapi/startup/simmotor.py +++ b/src/blueapi/startup/simmotor.py @@ -1,6 +1,6 @@ import threading import time as ttime -from typing import Callable, Optional +from collections.abc import Callable from ophyd.sim import SynAxis from ophyd.status import MoveStatus, Status @@ -13,7 +13,7 @@ def __init__( self, *, name: str, - readback_func: Optional[Callable] = None, + readback_func: Callable | None = None, value: float = 0.0, delay: float = 0.0, events_per_move: int = 6, diff --git a/src/blueapi/utils/thread_exception.py b/src/blueapi/utils/thread_exception.py index 988ee6dae..1151917fb 100644 --- a/src/blueapi/utils/thread_exception.py +++ b/src/blueapi/utils/thread_exception.py @@ -1,11 +1,12 @@ import sys import traceback +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Optional +from typing import Any def handle_all_exceptions( - func: Callable[..., Any], callback: Optional[Callable[[Exception], None]] = None + func: Callable[..., Any], callback: Callable[[Exception], None] | None = None ) -> Callable: """ Ensure any uncaught exception traceback is printed to stdout. This does not diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 2c14de30f..0ec8bb516 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,6 +1,5 @@ from collections.abc import Mapping from enum import Enum -from typing import Optional, Union from bluesky.run_engine import RunEngineStateMachine from pydantic import Field @@ -9,7 +8,7 @@ from blueapi.utils import BlueapiBaseModel # The RunEngine can return any of these three types as its state -RawRunEngineState = Union[PropertyMachine, ProxyString, str] +RawRunEngineState = type[PropertyMachine | ProxyString | str] class WorkerState(str, Enum): @@ -53,13 +52,13 @@ class StatusView(BlueapiBaseModel): description="Human-readable name indicating what this status describes", default="Unknown", ) - current: Optional[float] = Field( + current: float | None = Field( description="Current value of operation progress, if known", default=None ) - initial: Optional[float] = Field( + initial: float | None = Field( description="Initial value of operation progress, if known", default=None ) - target: Optional[float] = Field( + target: float | None = Field( description="Target value operation of progress, if known", default=None ) unit: str = Field(description="Units of progress", default="units") @@ -70,14 +69,14 @@ class StatusView(BlueapiBaseModel): description="Whether the operation this status describes is complete", default=False, ) - percentage: Optional[float] = Field( + percentage: float | None = Field( description="Percentage of status completion, if known", default=None ) - time_elapsed: Optional[float] = Field( + time_elapsed: float | None = Field( description="Time elapsed since status operation beginning, if known", default=None, ) - time_remaining: Optional[float] = Field( + time_remaining: float | None = Field( description="Estimated time remaining until operation completion, if known", default=None, ) @@ -110,7 +109,7 @@ class WorkerEvent(BlueapiBaseModel): """ state: WorkerState - task_status: Optional[TaskStatus] = None + task_status: TaskStatus | None = None errors: list[str] = Field(default_factory=list) warnings: list[str] = Field(default_factory=list) diff --git a/src/blueapi/worker/multithread.py b/src/blueapi/worker/multithread.py index ce2420ba2..cc02f884f 100644 --- a/src/blueapi/worker/multithread.py +++ b/src/blueapi/worker/multithread.py @@ -1,6 +1,6 @@ import logging from concurrent.futures import Future, ThreadPoolExecutor -from typing import Optional, TypeVar +from typing import TypeVar from blueapi.core import configure_bluesky_event_loop from blueapi.utils import handle_all_exceptions @@ -13,7 +13,7 @@ def run_worker_in_own_thread( - worker: Worker[T], executor: Optional[ThreadPoolExecutor] = None + worker: Worker[T], executor: ThreadPoolExecutor | None = None ) -> Future: """ Helper function, make a worker run in a new thread managed by a ThreadPoolExecutor diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index f284943db..97a523613 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -5,7 +5,7 @@ from functools import partial from queue import Full, Queue from threading import Event, RLock -from typing import Any, Optional, Union +from typing import Any from bluesky.protocols import Status from super_state_machine.errors import TransitionError @@ -55,7 +55,7 @@ class TaskWorker(Worker[Task]): _errors: list[str] _warnings: list[str] _task_channel: Queue # type: ignore - _current: Optional[TrackableTask] + _current: TrackableTask | None _status_lock: RLock _status_snapshot: dict[str, StatusView] _completed_statuses: set[str] @@ -101,7 +101,7 @@ def clear_task(self, task_id: str) -> str: def cancel_active_task( self, failure: bool = False, - reason: Optional[str] = None, + reason: str | None = None, ) -> str: if self._current is None: # Persuades mypy that self._current is not None @@ -116,10 +116,10 @@ def cancel_active_task( def get_tasks(self) -> list[TrackableTask[Task]]: return list(self._tasks.values()) - def get_task_by_id(self, task_id: str) -> Optional[TrackableTask[Task]]: + def get_task_by_id(self, task_id: str) -> TrackableTask[Task] | None: return self._tasks.get(task_id) - def get_active_task(self) -> Optional[TrackableTask[Task]]: + def get_active_task(self) -> TrackableTask[Task] | None: return self._current def begin_task(self, task_id: str) -> None: @@ -142,7 +142,7 @@ def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: task_started = Event() - def mark_task_as_started(event: WorkerEvent, _: Optional[str]) -> None: + def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: if ( event.task_status is not None and event.task_status.task_id == trackable_task.task_id @@ -228,7 +228,7 @@ def _cycle_with_error_handling(self) -> None: def _cycle(self) -> None: try: LOGGER.info("Awaiting task") - next_task: Union[TrackableTask, KillSignal] = self._task_channel.get() + next_task: TrackableTask | KillSignal = self._task_channel.get() if isinstance(next_task, TrackableTask): LOGGER.info(f"Got new task: {next_task}") self._current = next_task # Informing mypy that the task is not None @@ -266,7 +266,7 @@ def data_events(self) -> EventStream[DataEvent, int]: def _on_state_change( self, raw_new_state: RawRunEngineState, - raw_old_state: Optional[RawRunEngineState] = None, + raw_old_state: RawRunEngineState | None = None, ) -> None: new_state = WorkerState.from_bluesky_state(raw_new_state) if raw_old_state: @@ -286,7 +286,7 @@ def _report_error(self, err: Exception) -> None: def _report_status( self, ) -> None: - task_status: Optional[TaskStatus] + task_status: TaskStatus | None errors = self._errors warnings = self._warnings if self._current is not None: @@ -319,7 +319,7 @@ def _on_document(self, name: str, document: Mapping[str, Any]) -> None: "Trying to emit a document despite the fact that the RunEngine is idle" ) - def _waiting_hook(self, statuses: Optional[Iterable[Status]]) -> None: + def _waiting_hook(self, statuses: Iterable[Status] | None) -> None: if statuses is not None: with self._status_lock: for status in statuses: @@ -347,15 +347,15 @@ def _on_status_event( status: Status, status_uuid: str, *, - name: Optional[str] = None, - current: Optional[float] = None, - initial: Optional[float] = None, - target: Optional[float] = None, - unit: Optional[str] = None, - precision: Optional[int] = None, - fraction: Optional[float] = None, - time_elapsed: Optional[float] = None, - time_remaining: Optional[float] = None, + name: str | None = None, + current: float | None = None, + initial: float | None = None, + target: float | None = None, + unit: str | None = None, + precision: int | None = None, + fraction: float | None = None, + time_elapsed: float | None = None, + time_remaining: float | None = None, ) -> None: if not status.done: percentage = float(1.0 - fraction) if fraction is not None else None diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index 84d73055c..026806074 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from pydantic import Field @@ -40,7 +40,7 @@ def get_tasks(self) -> list[TrackableTask[T]]: """ @abstractmethod - def get_task_by_id(self, task_id: str) -> Optional[TrackableTask[T]]: + def get_task_by_id(self, task_id: str) -> TrackableTask[T] | None: """ Returns a task matching the task ID supplied, if the worker knows of it. @@ -53,7 +53,7 @@ def get_task_by_id(self, task_id: str) -> Optional[TrackableTask[T]]: None if the task ID is unknown to the worker. """ - def get_active_task(self) -> Optional[TrackableTask[T]]: + def get_active_task(self) -> TrackableTask[T] | None: """ Returns the task the worker is currently running @@ -77,7 +77,7 @@ def clear_task(self, task_id: str) -> str: def cancel_active_task( self, failure: bool = False, - reason: Optional[str] = None, + reason: str | None = None, ) -> str: """ Remove the currently active task from the worker if there is one diff --git a/tests/preprocessors/test_attach_metadata.py b/tests/preprocessors/test_attach_metadata.py index 878015922..8f879fc26 100644 --- a/tests/preprocessors/test_attach_metadata.py +++ b/tests/preprocessors/test_attach_metadata.py @@ -1,6 +1,6 @@ -from collections.abc import Mapping +from collections.abc import Callable, Mapping from pathlib import Path -from typing import Any, Callable +from typing import Any import bluesky.plan_stubs as bps import bluesky.plans as bp @@ -383,7 +383,7 @@ def assert_all_detectors_used_collection_numbers( descriptors = find_descriptor_docs(docs) assert len(descriptors) == len(source_history) - for descriptor, expected_source in zip(descriptors, source_history): + for descriptor, expected_source in zip(descriptors, source_history, strict=False): for detector in detectors: source = descriptor.doc.get("data_keys", {}).get(f"{detector.name}_data")[ "source" diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index 9c37b5359..ca416bae8 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -1,6 +1,5 @@ import json from dataclasses import dataclass -from typing import Optional from unittest.mock import MagicMock, call import pytest @@ -346,7 +345,7 @@ def test_pausing_while_idle_denied( @pytest.mark.parametrize("defer", [True, False, None]) def test_calls_pause_if_running( - mockable_state_machine: Handler, client: TestClient, defer: Optional[bool] + mockable_state_machine: Handler, client: TestClient, defer: bool | None ) -> None: re = mockable_state_machine._context.run_engine mockable_state_machine._worker._on_state_change( # type: ignore diff --git a/tests/service/test_subprocess_handler.py b/tests/service/test_subprocess_handler.py index 147e845b4..4f9b00b65 100644 --- a/tests/service/test_subprocess_handler.py +++ b/tests/service/test_subprocess_handler.py @@ -1,4 +1,3 @@ -from typing import Optional from unittest.mock import MagicMock, patch import pytest @@ -73,18 +72,18 @@ def begin_task(self, task: WorkerTask) -> WorkerTask: return WorkerTask(task_id=task.task_id) @property - def active_task(self) -> Optional[TrackableTask]: + def active_task(self) -> TrackableTask | None: return None @property def state(self) -> WorkerState: return WorkerState.IDLE - def pause_worker(self, defer: Optional[bool]) -> None: ... + def pause_worker(self, defer: bool | None) -> None: ... def resume_worker(self) -> None: ... - def cancel_active_task(self, failure: bool, reason: Optional[str]) -> None: ... + def cancel_active_task(self, failure: bool, reason: str | None) -> None: ... @property def tasks(self) -> list[TrackableTask]: @@ -92,7 +91,7 @@ def tasks(self) -> list[TrackableTask]: TrackableTask(task_id="abc", task=Task(name="sleep", params={"time": 0.0})) ] - def get_task_by_id(self, task_id: str) -> Optional[TrackableTask]: + def get_task_by_id(self, task_id: str) -> TrackableTask | None: return None def start(self): ... diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index 2c65ef656..a31d7f3ec 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -1,9 +1,9 @@ import itertools import threading -from collections.abc import Iterable +from collections.abc import Callable, Iterable from concurrent.futures import Future from queue import Full -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar from unittest.mock import MagicMock, patch import pytest @@ -211,7 +211,7 @@ def test_produces_worker_events(worker: Worker, num_runs: int) -> None: task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)] event_sequences = [_sleep_events(task_id) for task_id in task_ids] - for task_id, events in zip(task_ids, event_sequences): + for task_id, events in zip(task_ids, event_sequences, strict=False): assert_run_produces_worker_events(events, worker, task_id) @@ -344,7 +344,7 @@ def test_worker_and_data_events_produce_in_order(worker: Worker) -> None: def assert_running_count_plan_produces_ordered_worker_and_data_events( - expected_events: list[Union[WorkerEvent, DataEvent]], + expected_events: list[WorkerEvent | DataEvent], worker: Worker, task: Task = Task(name="count", params={"detectors": ["image_det"], "num": 1}), # noqa: B008 timeout: float = 5.0, @@ -392,7 +392,7 @@ def take_events( events: list[E] = [] future: "Future[list[E]]" = Future() - def on_event(event: E, event_id: Optional[str]) -> None: + def on_event(event: E, event_id: str | None) -> None: events.append(event) if cutoff_predicate(event): future.set_result(events) @@ -428,7 +428,7 @@ def take_events_from_streams( events: list[Any] = [] future: "Future[list[Any]]" = Future() - def on_event(event: Any, event_id: Optional[str]) -> None: + def on_event(event: Any, event_id: str | None) -> None: print(event) events.append(event) if cutoff_predicate(event):