Skip to content

Commit

Permalink
Refactor validation (#254)
Browse files Browse the repository at this point in the history
This is a refactor with no changes in behaviour. It was motivated by the
fact that `RunEngineWorker` accessed the private function
`_lookup_params` in `task.py`.
* Deprecate `RunPlan` and move its functionality into `Task`
* Add method to `Task` that can optionally be called to cache the
parameters, call it when run if not
* Remove instance check on task in worker as it is no longer needed
  • Loading branch information
callumforrester authored Jun 1, 2023
1 parent d0ffaaf commit c230f91
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
5 changes: 2 additions & 3 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
WorkerState,
)
from .multithread import run_worker_in_own_thread
from .task import RunPlan, Task, _lookup_params
from .task import Task
from .worker import TrackableTask, Worker
from .worker_busy_error import WorkerBusyError

Expand Down Expand Up @@ -114,8 +114,7 @@ def begin_task(self, task_id: str) -> None:
raise KeyError(f"No pending task with ID {task_id}")

def submit_task(self, task: Task) -> str:
if isinstance(task, RunPlan):
task.set_clean_params(_lookup_params(self._ctx, task))
task.prepare_params(self._ctx)
task_id: str = str(uuid.uuid4())
trackable_task = TrackableTask(task_id=task_id, task=task)
self._pending_tasks[task_id] = trackable_task
Expand Down
47 changes: 22 additions & 25 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,15 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Mapping, Optional

from pydantic import BaseModel, Field

from blueapi.core import BlueskyContext
from blueapi.utils import BlueapiBaseModel


# TODO: Make a TaggedUnion
class Task(ABC, BlueapiBaseModel):
"""
Object that can run with a TaskContext
"""

@abstractmethod
def do_task(self, __ctx: BlueskyContext) -> None:
"""
Perform the task using the context
Args:
ctx: Context for the task, holds plans/device/etc
"""


LOGGER = logging.getLogger(__name__)


class RunPlan(Task):
class Task(BlueapiBaseModel):
"""
Task that will run a plan
"""
Expand All @@ -36,21 +18,36 @@ class RunPlan(Task):
params: Mapping[str, Any] = Field(
description="Values for parameters to plan, if any", default_factory=dict
)
_sanitized_params: Optional[BaseModel] = Field(default=None)
_prepared_params: Optional[BaseModel] = None

def set_clean_params(self, model: BaseModel):
self._sanitized_params = model
def prepare_params(self, ctx: BlueskyContext) -> None:
self._ensure_params(ctx)

def do_task(self, ctx: BlueskyContext) -> None:
LOGGER.info(f"Asked to run plan {self.name} with {self.params}")

func = ctx.plan_functions[self.name]
sanitized_params = self._sanitized_params or _lookup_params(ctx, self)
plan_generator = func(**sanitized_params.dict())
prepared_params = self._ensure_params(ctx)
plan_generator = func(**prepared_params.dict())
ctx.run_engine(plan_generator)

def _ensure_params(self, ctx: BlueskyContext) -> BaseModel:
if self._prepared_params is None:
self._prepared_params = _lookup_params(ctx, self)
return self._prepared_params


# Here for backward compatibility pending
# https://github.com/DiamondLightSource/blueapi/issues/253
class RunPlan(Task):
"""
Task that will run a plan
"""

...


def _lookup_params(ctx: BlueskyContext, task: RunPlan) -> BaseModel:
def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel:
"""
Checks plan parameters against context
Expand Down

0 comments on commit c230f91

Please sign in to comment.