Skip to content

Commit

Permalink
Add helper class for UWS tests
Browse files Browse the repository at this point in the history
Several tests have to manually push queued jobs through their
lifecycle in the mock arq queue. Add a helper class that wraps the
data structures required to do this and provides a simple interface.
This can eventually move into Safir when we move the rest of the
support code there.
  • Loading branch information
rra committed Jun 6, 2024
1 parent 9255704 commit d643e63
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 142 deletions.
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from vocutouts.uws.dependencies import UWSFactory, uws_dependency
from vocutouts.uws.schema import Base

from .support.uws import MockJobRunner


@pytest_asyncio.fixture
async def app(arq_queue: MockArqQueue) -> AsyncIterator[FastAPI]:
Expand Down Expand Up @@ -67,6 +69,11 @@ def mock_google_storage() -> Iterator[MockStorageClient]:
)


@pytest.fixture
def runner(uws_factory: UWSFactory, arq_queue: MockArqQueue) -> MockJobRunner:
return MockJobRunner(uws_factory, arq_queue)


@pytest_asyncio.fixture
async def uws_factory(app: FastAPI) -> AsyncIterator[UWSFactory]:
"""Return a UWS component factory.
Expand Down
31 changes: 8 additions & 23 deletions tests/handlers/async_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

import asyncio
import re
from datetime import UTC, datetime

import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from safir.arq import MockArqQueue

from vocutouts.uws.dependencies import UWSFactory
from vocutouts.uws.models import UWSJobResult

from ..support.uws import MockJobRunner

PENDING_JOB = """
<uws:job
version="1.1"
Expand Down Expand Up @@ -65,12 +64,7 @@


@pytest.mark.asyncio
async def test_create_job(
client: AsyncClient, arq_queue: MockArqQueue, uws_factory: UWSFactory
) -> None:
job_service = uws_factory.create_job_service()
job_storage = uws_factory.create_job_store()

async def test_create_job(client: AsyncClient, runner: MockJobRunner) -> None:
r = await client.post(
"/api/cutout/jobs",
headers={"X-Auth-Request-User": "someone"},
Expand Down Expand Up @@ -100,28 +94,19 @@ async def test_create_job(
assert r.status_code == 303
assert r.headers["Location"] == "https://example.com/api/cutout/jobs/2"

async def set_result() -> None:
await asyncio.sleep(0.2)
job = await job_service.get("someone", "2")
assert job.message_id
await arq_queue.set_in_progress(job.message_id)
await job_storage.mark_executing("2", datetime.now(tz=UTC))
result = [
async def run_job() -> None:
await runner.mark_in_progress("someone", "2", delay=0.2)
results = [
UWSJobResult(
result_id="cutout",
url="s3://some-bucket/some/path",
mime_type="application/fits",
)
]
await asyncio.sleep(0.2)
job = await job_service.get("someone", "2")
assert job.message_id
await arq_queue.set_complete(job.message_id, result=result)
job_result = await arq_queue.get_job_result(job.message_id)
await job_storage.mark_completed("2", job_result)
await runner.mark_complete("someone", "2", results, delay=0.2)

_, r = await asyncio.gather(
set_result(),
run_job(),
client.get(
"/api/cutout/jobs/2",
headers={"X-Auth-Request-User": "someone"},
Expand Down
32 changes: 9 additions & 23 deletions tests/handlers/sync_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,31 @@
from __future__ import annotations

import asyncio
from datetime import UTC, datetime

import pytest
from httpx import AsyncClient
from safir.arq import MockArqQueue

from vocutouts.uws.dependencies import UWSFactory
from vocutouts.uws.models import UWSJobResult

from ..support.uws import MockJobRunner

@pytest.mark.asyncio
async def test_sync(
client: AsyncClient, arq_queue: MockArqQueue, uws_factory: UWSFactory
) -> None:
job_service = uws_factory.create_job_service()
job_storage = uws_factory.create_job_store()

async def set_result(job_id: str) -> None:
await asyncio.sleep(0.1)
job = await job_service.get("someone", job_id)
assert job.message_id
await arq_queue.set_in_progress(job.message_id)
await job_storage.mark_executing(job_id, datetime.now(tz=UTC))
result = [
@pytest.mark.asyncio
async def test_sync(client: AsyncClient, runner: MockJobRunner) -> None:
async def run_job(job_id: str) -> None:
await runner.mark_in_progress("someone", job_id, delay=0.2)
results = [
UWSJobResult(
result_id="cutout",
url="s3://some-bucket/some/path",
mime_type="application/fits",
)
]
job = await job_service.get("someone", job_id)
assert job.message_id
await arq_queue.set_complete(job.message_id, result=result)
job_result = await arq_queue.get_job_result(job.message_id)
await job_storage.mark_completed(job_id, job_result)
await runner.mark_complete("someone", job_id, results)

# GET request.
_, r = await asyncio.gather(
set_result("1"),
run_job("1"),
client.get(
"/api/cutout/sync",
headers={"X-Auth-Request-User": "someone"},
Expand All @@ -53,7 +39,7 @@ async def set_result(job_id: str) -> None:

# POST request.
_, r = await asyncio.gather(
set_result("2"),
run_job("2"),
client.post(
"/api/cutout/sync",
headers={"X-Auth-Request-User": "someone"},
Expand Down
97 changes: 94 additions & 3 deletions tests/support/uws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

from __future__ import annotations

import asyncio
import os
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta

from arq.connections import RedisSettings
from pydantic import SecretStr
from safir.arq import ArqMode, ArqQueue, JobMetadata
from safir.arq import ArqMode, ArqQueue, JobMetadata, MockArqQueue

from vocutouts.uws.config import UWSConfig
from vocutouts.uws.models import UWSJob, UWSJobParameter
from vocutouts.uws.dependencies import UWSFactory
from vocutouts.uws.models import UWSJob, UWSJobParameter, UWSJobResult
from vocutouts.uws.policy import UWSPolicy

__all__ = [
"MockJobRunner",
"TrivialPolicy",
"build_uws_config",
]
Expand Down Expand Up @@ -76,3 +79,91 @@ def build_uws_config() -> UWSConfig:
),
signing_service_account="",
)


class MockJobRunner:
"""Simulate execution of jobs with a mock queue.
When running the test suite, the arq queue is replaced with a mock queue
that doesn't execute workers. That execution has to be simulated by
manually updating state in the mock queue and running the UWS database
worker functions that normally would be run automatically by the queue.
This class wraps that functionality. An instance of it is normally
provided as a fixture, initialized with the same test objects as the test
suite.
Parameters
----------
factory
Factory for UWS components.
arq_queue
Mock arq queue for testing.
"""

def __init__(self, factory: UWSFactory, arq_queue: MockArqQueue) -> None:
self._service = factory.create_job_service()
self._store = factory.create_job_store()
self._arq = arq_queue

async def mark_in_progress(
self, username: str, job_id: str, *, delay: float | None = None
) -> UWSJob:
"""Mark a queued job in progress.
Parameters
----------
username
Owner of job.
job_id
Job ID.
delay
How long to delay in seconds before marking the job as complete.
Returns
-------
UWSJob
Record of the job.
"""
if delay:
await asyncio.sleep(delay)
job = await self._service.get(username, job_id)
assert job.message_id
await self._arq.set_in_progress(job.message_id)
await self._store.mark_executing(job_id, datetime.now(tz=UTC))
return await self._service.get(username, job_id)

async def mark_complete(
self,
username: str,
job_id: str,
results: list[UWSJobResult] | Exception,
*,
delay: float | None = None,
) -> UWSJob:
"""Mark an in progress job as complete.
Parameters
----------
username
Owner of job.
job_id
Job ID.
results
Results to return. May be an exception to simulate a job failure.
delay
How long to delay in seconds before marking the job as complete.
Returns
-------
UWSJob
Record of the job.
"""
if delay:
await asyncio.sleep(delay)
job = await self._service.get(username, job_id)
assert job.message_id
await self._arq.set_complete(job.message_id, result=results)
job_result = await self._arq.get_job_result(job.message_id)
await self._store.mark_completed(job_id, job_result)
return await self._service.get(username, job_id)
7 changes: 6 additions & 1 deletion tests/uws/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vocutouts.uws.handlers import uws_router
from vocutouts.uws.schema import Base

from ..support.uws import TrivialPolicy, build_uws_config
from ..support.uws import MockJobRunner, TrivialPolicy, build_uws_config


@pytest_asyncio.fixture
Expand Down Expand Up @@ -94,6 +94,11 @@ def mock_google_storage() -> Iterator[MockStorageClient]:
)


@pytest.fixture
def runner(uws_factory: UWSFactory, arq_queue: MockArqQueue) -> MockJobRunner:
return MockJobRunner(uws_factory, arq_queue)


@pytest_asyncio.fixture
async def session(app: FastAPI) -> AsyncIterator[async_scoped_session]:
"""Return a database session with no transaction open.
Expand Down
26 changes: 7 additions & 19 deletions tests/uws/job_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@

from __future__ import annotations

from datetime import UTC, datetime, timedelta
from datetime import timedelta

import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from safir.arq import MockArqQueue
from safir.datetime import isodatetime

from vocutouts.uws.config import UWSConfig
from vocutouts.uws.dependencies import UWSFactory
from vocutouts.uws.models import UWSJobParameter, UWSJobResult

from ..support.uws import MockJobRunner

PENDING_JOB = """
<uws:job
version="1.1"
Expand Down Expand Up @@ -96,13 +96,9 @@

@pytest.mark.asyncio
async def test_job_run(
client: AsyncClient,
arq_queue: MockArqQueue,
uws_config: UWSConfig,
uws_factory: UWSFactory,
client: AsyncClient, runner: MockJobRunner, uws_factory: UWSFactory
) -> None:
job_service = uws_factory.create_job_service()
job_storage = uws_factory.create_job_store()
job = await job_service.create(
"user",
run_id="some-run-id",
Expand Down Expand Up @@ -159,27 +155,19 @@ async def test_job_run(
"600",
isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)),
)

# Tell the queue to start the job.
job = await job_service.get("user", "1")
assert job.message_id
await arq_queue.set_in_progress(job.message_id)
await job_storage.mark_executing("1", datetime.now(tz=UTC))
await runner.mark_in_progress("user", "1")

# Tell the queue the job is finished.
result = [
results = [
UWSJobResult(
result_id="cutout",
url="s3://some-bucket/some/path",
mime_type="application/fits",
)
]
await arq_queue.set_complete(job.message_id, result=result)
job_result = await arq_queue.get_job_result(job.message_id)
await job_storage.mark_completed("1", job_result)
job = await runner.mark_complete("user", "1", results)

# Check the job results.
job = await job_service.get("user", "1")
assert job.start_time
assert job.start_time.microsecond == 0
assert job.end_time
Expand Down
Loading

0 comments on commit d643e63

Please sign in to comment.