diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index fa23a48c8..676d234d4 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -25,7 +25,7 @@ WorkerState, ) from .multithread import run_worker_in_own_thread -from .task import RunPlan, Task, _lookup_params +from .task import Task from .worker import TrackableTask, Worker from .worker_busy_error import WorkerBusyError @@ -114,8 +114,7 @@ def begin_task(self, task_id: str) -> None: raise KeyError(f"No pending task with ID {task_id}") def submit_task(self, task: Task) -> str: - if isinstance(task, RunPlan): - task.set_clean_params(_lookup_params(self._ctx, task)) + task.prepare_params(self._ctx) task_id: str = str(uuid.uuid4()) trackable_task = TrackableTask(task_id=task_id, task=task) self._pending_tasks[task_id] = trackable_task diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index ed59fc485..d09add5c7 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -1,5 +1,4 @@ import logging -from abc import ABC, abstractmethod from typing import Any, Mapping, Optional from pydantic import BaseModel, Field @@ -7,27 +6,10 @@ from blueapi.core import BlueskyContext from blueapi.utils import BlueapiBaseModel - -# TODO: Make a TaggedUnion -class Task(ABC, BlueapiBaseModel): - """ - Object that can run with a TaskContext - """ - - @abstractmethod - def do_task(self, __ctx: BlueskyContext) -> None: - """ - Perform the task using the context - - Args: - ctx: Context for the task, holds plans/device/etc - """ - - LOGGER = logging.getLogger(__name__) -class RunPlan(Task): +class Task(BlueapiBaseModel): """ Task that will run a plan """ @@ -36,21 +18,36 @@ class RunPlan(Task): params: Mapping[str, Any] = Field( description="Values for parameters to plan, if any", default_factory=dict ) - _sanitized_params: Optional[BaseModel] = Field(default=None) + _prepared_params: Optional[BaseModel] = None - def set_clean_params(self, model: BaseModel): - self._sanitized_params = model + def prepare_params(self, ctx: BlueskyContext) -> None: + self._ensure_params(ctx) def do_task(self, ctx: BlueskyContext) -> None: LOGGER.info(f"Asked to run plan {self.name} with {self.params}") func = ctx.plan_functions[self.name] - sanitized_params = self._sanitized_params or _lookup_params(ctx, self) - plan_generator = func(**sanitized_params.dict()) + prepared_params = self._ensure_params(ctx) + plan_generator = func(**prepared_params.dict()) ctx.run_engine(plan_generator) + def _ensure_params(self, ctx: BlueskyContext) -> BaseModel: + if self._prepared_params is None: + self._prepared_params = _lookup_params(ctx, self) + return self._prepared_params + + +# Here for backward compatibility pending +# https://github.com/DiamondLightSource/blueapi/issues/253 +class RunPlan(Task): + """ + Task that will run a plan + """ + + ... + -def _lookup_params(ctx: BlueskyContext, task: RunPlan) -> BaseModel: +def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: """ Checks plan parameters against context