Skip to content

Commit

Permalink
Add tests for known bad requests (#230)
Browse files Browse the repository at this point in the history
Use status enum for ease of read

Do not cache handler and test client
  • Loading branch information
DiamondJoseph authored May 26, 2023
1 parent 3d85b15 commit 3a07017
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 50 deletions.
9 changes: 0 additions & 9 deletions docs/user/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,6 @@ paths:
schema:
$ref: '#/components/schemas/DeviceModel'
description: Successful Response
'404':
description: Not Found
detail: item not found
'422':
content:
application/json:
Expand Down Expand Up @@ -258,9 +255,6 @@ paths:
schema:
$ref: '#/components/schemas/PlanModel'
description: Successful Response
'404':
description: Not Found
detail: item not found
'422':
content:
application/json:
Expand Down Expand Up @@ -315,9 +309,6 @@ paths:
schema:
$ref: '#/components/schemas/TrackableTask'
description: Successful Response
'404':
description: Not Found
item: not found
'422':
content:
application/json:
Expand Down
51 changes: 31 additions & 20 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Dict, Set

from fastapi import Body, Depends, FastAPI, HTTPException, Request, Response, status
from pydantic import ValidationError
from starlette.responses import JSONResponse

from blueapi.config import ApplicationConfig
from blueapi.worker import RunPlan, TrackableTask, WorkerState
Expand Down Expand Up @@ -37,6 +39,14 @@ async def lifespan(app: FastAPI):
)


@app.exception_handler(KeyError)
async def on_key_error_404(_: Request, __: KeyError):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={"detail": "Item not found"},
)


@app.get("/plans", response_model=PlanResponse)
def get_plans(handler: Handler = Depends(get_handler)):
"""Retrieve information about all available plans."""
Expand All @@ -48,14 +58,10 @@ def get_plans(handler: Handler = Depends(get_handler)):
@app.get(
"/plans/{name}",
response_model=PlanModel,
responses={status.HTTP_404_NOT_FOUND: {"detail": "item not found"}},
)
def get_plan_by_name(name: str, handler: Handler = Depends(get_handler)):
"""Retrieve information about a plan by its (unique) name."""
try:
return PlanModel.from_plan(handler.context.plans[name])
except KeyError:
raise HTTPException(status_code=404, detail="Item not found")
return PlanModel.from_plan(handler.context.plans[name])


@app.get("/devices", response_model=DeviceResponse)
Expand All @@ -72,17 +78,17 @@ def get_devices(handler: Handler = Depends(get_handler)):
@app.get(
"/devices/{name}",
response_model=DeviceModel,
responses={status.HTTP_404_NOT_FOUND: {"detail": "item not found"}},
)
def get_device_by_name(name: str, handler: Handler = Depends(get_handler)):
"""Retrieve information about a devices by its (unique) name."""
try:
return DeviceModel.from_device(handler.context.devices[name])
except KeyError:
raise HTTPException(status_code=404, detail="Item not found")
return DeviceModel.from_device(handler.context.devices[name])


@app.post("/tasks", response_model=TaskResponse, status_code=201)
@app.post(
"/tasks",
response_model=TaskResponse,
status_code=status.HTTP_201_CREATED,
)
def submit_task(
request: Request,
response: Response,
Expand All @@ -92,9 +98,14 @@ def submit_task(
handler: Handler = Depends(get_handler),
):
"""Submit a task to the worker."""
task_id: str = handler.worker.submit_task(task)
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
try:
task_id: str = handler.worker.submit_task(task)
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors()
)


@app.put(
Expand All @@ -108,7 +119,9 @@ def update_task(
) -> WorkerTask:
active_task = handler.worker.get_active_task()
if active_task is not None and not active_task.is_complete:
raise HTTPException(status_code=409, detail="Worker already active")
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Worker already active"
)
elif task.task_id is not None:
handler.worker.begin_task(task.task_id)
return task
Expand All @@ -117,7 +130,6 @@ def update_task(
@app.get(
"/tasks/{task_id}",
response_model=TrackableTask,
responses={status.HTTP_404_NOT_FOUND: {"item": "not found"}},
)
def get_task(
task_id: str,
Expand All @@ -126,10 +138,9 @@ def get_task(
"""Retrieve a task"""

task = handler.worker.get_pending_task(task_id)
if task is not None:
return task
else:
raise HTTPException(status_code=404, detail="Item not found")
if task is None:
raise KeyError
return task


@app.get("/worker/task")
Expand Down
1 change: 1 addition & 0 deletions src/blueapi/utils/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class BlueapiModelConfig(BaseConfig):

extra = Extra.forbid
allow_population_by_field_name = True
underscore_attrs_are_private = True


class BlueapiPlanModelConfig(BaseConfig):
Expand Down
4 changes: 3 additions & 1 deletion src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
WorkerState,
)
from .multithread import run_worker_in_own_thread
from .task import Task
from .task import RunPlan, Task, _lookup_params
from .worker import TrackableTask, Worker
from .worker_busy_error import WorkerBusyError

Expand Down Expand Up @@ -113,6 +113,8 @@ def begin_task(self, task_id: str) -> None:
raise KeyError(f"No pending task with ID {task_id}")

def submit_task(self, task: Task) -> str:
if isinstance(task, RunPlan):
task.set_clean_params(_lookup_params(self._ctx, task))
task_id: str = str(uuid.uuid4())
trackable_task = TrackableTask(task_id=task_id, task=task)
self._pending_tasks[task_id] = trackable_task
Expand Down
20 changes: 11 additions & 9 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Mapping
from typing import Any, Mapping, Optional

from pydantic import BaseModel, Field, parse_obj_as
from pydantic import BaseModel, Field

from blueapi.core import BlueskyContext, Plan
from blueapi.core import BlueskyContext
from blueapi.utils import BlueapiBaseModel


Expand Down Expand Up @@ -36,20 +36,21 @@ class RunPlan(Task):
params: Mapping[str, Any] = Field(
description="Values for parameters to plan, if any", default_factory=dict
)
_sanitized_params: Optional[BaseModel] = Field(default=None)

def set_clean_params(self, model: BaseModel):
self._sanitized_params = model

def do_task(self, ctx: BlueskyContext) -> None:
LOGGER.info(f"Asked to run plan {self.name} with {self.params}")

plan = ctx.plans[self.name]
func = ctx.plan_functions[self.name]
sanitized_params = _lookup_params(ctx, plan, self.params)
sanitized_params = self._sanitized_params or _lookup_params(ctx, self)
plan_generator = func(**sanitized_params.dict())
ctx.run_engine(plan_generator)


def _lookup_params(
ctx: BlueskyContext, plan: Plan, params: Mapping[str, Any]
) -> BaseModel:
def _lookup_params(ctx: BlueskyContext, task: RunPlan) -> BaseModel:
"""
Checks plan parameters against context
Expand All @@ -62,5 +63,6 @@ def _lookup_params(
Mapping[str, Any]: _description_
"""

plan = ctx.plans[task.name]
model = plan.model
return parse_obj_as(model, params)
return model.parse_obj(task.params)
10 changes: 3 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Based on https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # noqa: E501
from typing import Iterator
from unittest.mock import MagicMock

import pytest
from bluesky.run_engine import RunEngineStateMachine
from fastapi.testclient import TestClient
from mock import Mock

from blueapi.service.handler import Handler, get_handler
from blueapi.service.main import app
Expand Down Expand Up @@ -45,14 +45,10 @@ def client(self) -> TestClient:

@pytest.fixture
def handler() -> Iterator[Handler]:
context: BlueskyContext = Mock()
context: BlueskyContext = BlueskyContext(run_engine=MagicMock())
context.run_engine.state = RunEngineStateMachine.States.IDLE
handler = Handler(context=context)
handler = Handler(context=context, messaging_template=MagicMock())

def no_op():
return

handler.start = handler.stop = no_op # type: ignore
yield handler
handler.stop()

Expand Down
65 changes: 61 additions & 4 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from dataclasses import dataclass
from typing import Optional
from unittest.mock import MagicMock
Expand Down Expand Up @@ -25,7 +26,7 @@ class MyModel(BaseModel):
handler.context.plans = {"my-plan": plan}
response = client.get("/plans")

assert response.status_code == 200
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"plans": [{"name": "my-plan"}]}


Expand All @@ -38,10 +39,17 @@ class MyModel(BaseModel):
handler.context.plans = {"my-plan": plan}
response = client.get("/plans/my-plan")

assert response.status_code == 200
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"name": "my-plan"}


def test_get_non_existant_plan_by_name(handler: Handler, client: TestClient) -> None:
response = client.get("/plans/my-plan")

assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {"detail": "Item not found"}


def test_get_devices(handler: Handler, client: TestClient) -> None:
@dataclass
class MyDevice:
Expand All @@ -52,7 +60,7 @@ class MyDevice:
handler.context.devices = {"my-device": device}
response = client.get("/devices")

assert response.status_code == 200
assert response.status_code == status.HTTP_200_OK
assert response.json() == {
"devices": [
{
Expand All @@ -73,13 +81,20 @@ class MyDevice:
handler.context.devices = {"my-device": device}
response = client.get("/devices/my-device")

assert response.status_code == 200
assert response.status_code == status.HTTP_200_OK
assert response.json() == {
"name": "my-device",
"protocols": ["HasName"],
}


def test_get_non_existant_device_by_name(handler: Handler, client: TestClient) -> None:
response = client.get("/devices/my-device")

assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {"detail": "Item not found"}


def test_create_task(handler: Handler, client: TestClient) -> None:
response = client.post("/tasks", json=_TASK.dict())
task_id = response.json()["task_id"]
Expand All @@ -103,6 +118,48 @@ def test_put_plan_begins_task(handler: Handler, client: TestClient) -> None:
handler.worker.stop()


def test_put_plan_with_unknown_plan_name_fails(
handler: Handler, client: TestClient
) -> None:
task_name = "foo"
task_params = {"detectors": ["x"]}
task_json = {"name": task_name, "params": task_params}

response = client.post("/tasks", json=task_json)

assert not handler.worker.get_pending_tasks()
assert response.status_code == status.HTTP_404_NOT_FOUND


def test_get_plan_returns_posted_plan(handler: Handler, client: TestClient) -> None:
handler.worker.start()
post_response = client.post("/tasks", json=_TASK.dict())
task_id = post_response.json()["task_id"]

str_map = json.load(client.get(f"/tasks/{task_id}")) # type: ignore

assert str_map["task_id"] == task_id
assert str_map["task"] == _TASK.dict()


def test_get_non_existant_plan_by_id(handler: Handler, client: TestClient) -> None:
response = client.get("/tasks/foo")

assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {"detail": "Item not found"}


def test_put_plan_with_bad_params_fails(handler: Handler, client: TestClient) -> None:
task_name = "count"
task_params = {"motors": ["x"]}
task_json = {"name": task_name, "params": task_params}

response = client.post("/tasks", json=task_json)

assert not handler.worker.get_pending_tasks()
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


def test_get_state_updates(handler: Handler, client: TestClient) -> None:
assert client.get("/worker/state").text == f'"{WorkerState.IDLE.name}"'
handler.worker._on_state_change( # type: ignore
Expand Down

0 comments on commit 3a07017

Please sign in to comment.