From f36e8b56f66799541053b915afff486425a5ea14 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Wed, 10 Jul 2024 14:47:14 -0700 Subject: [PATCH] Add timeout support Switch the backend worker to ProcessPoolExecutor and use a really ugly hack to terminate the worker process and rebuild the pool on timeout or job cancellation. Restore support for execution duration, change the default execution duration back to 10 minutes, and re-add the CUTOUT_TIMEOUT environment variable setting to change the default timeout. --- changelog.d/20240710_144457_rra_DM_45138.md | 4 + src/vocutouts/config.py | 14 +-- src/vocutouts/uws/config.py | 20 ++--- src/vocutouts/uws/service.py | 12 ++- src/vocutouts/uws/uwsworker.py | 99 ++++++++++++--------- tests/handlers/async_test.py | 4 +- tests/support/uws.py | 1 + tests/uws/job_api_test.py | 13 ++- tests/uws/job_error_test.py | 2 +- tests/uws/long_polling_test.py | 6 +- tests/uws/workers_test.py | 97 +++++++++++++++++--- 11 files changed, 184 insertions(+), 88 deletions(-) create mode 100644 changelog.d/20240710_144457_rra_DM_45138.md diff --git a/changelog.d/20240710_144457_rra_DM_45138.md b/changelog.d/20240710_144457_rra_DM_45138.md new file mode 100644 index 0000000..2ff84ed --- /dev/null +++ b/changelog.d/20240710_144457_rra_DM_45138.md @@ -0,0 +1,4 @@ +### New features + +- Restore support for execution duration and change the default execution duration back to 10 minutes. Use a very ugly hack to enforce a timeout in the backend worker that will hopefully not be too fragile. +- Re-add the `CUTOUT_TIMEOUT` configuration option to change the default and maximum execution duration for cutout jobs. diff --git a/src/vocutouts/config.py b/src/vocutouts/config.py index c5a348b..c4e9ad3 100644 --- a/src/vocutouts/config.py +++ b/src/vocutouts/config.py @@ -105,12 +105,11 @@ class Config(BaseSettings): ) sync_timeout: timedelta = Field( - timedelta(minutes=1), - title="Timeout for sync requests", - description=( - "The job will continue running as an async job beyond this" - " timeout since cancellation of jobs is not currently supported." - ), + timedelta(minutes=1), title="Timeout for sync requests" + ) + + timeout: timedelta = Field( + timedelta(minutes=10), title="Timeout for cutout jobs" ) tmpdir: Path = Field(Path("/tmp"), title="Temporary directory for workers") @@ -180,7 +179,7 @@ def _validate_arq_queue_url(cls, v: RedisDsn) -> RedisDsn: ) return v - @field_validator("lifetime", "sync_timeout", mode="before") + @field_validator("lifetime", "sync_timeout", "timeout", mode="before") @classmethod def _parse_timedelta(cls, v: str | float | timedelta) -> float | timedelta: """Support human-readable timedeltas.""" @@ -211,6 +210,7 @@ def uws_config(self) -> UWSConfig: return UWSConfig( arq_mode=self.arq_mode, arq_redis_settings=self.arq_redis_settings, + execution_duration=self.timeout, lifetime=self.lifetime, parameters_type=CutoutParameters, signing_service_account=self.service_account, diff --git a/src/vocutouts/uws/config.py b/src/vocutouts/uws/config.py index 48e3539..725b0a3 100644 --- a/src/vocutouts/uws/config.py +++ b/src/vocutouts/uws/config.py @@ -98,6 +98,13 @@ class encapsulates the configuration of the UWS component that may vary by database_url: str """URL for the metadata database.""" + execution_duration: timedelta + """Maximum execution time in seconds. + + Jobs that run longer than this length of time will be automatically + aborted. + """ + lifetime: timedelta """The lifetime of jobs. @@ -127,15 +134,6 @@ class encapsulates the configuration of the UWS component that may vary by database_password: SecretStr | None = None """Password for the database.""" - execution_duration: timedelta = timedelta(seconds=0) - """Maximum execution time in seconds. - - Jobs that run longer than this length of time should be automatically - aborted. However, currently the backend does not support cancelling jobs, - and therefore the only correct value is 0, which indicates that the - execution duration of the job is unlimited. - """ - slack_webhook: SecretStr | None = None """Slack incoming webhook for reporting errors.""" @@ -176,9 +174,7 @@ class encapsulates the configuration of the UWS component that may vary by If provided, called with the requested execution duration and the current job record and should return the new execution duration time. Otherwise, - the execution duration may not be changed. Note that the current backend - does not support cancelling jobs and therefore does not support execution - duration values other than 0. + the execution duration may not be changed. """ wait_timeout: timedelta = timedelta(minutes=1) diff --git a/src/vocutouts/uws/service.py b/src/vocutouts/uws/service.py index bf7d577..be63749 100644 --- a/src/vocutouts/uws/service.py +++ b/src/vocutouts/uws/service.py @@ -432,15 +432,13 @@ async def update_execution_duration( if job.owner != user: raise PermissionDeniedError(f"Access to job {job_id} denied") - # Validate the new value. Only support changes to execution duration - # if a validator is set, which is a signal that the application - # supports cancellation of jobs. The current implementation does not - # support cancelling jobs and therefore cannot enforce a timeout, so - # an execution duration of 0 is currently the only correct value. + # Validate the new value. if validator := self._config.validate_execution_duration: duration = validator(duration, job) - else: - return None + if duration > self._config.execution_duration: + duration = self._config.execution_duration + + # Update the duration in the job. if duration == job.execution_duration: return None await self._storage.update_execution_duration(job_id, duration) diff --git a/src/vocutouts/uws/uwsworker.py b/src/vocutouts/uws/uwsworker.py index ba77b9a..2b0b493 100644 --- a/src/vocutouts/uws/uwsworker.py +++ b/src/vocutouts/uws/uwsworker.py @@ -3,9 +3,11 @@ from __future__ import annotations import asyncio +import os +import signal import uuid from collections.abc import Callable, Sequence -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass from datetime import UTC, datetime, timedelta from enum import Enum @@ -35,6 +37,7 @@ "WorkerJobInfo", "WorkerResult", "WorkerSettings", + "WorkerTimeoutError", "WorkerTransientError", "WorkerUsageError", "build_worker", @@ -102,6 +105,9 @@ class returned by other functions. max_jobs: int """Maximum number of jobs that can be run at one time.""" + allow_abort_jobs: bool = False + """Whether to allow jobs to be aborted.""" + queue_name: str = default_queue_name """Name of arq queue to listen to for jobs.""" @@ -129,11 +135,7 @@ class WorkerJobInfo: """Delegated Gafaelfawr token to act on behalf of the user.""" timeout: timedelta - """Maximum execution time for the job. - - Currently, this is ignored, since the backend workers do not support - cancellation. - """ + """Maximum execution time for the job.""" run_id: str | None = None """User-supplied run ID, if any.""" @@ -259,6 +261,20 @@ class WorkerTransientError(WorkerError): error_type = WorkerErrorType.TRANSIENT +class WorkerTimeoutError(WorkerTransientError): + """Transient error occurred during worker processing. + + The job may be retried with the same parameters and may succeed. + """ + + def __init__(self, elapsed: timedelta, timeout: timedelta) -> None: + msg = ( + f"Job timed out after {elapsed.total_seconds()}s" + f" (timeout: {timeout.total_seconds()}s)" + ) + super().__init__(msg) + + class WorkerUsageError(WorkerError): """Parameters sent by the user were invalid. @@ -271,6 +287,20 @@ class WorkerUsageError(WorkerError): error_type = WorkerErrorType.USAGE +def _restart_pool(pool: ProcessPoolExecutor) -> ProcessPoolExecutor: + """Restart the pool after timeout or job cancellation. + + This is a horrible, fragile hack, but it appears to be the only way to + enforce a timeout currently in Python since there is no way to abort a + job already in progress. Find the processes underlying the pool, kill + them, and then shut down and recreate the pool. + """ + for pid in pool._processes: # noqa: SLF001 + os.kill(pid, signal.SIGINT) + pool.shutdown(wait=True) + return ProcessPoolExecutor(1) + + def build_worker( worker: Callable[[T, WorkerJobInfo, BoundLogger], list[WorkerResult]], config: WorkerConfig[T], @@ -298,27 +328,6 @@ def build_worker( UWS worker configuration. logger Logger to use for messages. - - Notes - ----- - Timeouts and aborting jobs unfortunately are not supported due to - limitations in `concurrent.futures.ThreadPoolExecutor`. Once a thread has - been started, there is no way to stop it until it completes on its own. - Therefore, no job timeout is set or supported, and the timeout set on the - job (which comes from executionduration) is ignored. - - Fixing this appears to be difficult since Python's `threading.Thread` - simply does not support cancellation. It would probably require rebuilding - the worker model on top of processes and killing those processes on - timeout. That would pose problems for cleanup of any temporary resources - created by the process such as temporary files, since Python cleanup code - would not be run. - - The best fix would be for backend code to be rewritten to be async, so - await would become a cancellation point (although this still may not be - enough for compute-heavy code that doesn't use await frequently). However, - the Rubin pipelines code is all sync, so async worker support has not yet - been added due to lack of demand. """ async def startup(ctx: dict[Any, Any]) -> None: @@ -336,13 +345,13 @@ async def startup(ctx: dict[Any, Any]) -> None: ctx["arq"] = arq ctx["logger"] = logger - ctx["pool"] = ThreadPoolExecutor(1) + ctx["pool"] = ProcessPoolExecutor(1) logger.info("Worker startup complete") async def shutdown(ctx: dict[Any, Any]) -> None: logger: BoundLogger = ctx["logger"] - pool: ThreadPoolExecutor = ctx["pool"] + pool: ProcessPoolExecutor = ctx["pool"] pool.shutdown(wait=True, cancel_futures=True) @@ -353,7 +362,7 @@ async def run( ) -> list[WorkerResult]: arq: ArqQueue = ctx["arq"] logger: BoundLogger = ctx["logger"] - pool: ThreadPoolExecutor = ctx["pool"] + pool: ProcessPoolExecutor = ctx["pool"] params = config.parameters_class.model_validate(params_raw) logger = logger.bind( @@ -365,28 +374,34 @@ async def run( if info.run_id: logger = logger.bind(run_id=info.run_id) - await arq.enqueue("uws_job_started", info.job_id, datetime.now(tz=UTC)) + start = datetime.now(tz=UTC) + await arq.enqueue("uws_job_started", info.job_id, start) loop = asyncio.get_running_loop() try: - return await loop.run_in_executor( - pool, worker, params, info, logger - ) + async with asyncio.timeout(info.timeout.total_seconds()): + return await loop.run_in_executor( + pool, worker, params, info, logger + ) + except asyncio.CancelledError: + ctx["pool"] = _restart_pool(pool) + raise + except TimeoutError: + elapsed = datetime.now(tz=UTC) - start + ctx["pool"] = _restart_pool(pool) + raise WorkerTimeoutError(elapsed, info.timeout) from None finally: await arq.enqueue("uws_job_completed", info.job_id) - # Job timeouts are not actually supported since we have no way of stopping - # the sync worker. A timeout will just leave the previous worker running - # and will block all future jobs. Set it to an extremely long value, since - # it can't be disabled entirely. - # # Since the worker is running sync jobs, run one job per pod since they - # will be serialized anyway and no parallelism is possible. If async - # worker support is added, consider making this configurable. + # will be serialized anyway and no parallelism is possible. This also + # allows us to easily restart the job pool on timeout or job abort. If + # async worker support is added, consider making this configurable. return WorkerSettings( functions=[func(run, name=worker.__qualname__)], redis_settings=config.arq_redis_settings, - job_timeout=3600, + job_timeout=config.timeout, max_jobs=1, + allow_abort_jobs=True, on_startup=startup, on_shutdown=shutdown, ) diff --git a/tests/handlers/async_test.py b/tests/handlers/async_test.py index bd9349a..2e318d6 100644 --- a/tests/handlers/async_test.py +++ b/tests/handlers/async_test.py @@ -27,7 +27,7 @@ someone PENDING [DATE] - 0 + 600 [DATE] 1:2:band:value @@ -51,7 +51,7 @@ [DATE] [DATE] [DATE] - 0 + 600 [DATE] 1:2:band:value diff --git a/tests/support/uws.py b/tests/support/uws.py index 50682fe..286ddf5 100644 --- a/tests/support/uws.py +++ b/tests/support/uws.py @@ -74,6 +74,7 @@ def build_uws_config() -> UWSConfig: async_post_dependency=_post_dependency, database_url=database_url, database_password=SecretStr(os.environ["POSTGRES_PASSWORD"]), + execution_duration=timedelta(minutes=10), lifetime=timedelta(days=1), parameters_type=SimpleParameters, signing_service_account="signer@example.com", diff --git a/tests/uws/job_api_test.py b/tests/uws/job_api_test.py index 46c6ad2..46b7679 100644 --- a/tests/uws/job_api_test.py +++ b/tests/uws/job_api_test.py @@ -37,7 +37,7 @@ user {} {} - 0 + {} {} Jane @@ -60,7 +60,7 @@ {} {} {} - 0 + {} {} Jane @@ -127,6 +127,7 @@ async def test_job_run( "1", "PENDING", isodatetime(job.creation_time), + "600", isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)), ) @@ -161,6 +162,7 @@ async def test_job_run( "1", "QUEUED", isodatetime(job.creation_time), + "600", isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)), ) await runner.mark_in_progress("user", "1") @@ -207,6 +209,7 @@ async def test_job_run( isodatetime(job.creation_time), isodatetime(job.start_time), isodatetime(job.end_time), + "600", isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)), ) @@ -261,6 +264,7 @@ async def test_job_api( "1", "PENDING", isodatetime(job.creation_time), + "600", isodatetime(destruction_time), ) @@ -278,7 +282,7 @@ async def test_job_api( ) assert r.status_code == 200 assert r.headers["Content-Type"] == "text/plain; charset=utf-8" - assert r.text == "0" + assert r.text == "600" r = await client.get( "/test/jobs/1/owner", headers={"X-Auth-Request-User": "user"} @@ -318,7 +322,6 @@ async def test_job_api( assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" - # Changing the execution duration is not supported. r = await client.post( "/test/jobs/1/executionduration", headers={"X-Auth-Request-User": "user"}, @@ -337,6 +340,7 @@ async def test_job_api( "1", "PENDING", isodatetime(job.creation_time), + "300", isodatetime(now), ) @@ -367,6 +371,7 @@ async def test_job_api( "2", "PENDING", isodatetime(job.creation_time), + "600", isodatetime(job.destruction_time), ) r = await client.post( diff --git a/tests/uws/job_error_test.py b/tests/uws/job_error_test.py index 6f75964..1001903 100644 --- a/tests/uws/job_error_test.py +++ b/tests/uws/job_error_test.py @@ -28,7 +28,7 @@ {} {} {} - 0 + 600 {} Sarah diff --git a/tests/uws/long_polling_test.py b/tests/uws/long_polling_test.py index 44b96fc..72d2574 100644 --- a/tests/uws/long_polling_test.py +++ b/tests/uws/long_polling_test.py @@ -26,7 +26,7 @@ user {} {} - 0 + 600 {} Naomi @@ -47,7 +47,7 @@ EXECUTING {} {} - 0 + 600 {} Naomi @@ -69,7 +69,7 @@ {} {} {} - 0 + 600 {} Naomi diff --git a/tests/uws/workers_test.py b/tests/uws/workers_test.py index efbce46..9d30408 100644 --- a/tests/uws/workers_test.py +++ b/tests/uws/workers_test.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import time from datetime import timedelta from typing import Any from unittest.mock import ANY @@ -33,26 +34,29 @@ WorkerFatalError, WorkerJobInfo, WorkerResult, + WorkerTimeoutError, build_worker, ) from ..support.uws import SimpleParameters +def hello( + params: SimpleParameters, info: WorkerJobInfo, logger: BoundLogger +) -> list[WorkerResult]: + if params.name == "Timeout": + time.sleep(120) + return [ + WorkerResult( + result_id="greeting", url=f"https://example.com/{params.name}" + ) + ] + + @pytest.mark.asyncio async def test_build_worker( uws_config: UWSConfig, logger: BoundLogger ) -> None: - def hello( - params: SimpleParameters, info: WorkerJobInfo, logger: BoundLogger - ) -> list[WorkerResult]: - return [ - WorkerResult( - result_id="greeting", url=f"https://example.com/{params.name}" - ) - ] - - # Construct the arq configuration and check it. redis_settings = uws_config.arq_redis_settings worker_config = WorkerConfig( arq_mode=uws_config.arq_mode, @@ -70,6 +74,7 @@ def hello( assert isinstance(settings.functions[0], Function) assert settings.functions[0].name == hello.__qualname__ assert settings.redis_settings == uws_config.arq_redis_settings + assert settings.allow_abort_jobs assert settings.queue_name == default_queue_name assert settings.on_startup assert settings.on_shutdown @@ -123,6 +128,77 @@ def hello( await shutdown(ctx) +@pytest.mark.asyncio +async def test_timeout(uws_config: UWSConfig, logger: BoundLogger) -> None: + redis_settings = uws_config.arq_redis_settings + worker_config = WorkerConfig( + arq_mode=uws_config.arq_mode, + arq_queue_url=( + f"redis://{redis_settings.host}:{redis_settings.port}" + f"/{redis_settings.database}" + ), + arq_queue_password=redis_settings.password, + parameters_class=SimpleParameters, + timeout=uws_config.execution_duration, + ) + settings = build_worker(hello, worker_config, logger) + assert isinstance(settings.functions[0], Function) + + # Run the startup hook. + ctx: dict[Any, Any] = {} + startup = settings.on_startup + assert startup + await startup(ctx) + arq = ctx["arq"] + + # Run the worker. + function = settings.functions[0].coroutine + params = SimpleParameters(name="Timeout") + info = WorkerJobInfo( + job_id="42", + user="someuser", + token="some-token", + timeout=timedelta(seconds=1), + run_id="some-run-id", + ) + with pytest.raises(WorkerTimeoutError): + await function(ctx, params, info) + assert list(arq._job_metadata[UWS_QUEUE_NAME].values()) == [ + JobMetadata( + id=ANY, + name="uws_job_started", + args=("42", ANY), + kwargs={}, + enqueue_time=ANY, + status=JobStatus.queued, + queue_name=UWS_QUEUE_NAME, + ), + JobMetadata( + id=ANY, + name="uws_job_completed", + args=("42",), + kwargs={}, + enqueue_time=ANY, + status=JobStatus.queued, + queue_name=UWS_QUEUE_NAME, + ), + ] + + # Make sure that handling the timeout didn't break the worker and we can + # run another job successfully. + params = SimpleParameters(name="Roger") + info.job_id = "43" + result = await function(ctx, params, info) + assert result == [ + WorkerResult(result_id="greeting", url="https://example.com/Roger") + ] + + # Run the shutdown hook. + shutdown = settings.on_shutdown + assert shutdown + await shutdown(ctx) + + @pytest.mark.asyncio async def test_build_uws_worker( arq_queue: MockArqQueue, @@ -156,6 +232,7 @@ async def test_build_uws_worker( expire_jobs = expire_cron.coroutine assert callable(expire_jobs) assert settings.redis_settings == uws_config.arq_redis_settings + assert not settings.allow_abort_jobs assert settings.queue_name == UWS_QUEUE_NAME assert settings.on_startup assert settings.on_shutdown