Skip to content

Commit

Permalink
Adapt to PEP 604 type hint changes
Browse files Browse the repository at this point in the history
  • Loading branch information
joeshannon committed Apr 12, 2024
1 parent f7fc01d commit d21686e
Show file tree
Hide file tree
Showing 29 changed files with 160 additions and 178 deletions.
3 changes: 1 addition & 2 deletions .github/pages/make_switcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from argparse import ArgumentParser
from pathlib import Path
from subprocess import CalledProcessError, check_output
from typing import Optional


def report_output(stdout: bytes, label: str) -> list[str]:
Expand All @@ -24,7 +23,7 @@ def get_sorted_tags_list() -> list[str]:
return report_output(stdout, "Tags list")


def get_versions(ref: str, add: Optional[str]) -> list[str]:
def get_versions(ref: str, add: str | None) -> list[str]:
"""Generate the file containing the list of all GitHub Pages builds."""
# Get the directories (i.e. builds) from the GitHub Pages branch
try:
Expand Down
10 changes: 5 additions & 5 deletions src/blueapi/cli/amq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import Callable, Optional, Union
from collections.abc import Callable

from bluesky.callbacks.best_effort import BestEffortCallback

Expand All @@ -15,13 +15,13 @@ def __init__(self, message: str) -> None:
super().__init__(message)


_Event = Union[WorkerEvent, ProgressEvent, DataEvent]
_Event = WorkerEvent | ProgressEvent | DataEvent


class AmqClient:
app: MessagingTemplate
complete: threading.Event
timed_out: Optional[bool]
timed_out: bool | None

def __init__(self, app: MessagingTemplate) -> None:
self.app = app
Expand All @@ -37,7 +37,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
def subscribe_to_topics(
self,
correlation_id: str,
on_event: Optional[Callable[[WorkerEvent], None]] = None,
on_event: Callable[[WorkerEvent], None] | None = None,
) -> None:
"""Run callbacks on events/progress events with a given correlation id."""

Expand Down Expand Up @@ -70,7 +70,7 @@ def subscribe_to_all_events(
on_event,
)

def wait_for_complete(self, timeout: Optional[float] = None) -> None:
def wait_for_complete(self, timeout: float | None = None) -> None:
self.timed_out = not self.complete.wait(timeout=timeout)

self.complete.clear()
11 changes: 5 additions & 6 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import wraps
from pathlib import Path
from pprint import pprint
from typing import Optional, Union

import click
from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -34,7 +33,7 @@
"-c", "--config", type=Path, help="Path to configuration YAML file", multiple=True
)
@click.pass_context
def main(ctx: click.Context, config: Union[Optional[Path], tuple[Path, ...]]) -> None:
def main(ctx: click.Context, config: Path | None | tuple[Path, ...]) -> None:
# if no command is supplied, run with the options passed

config_loader = ConfigLoader(ApplicationConfig)
Expand Down Expand Up @@ -65,7 +64,7 @@ def main(ctx: click.Context, config: Union[Optional[Path], tuple[Path, ...]]) ->
is_flag=True,
help="[Development only] update the schema in the documentation",
)
def schema(output: Optional[Path] = None, update: bool = False) -> None:
def schema(output: Path | None = None, update: bool = False) -> None:
"""Generate the schema for the REST API"""
schema = generate_schema()

Expand Down Expand Up @@ -139,7 +138,7 @@ def listen_to_events(obj: dict) -> None:

def on_event(
context: MessageContext,
event: Union[WorkerEvent, ProgressEvent, DataEvent],
event: WorkerEvent | ProgressEvent | DataEvent,
) -> None:
converted = json.dumps(event.dict(), indent=2)
print(converted)
Expand All @@ -166,7 +165,7 @@ def on_event(
@check_connection
@click.pass_obj
def run_plan(
obj: dict, name: str, parameters: Optional[str], timeout: Optional[float]
obj: dict, name: str, parameters: str | None, timeout: float | None
) -> None:
"""Run a plan with parameters"""
config: ApplicationConfig = obj["config"]
Expand Down Expand Up @@ -236,7 +235,7 @@ def resume(obj: dict) -> None:
@check_connection
@click.argument("reason", type=str, required=False)
@click.pass_obj
def abort(obj: dict, reason: Optional[str] = None) -> None:
def abort(obj: dict, reason: str | None = None) -> None:
"""
Abort the execution of the current task, marking any ongoing runs as failed,
with optional reason
Expand Down
12 changes: 6 additions & 6 deletions src/blueapi/cli/rest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Callable, Literal, Optional, TypeVar
from collections.abc import Callable, Mapping
from typing import Any, Literal, TypeVar

import requests
from pydantic import parse_obj_as
Expand Down Expand Up @@ -27,7 +27,7 @@ def _is_exception(response: requests.Response) -> bool:
class BlueapiRestClient:
_config: RestConfig

def __init__(self, config: Optional[RestConfig] = None) -> None:
def __init__(self, config: RestConfig | None = None) -> None:
self._config = config or RestConfig()

def get_plans(self) -> PlanResponse:
Expand All @@ -48,7 +48,7 @@ def get_state(self) -> WorkerState:
def set_state(
self,
state: Literal[WorkerState.RUNNING, WorkerState.PAUSED],
defer: Optional[bool] = False,
defer: bool | None = False,
):
return self._request_and_deserialize(
"/worker/state",
Expand Down Expand Up @@ -87,7 +87,7 @@ def update_worker_task(self, task: WorkerTask) -> WorkerTask:
def cancel_current_task(
self,
state: Literal[WorkerState.ABORTING, WorkerState.STOPPING],
reason: Optional[str] = None,
reason: str | None = None,
):
return self._request_and_deserialize(
"/worker/state",
Expand All @@ -100,7 +100,7 @@ def _request_and_deserialize(
self,
suffix: str,
target_type: type[T],
data: Optional[Mapping[str, Any]] = None,
data: Mapping[str, Any] | None = None,
method="GET",
raise_if: Callable[[requests.Response], bool] = _is_exception,
) -> T:
Expand Down
9 changes: 4 additions & 5 deletions src/blueapi/cli/updates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
from collections.abc import Mapping
from typing import Optional, Union

from tqdm import tqdm

Expand Down Expand Up @@ -44,13 +43,13 @@ def _update(self, name: str, view: StatusView) -> None:


class CliEventRenderer:
_task_id: Optional[str]
_task_id: str | None
_pbar_renderer: ProgressBarRenderer

def __init__(
self,
task_id: Optional[str] = None,
pbar_renderer: Optional[ProgressBarRenderer] = None,
task_id: str | None = None,
pbar_renderer: ProgressBarRenderer | None = None,
) -> None:
self._task_id = task_id
if pbar_renderer is None:
Expand All @@ -65,7 +64,7 @@ def on_worker_event(self, event: WorkerEvent) -> None:
if self._relates_to_task(event):
print(str(event.state))

def _relates_to_task(self, event: Union[WorkerEvent, ProgressEvent]) -> bool:
def _relates_to_task(self, event: WorkerEvent | ProgressEvent) -> bool:
if self._task_id is None:
return True
elif isinstance(event, WorkerEvent):
Expand Down
8 changes: 4 additions & 4 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Mapping
from enum import Enum
from pathlib import Path
from typing import Any, Generic, Literal, Optional, TypeVar, Union
from typing import Any, Generic, Literal, TypeVar

import yaml
from pydantic import BaseModel, Field, ValidationError, parse_obj_as, validator
Expand All @@ -20,7 +20,7 @@ class SourceKind(str, Enum):

class Source(BaseModel):
kind: SourceKind
module: Union[Path, str]
module: Path | str


class BasicAuthentication(BaseModel):
Expand Down Expand Up @@ -48,11 +48,11 @@ class StompConfig(BaseModel):

host: str = "localhost"
port: int = 61613
auth: Optional[BasicAuthentication] = None
auth: BasicAuthentication | None = None


class DataWritingConfig(BlueapiBaseModel):
visit_service_url: Optional[str] = None # e.g. "http://localhost:8088/api"
visit_service_url: str | None = None # e.g. "http://localhost:8088/api"
visit_directory: Path = Path("/tmp/0-0")
group_name: str = "example"

Expand Down
41 changes: 19 additions & 22 deletions src/blueapi/core/bluesky_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import inspect
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from typing import (
Any,
Callable,
Optional,
Protocol,
Union,
get_type_hints,
runtime_checkable,
)
Expand Down Expand Up @@ -37,23 +34,23 @@

#: An object that encapsulates the device to do useful things to produce
# data (e.g. move and read)
Device = Union[
Checkable,
Flyable,
HasHints,
HasName,
HasParent,
Movable,
Pausable,
Readable,
Stageable,
Stoppable,
Subscribable,
WritesExternalAssets,
Configurable,
Triggerable,
AsyncDevice,
]
Device = (
Checkable
| Flyable
| HasHints
| HasName
| HasParent
| Movable
| Pausable
| Readable
| Stageable
| Stoppable
| Subscribable
| WritesExternalAssets
| Configurable
| Triggerable
| AsyncDevice
)

#: Protocols defining interface to hardware
BLUESKY_PROTOCOLS = list(Device.__args__) # type: ignore
Expand Down Expand Up @@ -90,7 +87,7 @@ class Plan(BlueapiBaseModel):
"""

name: str = Field(description="Referenceable name of the plan")
description: Optional[str] = Field(
description: str | None = Field(
description="Description/docstring of the plan", default=None
)
model: type[BaseModel] = Field(
Expand Down
11 changes: 4 additions & 7 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import functools
import logging
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from importlib import import_module
from inspect import Parameter, signature
from types import ModuleType
from typing import (
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
Expand Down Expand Up @@ -71,7 +68,7 @@ def wrap(self, plan: MsgGenerator) -> MsgGenerator:
)
yield from wrapped_plan

def find_device(self, addr: Union[str, list[str]]) -> Optional[Device]:
def find_device(self, addr: str | list[str]) -> Device | None:
"""
Find a device in this context, allows for recursive search.
Expand Down Expand Up @@ -168,7 +165,7 @@ def my_plan(a: int, b: str):
self.plan_functions[plan.__name__] = plan
return plan

def device(self, device: Device, name: Optional[str] = None) -> None:
def device(self, device: Device, name: str | None = None) -> None:
"""
Register an device in the context. The device needs to be registered with a
name. If the device is Readable, Movable or Flyable it has a `name`
Expand Down Expand Up @@ -223,7 +220,7 @@ def valid(cls, value):

@classmethod
def __modify_schema__(
cls, field_schema: dict[str, Any], field: Optional[ModelField]
cls, field_schema: dict[str, Any], field: ModelField | None
):
if field:
field_schema.update({field.name: repr(target)})
Expand Down
4 changes: 2 additions & 2 deletions src/blueapi/core/device_lookup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Optional, TypeVar
from typing import Any, TypeVar

from .bluesky_types import Device, is_bluesky_compatible_device

#: Device obeying Bluesky protocols
D = TypeVar("D", bound=Device)


def find_component(obj: Any, addr: list[str]) -> Optional[D]:
def find_component(obj: Any, addr: list[str]) -> D | None:
"""
Best effort function to locate a child device, either in a dictionary of
devices or a device with child attributes.
Expand Down
11 changes: 6 additions & 5 deletions src/blueapi/core/event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from abc import ABC, abstractmethod
from typing import Callable, Generic, Optional, TypeVar
from collections.abc import Callable
from typing import Generic, TypeVar

#: Event type
E = TypeVar("E")
Expand All @@ -15,7 +16,7 @@ class EventStream(ABC, Generic[E, S]):
"""

@abstractmethod
def subscribe(self, __callback: Callable[[E, Optional[str]], None]) -> S:
def subscribe(self, __callback: Callable[[E, str | None], None]) -> S:
"""
Subscribe to new events with a callback
Expand Down Expand Up @@ -47,14 +48,14 @@ class EventPublisher(EventStream[E, int]):
Simple Observable that can be fed values to publish
"""

_subscriptions: dict[int, Callable[[E, Optional[str]], None]]
_subscriptions: dict[int, Callable[[E, str | None], None]]
_count: itertools.count

def __init__(self) -> None:
self._subscriptions = {}
self._count = itertools.count()

def subscribe(self, callback: Callable[[E, Optional[str]], None]) -> int:
def subscribe(self, callback: Callable[[E, str | None], None]) -> int:
sub_id = next(self._count)
self._subscriptions[sub_id] = callback
return sub_id
Expand All @@ -65,7 +66,7 @@ def unsubscribe(self, subscription: int) -> None:
def unsubscribe_all(self) -> None:
self._subscriptions = {}

def publish(self, event: E, correlation_id: Optional[str] = None) -> None:
def publish(self, event: E, correlation_id: str | None = None) -> None:
"""
Publish a new event to all subscribers
Expand Down
Loading

0 comments on commit d21686e

Please sign in to comment.