From 9f8a39273404cd9ba7afd851f1ad6c1e3db7f027 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Fri, 16 Feb 2024 15:06:45 -0800 Subject: [PATCH] Convert to Ruff and fix most issues it uncovered Switch to the Annotated syntax for FastAPI handlers so that there is no need to whitelist the FastAPI dependency functions. Exclude the diagnostic of os.getenv in configuration for now, since I will convert to pydantic-settings in an upcoming PR. --- .flake8 | 7 - .pre-commit-config.yaml | 24 +- changelog.d/20240216_150548_rra_DM_42937.md | 3 + pyproject.toml | 147 +++++++-- src/vocutouts/actors.py | 18 +- src/vocutouts/broker.py | 4 +- src/vocutouts/cli.py | 20 +- src/vocutouts/handlers/external.py | 342 +++++++++++--------- src/vocutouts/models/parameters.py | 2 +- src/vocutouts/uws/dependencies.py | 21 +- src/vocutouts/uws/errors.py | 4 +- src/vocutouts/uws/exceptions.py | 12 +- src/vocutouts/uws/handlers.py | 220 +++++++------ src/vocutouts/uws/models.py | 13 +- src/vocutouts/uws/schema/job.py | 5 +- src/vocutouts/uws/schema/job_parameter.py | 2 + src/vocutouts/uws/schema/job_result.py | 2 + src/vocutouts/uws/service.py | 19 +- src/vocutouts/uws/storage.py | 20 +- src/vocutouts/uws/utils.py | 5 +- src/vocutouts/workers.py | 15 +- tests/conftest.py | 4 +- tests/handlers/external_test.py | 2 +- tests/handlers/internal_test.py | 2 +- tests/support/uws.py | 8 +- tests/uws/job_error_test.py | 8 +- tests/uws/job_list_test.py | 6 +- tests/uws/long_polling_test.py | 12 +- tests/uws/policy_test.py | 10 +- 29 files changed, 550 insertions(+), 407 deletions(-) delete mode 100644 .flake8 create mode 100644 changelog.d/20240216_150548_rra_DM_42937.md diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 01ee522..0000000 --- a/.flake8 +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -max-line-length = 79 -# E203: whitespace before :, flake8 disagrees with PEP-8 -# W503: line break after binary operator, flake8 disagrees with PEP-8 -ignore = E203, W503 -exclude = - docs/conf.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 59565be..eaf294c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,22 +2,14 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - - id: check-yaml + - id: check-merge-conflict - id: check-toml + - id: check-yaml + - id: trailing-whitespace - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 - hooks: - - id: isort - additional_dependencies: - - toml - - - repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black - - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.1 hooks: - - id: flake8 + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format diff --git a/changelog.d/20240216_150548_rra_DM_42937.md b/changelog.d/20240216_150548_rra_DM_42937.md new file mode 100644 index 0000000..03daf69 --- /dev/null +++ b/changelog.d/20240216_150548_rra_DM_42937.md @@ -0,0 +1,3 @@ +### Other changes + +- Use [Ruff](https://docs.astral.sh/ruff/) for linting and formatting instead of Black, flake8, and isort. diff --git a/pyproject.toml b/pyproject.toml index 65f228e..ef53172 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,30 +60,6 @@ exclude_lines = [ "if TYPE_CHECKING:" ] -[tool.black] -line-length = 79 -target-version = ["py311"] -exclude = ''' -/( - \.eggs - | \.git - | \.mypy_cache - | \.tox - | \.venv - | _build - | build - | dist -)/ -''' -# Use single-quoted strings so TOML treats the string like a Python r-string -# Multi-line strings are implicitly treated by black as regular expressions - -[tool.isort] -profile = "black" -line_length = 79 -known_first_party = ["vocutouts", "tests"] -skip = ["docs/conf.py"] - [tool.mypy] disallow_untyped_defs = true disallow_incomplete_defs = true @@ -118,6 +94,129 @@ asyncio_mode = "strict" # listed in python_files. python_files = ["tests/*.py", "tests/*/*.py"] +# The rule used with Ruff configuration is to disable every lint that has +# legitimate exceptions that are not dodgy code, rather than cluttering code +# with noqa markers. This is therefore a reiatively relaxed configuration that +# errs on the side of disabling legitimate lints. +# +# Reference for settings: https://beta.ruff.rs/docs/settings/ +# Reference for rules: https://beta.ruff.rs/docs/rules/ +[tool.ruff] +exclude = [ + "docs/**", +] +line-length = 79 +target-version = "py312" + +[tool.ruff.lint] +ignore = [ + "ANN101", # self should not have a type annotation + "ANN102", # cls should not have a type annotation + "ANN401", # sometimes Any is the right type + "ARG001", # unused function arguments are often legitimate + "ARG002", # unused method arguments are often legitimate + "ARG005", # unused lambda arguments are often legitimate + "BLE001", # we want to catch and report Exception in background tasks + "C414", # nested sorted is how you sort by multiple keys with reverse + "COM812", # omitting trailing commas allows black autoreformatting + "D102", # sometimes we use docstring inheritence + "D104", # don't see the point of documenting every package + "D105", # our style doesn't require docstrings for magic methods + "D106", # Pydantic uses a nested Config class that doesn't warrant docs + "D205", # our documentation style allows a folded first line + "EM101", # justification (duplicate string in traceback) is silly + "EM102", # justification (duplicate string in traceback) is silly + "FBT003", # positional booleans are normal for Pydantic field defaults + "FIX002", # point of a TODO comment is that we're not ready to fix it + "G004", # forbidding logging f-strings is appealing, but not our style + "RET505", # disagree that omitting else always makes code more readable + "PLR0911", # often many returns is clearer and simpler style + "PLR0913", # factory pattern uses constructors with many arguments + "PLR2004", # too aggressive about magic values + "PLW0603", # yes global is discouraged but if needed, it's needed + "S105", # good idea but too many false positives on non-passwords + "S106", # good idea but too many false positives on non-passwords + "S107", # good idea but too many false positives on non-passwords + "S603", # not going to manually mark every subprocess call as reviewed + "S607", # using PATH is not a security vulnerability + "SIM102", # sometimes the formatting of nested if statements is clearer + "SIM117", # sometimes nested with contexts are clearer + "TCH001", # we decided to not maintain separate TYPE_CHECKING blocks + "TCH002", # we decided to not maintain separate TYPE_CHECKING blocks + "TCH003", # we decided to not maintain separate TYPE_CHECKING blocks + "TID252", # if we're going to use relative imports, use them always + "TRY003", # good general advice but lint is way too aggressive + "TRY301", # sometimes raising exceptions inside try is the best flow + + # The following settings should be disabled when using ruff format + # per https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules + "W191", + "E111", + "E114", + "E117", + "D206", + "D300", + "Q000", + "Q001", + "Q002", + "Q003", + "COM812", + "COM819", + "ISC001", + "ISC002", + + # TEMPORARY + "RUF009", +] +select = ["ALL"] + +[tool.ruff.lint.per-file-ignores] +"src/uws/handlers.py" = [ + "D103", # FastAPI handlers should not have docstrings +] +"src/vocutouts/config.py" = [ + "S108", # use of /tmp is safe in this context +] +"src/vocutouts/handlers/**" = [ + "D103", # FastAPI handlers should not have docstrings +] +"src/vocutouts/workers.py" = [ + "S108", # use of /tmp is safe in this context +] +"tests/**" = [ + "C901", # tests are allowed to be complex, sometimes that's convenient + "D101", # tests don't need docstrings + "D103", # tests don't need docstrings + "PLR0915", # tests are allowed to be long, sometimes that's convenient + "PT012", # way too aggressive about limiting pytest.raises blocks + "S101", # tests should use assert + "S106", # tests are allowed to hard-code dummy passwords + "SLF001", # tests are allowed to access private members +] + +[tool.ruff.lint.isort] +known-first-party = ["vocutouts", "tests"] +split-on-trailing-comma = false + +# These are too useful as attributes or methods to allow the conflict with the +# built-in to rule out their use. +[tool.ruff.lint.flake8-builtins] +builtins-ignorelist = [ + "all", + "any", + "help", + "id", + "list", + "type", +] + +[tool.ruff.lint.flake8-pytest-style] +fixture-parentheses = false +mark-parentheses = false + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + [tool.scriv] categories = [ "Backwards-incompatible changes", diff --git a/src/vocutouts/actors.py b/src/vocutouts/actors.py index 7313122..e854015 100644 --- a/src/vocutouts/actors.py +++ b/src/vocutouts/actors.py @@ -101,7 +101,7 @@ def cutout( @dramatiq.actor(queue_name="uws", priority=0) def job_started(job_id: str, message_id: str, start_time: str) -> None: - """Wrapper around the UWS function to mark a job as started. + """Call the UWS function to mark a job as started. Notes ----- @@ -113,8 +113,10 @@ def job_started(job_id: str, message_id: str, start_time: str) -> None: """ logger = structlog.get_logger(config.logger_name) start = parse_isodatetime(start_time) - assert broker.worker_session, "Worker database connection not initalized" - assert start, f"Invalid start timestamp {start_time}" + if not broker.worker_session: + raise RuntimeError("Worker database connection not initalized") + if not start: + raise RuntimeError(f"Invalid start timestamp {start_time}") uws_job_started(job_id, message_id, start, broker.worker_session, logger) @@ -122,17 +124,19 @@ def job_started(job_id: str, message_id: str, start_time: str) -> None: def job_completed( message: dict[str, Any], result: list[dict[str, str]] ) -> None: - """Wrapper around the UWS function to mark a job as completed.""" + """Call the UWS function to mark a job as completed.""" logger = structlog.get_logger(config.logger_name) job_id = message["args"][0] - assert broker.worker_session, "Worker database connection not initalized" + if not broker.worker_session: + raise RuntimeError("Worker database connection not initalized") uws_job_completed(job_id, result, broker.worker_session, logger) @dramatiq.actor(queue_name="uws", priority=20) def job_failed(message: dict[str, Any], exception: dict[str, str]) -> None: - """Wrapper around the UWS function to mark a job as errored.""" + """Call the UWS function to mark a job as errored.""" logger = structlog.get_logger(config.logger_name) job_id = message["args"][0] - assert broker.worker_session, "Worker database connection not initalized" + if not broker.worker_session: + raise RuntimeError("Worker database connection not initalized") uws_job_failed(job_id, exception, broker.worker_session, logger) diff --git a/src/vocutouts/broker.py b/src/vocutouts/broker.py index 9c702c5..3c4f273 100644 --- a/src/vocutouts/broker.py +++ b/src/vocutouts/broker.py @@ -11,8 +11,6 @@ from __future__ import annotations -from typing import Optional - import dramatiq import structlog from dramatiq import Broker, Middleware, Worker @@ -33,7 +31,7 @@ results = RedisBackend(host=config.redis_host, password=config.redis_password) """Result backend used by UWS.""" -worker_session: Optional[scoped_session] = None +worker_session: scoped_session | None = None """Shared scoped session used by the UWS worker.""" diff --git a/src/vocutouts/cli.py b/src/vocutouts/cli.py index 6cac483..946e9fe 100644 --- a/src/vocutouts/cli.py +++ b/src/vocutouts/cli.py @@ -6,6 +6,7 @@ import structlog import uvicorn from safir.asyncio import run_with_asyncio +from safir.click import display_help from safir.database import create_database_engine, initialize_database from .config import config @@ -15,11 +16,7 @@ @click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.version_option(message="%(version)s") def main() -> None: - """vo-cutouts main. - - Administrative command-line interface for vo-cutouts. - """ - pass + """Administrative command-line interface for vo-cutouts.""" @main.command() @@ -27,16 +24,7 @@ def main() -> None: @click.pass_context def help(ctx: click.Context, topic: str | None) -> None: """Show help for any command.""" - # The help command implementation is taken from - # https://www.burgundywall.com/post/having-click-help-subcommand - if topic: - if topic in main.commands: - click.echo(main.commands[topic].get_help(ctx)) - else: - raise click.UsageError(f"Unknown help topic {topic}", ctx) - else: - assert ctx.parent - click.echo(ctx.parent.get_help()) + display_help(main, ctx, topic) @main.command() @@ -55,7 +43,7 @@ def run(port: int) -> None: "--reset", is_flag=True, help="Delete all existing database data." ) @run_with_asyncio -async def init(reset: bool) -> None: +async def init(*, reset: bool) -> None: """Initialize the database storage.""" logger = structlog.get_logger(config.logger_name) engine = create_database_engine( diff --git a/src/vocutouts/handlers/external.py b/src/vocutouts/handlers/external.py index 82feca4..f4b1130 100644 --- a/src/vocutouts/handlers/external.py +++ b/src/vocutouts/handlers/external.py @@ -5,7 +5,7 @@ application knows the job parameters. """ -from typing import Literal, Optional +from typing import Annotated, Literal from fastapi import APIRouter, Depends, Form, Query, Request, Response from fastapi.responses import PlainTextResponse, RedirectResponse @@ -94,7 +94,8 @@ async def get_index() -> Index: summary="IVOA service availability", ) async def get_availability( - request: Request, uws_factory: UWSFactory = Depends(uws_dependency) + request: Request, + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() availability = await job_service.availability() @@ -194,61 +195,72 @@ async def _sync_request( summary="Synchronous cutout", ) async def get_sync( + *, request: Request, - id: list[str] = Query( - ..., - title="Source ID", - description=( - "Identifiers of images from which to make a cutout. This" - " parameter is mandatory." + id: Annotated[ + list[str], + Query( + title="Source ID", + description=( + "Identifiers of images from which to make a cutout. This" + " parameter is mandatory." + ), ), - ), - pos: Optional[list[str]] = Query( - None, - title="Cutout positions", - description=( - "Positions to cut out. Supported parameters are RANGE followed" - " by min and max ra and min and max dec; CIRCLE followed by" - " ra, dec, and radius; and POLYGON followed by a list of" - " ra/dec positions for vertices. Arguments must be separated" - " by spaces and parameters are double-precision floating point" - " numbers expressed as strings." + ], + pos: Annotated[ + list[str] | None, + Query( + title="Cutout positions", + description=( + "Positions to cut out. Supported parameters are RANGE followed" + " by min and max ra and min and max dec; CIRCLE followed by" + " ra, dec, and radius; and POLYGON followed by a list of" + " ra/dec positions for vertices. Arguments must be separated" + " by spaces and parameters are double-precision floating point" + " numbers expressed as strings." + ), ), - ), - circle: Optional[list[str]] = Query( - None, - title="Cutout circle positions", - description=( - "Circles to cut out. The value must be the ra and dec of the" - " center of the circle and then the radius, as" - " double-precision floating point numbers expressed as" - " strings and separated by spaces." + ] = None, + circle: Annotated[ + list[str] | None, + Query( + title="Cutout circle positions", + description=( + "Circles to cut out. The value must be the ra and dec of the" + " center of the circle and then the radius, as" + " double-precision floating point numbers expressed as" + " strings and separated by spaces." + ), ), - ), - polygon: Optional[list[str]] = Query( - None, - title="Cutout polygon positions", - description=( - "Polygons to cut out. The value must be ra/dec pairs for each" - " vertex, ordered so that the polygon winding direction is" - " counter-clockwise (when viewed from the origin towards the" - " sky). These parameters are double-precision floating point" - " numbers expressed as strings and separated by spaces." + ] = None, + polygon: Annotated[ + list[str] | None, + Query( + title="Cutout polygon positions", + description=( + "Polygons to cut out. The value must be ra/dec pairs for each" + " vertex, ordered so that the polygon winding direction is" + " counter-clockwise (when viewed from the origin towards the" + " sky). These parameters are double-precision floating point" + " numbers expressed as strings and separated by spaces." + ), ), - ), - runid: Optional[str] = Query( - None, - title="Run ID for job", - description=( - "An opaque string that is returned in the job metadata and" - " job listings. Maybe used by the client to associate jobs" - " with specific larger operations." + ] = None, + runid: Annotated[ + str | None, + Query( + title="Run ID for job", + description=( + "An opaque string that is returned in the job metadata and" + " job listings. Maybe used by the client to associate jobs" + " with specific larger operations." + ), ), - ), - user: str = Depends(auth_dependency), - access_token: str = Depends(auth_delegated_token_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + ] = None, + user: Annotated[str, Depends(auth_dependency)], + access_token: Annotated[str, Depends(auth_delegated_token_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> Response: params = [ JobParameter(parameter_id=k.lower(), value=v, is_post=False) @@ -278,62 +290,73 @@ async def get_sync( summary="Synchronous cutout", ) async def post_sync( + *, request: Request, - id: Optional[str | list[str]] = Form( - None, - title="Source ID", - description=( - "Identifiers of images from which to make a cutout. This" - " parameter is mandatory." + id: Annotated[ + str | list[str] | None, + Form( + title="Source ID", + description=( + "Identifiers of images from which to make a cutout. This" + " parameter is mandatory." + ), ), - ), - pos: Optional[str | list[str]] = Form( - None, - title="Cutout positions", - description=( - "Positions to cut out. Supported parameters are RANGE followed" - " by min and max ra and min and max dec; CIRCLE followed by" - " ra, dec, and radius; and POLYGON followed by a list of" - " ra/dec positions for vertices. Arguments must be separated" - " by spaces and parameters are double-precision floating point" - " numbers expressed as strings." + ] = None, + pos: Annotated[ + str | list[str] | None, + Form( + title="Cutout positions", + description=( + "Positions to cut out. Supported parameters are RANGE followed" + " by min and max ra and min and max dec; CIRCLE followed by" + " ra, dec, and radius; and POLYGON followed by a list of" + " ra/dec positions for vertices. Arguments must be separated" + " by spaces and parameters are double-precision floating point" + " numbers expressed as strings." + ), ), - ), - circle: Optional[str | list[str]] = Form( - None, - title="Cutout circle positions", - description=( - "Circles to cut out. The value must be the ra and dec of the" - " center of the circle and then the radius, as" - " double-precision floating point numbers expressed as" - " strings and separated by spaces." + ] = None, + circle: Annotated[ + str | list[str] | None, + Form( + title="Cutout circle positions", + description=( + "Circles to cut out. The value must be the ra and dec of the" + " center of the circle and then the radius, as" + " double-precision floating point numbers expressed as" + " strings and separated by spaces." + ), ), - ), - polygon: Optional[str | list[str]] = Form( - None, - title="Cutout polygon positions", - description=( - "Polygons to cut out. The value must be ra/dec pairs for each" - " vertex, ordered so that the polygon winding direction is" - " counter-clockwise (when viewed from the origin towards the" - " sky). These parameters are double-precision floating point" - " numbers expressed as strings and separated by spaces." + ] = None, + polygon: Annotated[ + str | list[str] | None, + Form( + title="Cutout polygon positions", + description=( + "Polygons to cut out. The value must be ra/dec pairs for each" + " vertex, ordered so that the polygon winding direction is" + " counter-clockwise (when viewed from the origin towards the" + " sky). These parameters are double-precision floating point" + " numbers expressed as strings and separated by spaces." + ), ), - ), - runid: Optional[str] = Form( - None, - title="Run ID for job", - description=( - "An opaque string that is returned in the job metadata and" - " job listings. Maybe used by the client to associate jobs" - " with specific larger operations." + ] = None, + runid: Annotated[ + str | None, + Form( + title="Run ID for job", + description=( + "An opaque string that is returned in the job metadata and" + " job listings. Maybe used by the client to associate jobs" + " with specific larger operations." + ), ), - ), - params: list[JobParameter] = Depends(uws_post_params_dependency), - user: str = Depends(auth_dependency), - access_token: str = Depends(auth_delegated_token_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + ] = None, + params: Annotated[list[JobParameter], Depends(uws_post_params_dependency)], + user: Annotated[str, Depends(auth_dependency)], + access_token: Annotated[str, Depends(auth_delegated_token_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> Response: runid = None for param in params: @@ -353,65 +376,76 @@ async def post_sync( summary="Create async job", ) async def create_job( + *, request: Request, - id: Optional[str | list[str]] = Form( - None, - title="Source ID", - description=( - "Identifiers of images from which to make a cutout. This" - " parameter is mandatory." + id: Annotated[ + str | list[str] | None, + Form( + title="Source ID", + description=( + "Identifiers of images from which to make a cutout. This" + " parameter is mandatory." + ), ), - ), - pos: Optional[str | list[str]] = Form( - None, - title="Cutout positions", - description=( - "Positions to cut out. Supported parameters are RANGE followed" - " by min and max ra and min and max dec; CIRCLE followed by" - " ra, dec, and radius; and POLYGON followed by a list of" - " ra/dec positions for vertices. Arguments must be separated" - " by spaces and parameters are double-precision floating point" - " numbers expressed as strings." + ] = None, + pos: Annotated[ + str | list[str] | None, + Form( + title="Cutout positions", + description=( + "Positions to cut out. Supported parameters are RANGE followed" + " by min and max ra and min and max dec; CIRCLE followed by" + " ra, dec, and radius; and POLYGON followed by a list of" + " ra/dec positions for vertices. Arguments must be separated" + " by spaces and parameters are double-precision floating point" + " numbers expressed as strings." + ), ), - ), - circle: Optional[str | list[str]] = Form( - None, - title="Cutout circle positions", - description=( - "Circles to cut out. The value must be the ra and dec of the" - " center of the circle and then the radius, as" - " double-precision floating point numbers expressed as" - " strings and separated by spaces." + ] = None, + circle: Annotated[ + str | list[str] | None, + Form( + title="Cutout circle positions", + description=( + "Circles to cut out. The value must be the ra and dec of the" + " center of the circle and then the radius, as" + " double-precision floating point numbers expressed as" + " strings and separated by spaces." + ), ), - ), - polygon: Optional[str | list[str]] = Form( - None, - title="Cutout polygon positions", - description=( - "Polygons to cut out. The value must be ra/dec pairs for each" - " vertex, ordered so that the polygon winding direction is" - " counter-clockwise (when viewed from the origin towards the" - " sky). These parameters are double-precision floating point" - " numbers expressed as strings and separated by spaces." + ] = None, + polygon: Annotated[ + str | list[str] | None, + Form( + title="Cutout polygon positions", + description=( + "Polygons to cut out. The value must be ra/dec pairs for each" + " vertex, ordered so that the polygon winding direction is" + " counter-clockwise (when viewed from the origin towards the" + " sky). These parameters are double-precision floating point" + " numbers expressed as strings and separated by spaces." + ), ), - ), - phase: Optional[Literal["RUN"]] = Query( - None, title="Immediately start job" - ), - runid: Optional[str] = Form( - None, - title="Run ID for job", - description=( - "An opaque string that is returned in the job metadata and" - " job listings. Maybe used by the client to associate jobs" - " with specific larger operations." + ] = None, + phase: Annotated[ + Literal["RUN"] | None, Query(title="Immediately start job") + ] = None, + runid: Annotated[ + str | None, + Form( + title="Run ID for job", + description=( + "An opaque string that is returned in the job metadata and" + " job listings. Maybe used by the client to associate jobs" + " with specific larger operations." + ), ), - ), - params: list[JobParameter] = Depends(uws_post_params_dependency), - user: str = Depends(auth_dependency), - access_token: str = Depends(auth_delegated_token_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + ] = None, + params: Annotated[list[JobParameter], Depends(uws_post_params_dependency)], + user: Annotated[str, Depends(auth_dependency)], + access_token: Annotated[str, Depends(auth_delegated_token_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> str: runid = None for param in params: diff --git a/src/vocutouts/models/parameters.py b/src/vocutouts/models/parameters.py index 33ea3d1..e81431d 100644 --- a/src/vocutouts/models/parameters.py +++ b/src/vocutouts/models/parameters.py @@ -50,7 +50,7 @@ def from_job_parameters( f = parse_stencil(param.parameter_id.upper(), param.value) stencils.append(f) except Exception as e: - msg = f"Invalid cutout parameter: {type(e).__name__}: {str(e)}" + msg = f"Invalid cutout parameter: {type(e).__name__}: {e!s}" raise InvalidCutoutParameterError(msg, params) from e if not ids: raise InvalidCutoutParameterError("No dataset ID given", params) diff --git a/src/vocutouts/uws/dependencies.py b/src/vocutouts/uws/dependencies.py index 46f5a40..d2b542c 100644 --- a/src/vocutouts/uws/dependencies.py +++ b/src/vocutouts/uws/dependencies.py @@ -6,7 +6,7 @@ objects. """ -from typing import Optional +from typing import Annotated from fastapi import Depends, Request from safir.dependencies.db_session import db_session_dependency @@ -68,21 +68,22 @@ class UWSDependency: """Initializes UWS and provides a UWS factory as a dependency.""" def __init__(self) -> None: - self._config: Optional[UWSConfig] = None - self._policy: Optional[UWSPolicy] = None - self._result_store: Optional[ResultStore] = None + self._config: UWSConfig | None = None + self._policy: UWSPolicy | None = None + self._result_store: ResultStore | None = None async def __call__( self, - session: async_scoped_session = Depends(db_session_dependency), - logger: BoundLogger = Depends(logger_dependency), + session: Annotated[ + async_scoped_session, Depends(db_session_dependency) + ], + logger: Annotated[BoundLogger, Depends(logger_dependency)], ) -> UWSFactory: # Tell mypy that not calling initialize first is an error. This would # fail anyway without the asserts when something tried to use the None # value. - assert self._config, "UWSDependency not initialized" - assert self._policy, "UWSDependency not initialized" - assert self._result_store, "UWSDependency not initialized" + if not (self._config and self._policy and self._result_store): + raise RuntimeError("UWSDependency not initialized") return UWSFactory( config=self._config, policy=self._policy, @@ -160,7 +161,7 @@ async def uws_post_params_dependency(request: Request) -> list[JobParameter]: parameters = [] for key, value in (await request.form()).items(): if not isinstance(value, str): - raise ValueError("File upload not supported") + raise TypeError("File upload not supported") parameters.append( JobParameter(parameter_id=key.lower(), value=value, is_post=True) ) diff --git a/src/vocutouts/uws/errors.py b/src/vocutouts/uws/errors.py index 8e44b42..1e0efc7 100644 --- a/src/vocutouts/uws/errors.py +++ b/src/vocutouts/uws/errors.py @@ -19,7 +19,7 @@ async def _uws_error_handler( request: Request, exc: UWSError ) -> PlainTextResponse: - response = f"{exc.error_code.value} {str(exc)}\n" + response = f"{exc.error_code.value} {exc!s}\n" if exc.detail: response += "\n{exc.detail}" return PlainTextResponse(response, status_code=exc.status_code) @@ -28,7 +28,7 @@ async def _uws_error_handler( async def _usage_handler( request: Request, exc: RequestValidationError ) -> PlainTextResponse: - return PlainTextResponse(f"UsageError\n\n{str(exc)}", status_code=422) + return PlainTextResponse(f"UsageError\n\n{exc!s}", status_code=422) def install_error_handlers(app: FastAPI) -> None: diff --git a/src/vocutouts/uws/exceptions.py b/src/vocutouts/uws/exceptions.py index c31a18a..3c540dd 100644 --- a/src/vocutouts/uws/exceptions.py +++ b/src/vocutouts/uws/exceptions.py @@ -6,8 +6,6 @@ from __future__ import annotations -from typing import Optional - from .models import ErrorCode, ErrorType, JobError __all__ = [ @@ -33,7 +31,7 @@ class UWSError(Exception): """ def __init__( - self, error_code: ErrorCode, message: str, detail: Optional[str] = None + self, error_code: ErrorCode, message: str, detail: str | None = None ) -> None: super().__init__(message) self.error_code = error_code @@ -102,7 +100,7 @@ def __init__( error_type: ErrorType, error_code: ErrorCode, message: str, - detail: Optional[str] = None, + detail: str | None = None, ) -> None: super().__init__(error_code, message) self.error_type = error_type @@ -126,7 +124,7 @@ class TaskFatalError(TaskError): """ def __init__( - self, error_code: ErrorCode, message: str, detail: Optional[str] = None + self, error_code: ErrorCode, message: str, detail: str | None = None ) -> None: super().__init__(ErrorType.FATAL, error_code, message, detail) @@ -138,7 +136,7 @@ class TaskTransientError(TaskError): """ def __init__( - self, error_code: ErrorCode, message: str, detail: Optional[str] = None + self, error_code: ErrorCode, message: str, detail: str | None = None ) -> None: super().__init__(ErrorType.TRANSIENT, error_code, message, detail) @@ -146,7 +144,7 @@ def __init__( class UsageError(UWSError): """Invalid parameters were passed to a UWS API.""" - def __init__(self, message: str, detail: Optional[str] = None) -> None: + def __init__(self, message: str, detail: str | None = None) -> None: super().__init__(ErrorCode.USAGE_ERROR, message, detail) self.status_code = 422 diff --git a/src/vocutouts/uws/handlers.py b/src/vocutouts/uws/handlers.py index 578ef77..f3c1537 100644 --- a/src/vocutouts/uws/handlers.py +++ b/src/vocutouts/uws/handlers.py @@ -17,7 +17,7 @@ """ from datetime import datetime -from typing import Literal, Optional +from typing import Annotated, Literal from fastapi import APIRouter, Depends, Form, Query, Request, Response from fastapi.responses import PlainTextResponse, RedirectResponse @@ -53,24 +53,31 @@ summary="Async job list", ) async def get_job_list( + *, request: Request, - phase: Optional[list[ExecutionPhase]] = Query( - None, - title="Execution phase", - description="Limit results to the provided execution phases", - ), - after: Optional[datetime] = Query( - None, - title="Creation date", - description="Limit results to jobs created after this date", - ), - last: Optional[int] = Query( - None, - title="Number of jobs", - description="Return at most the given number of jobs", - ), - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + phase: Annotated[ + list[ExecutionPhase] | None, + Query( + title="Execution phase", + description="Limit results to the provided execution phases", + ), + ] = None, + after: Annotated[ + datetime | None, + Query( + title="Creation date", + description="Limit results to jobs created after this date", + ), + ] = None, + last: Annotated[ + int | None, + Query( + title="Number of jobs", + description="Return at most the given number of jobs", + ), + ] = None, + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() jobs = await job_service.list_jobs( @@ -87,28 +94,33 @@ async def get_job_list( summary="Job details", ) async def get_job( + *, job_id: str, request: Request, - wait: int = Query( - None, - title="Wait for status changes", - description=( - "Maximum number of seconds to wait or -1 to wait for as long as" - " the server permits" + wait: Annotated[ + int | None, + Query( + title="Wait for status changes", + description=( + "Maximum number of seconds to wait or -1 to wait for as long" + " as the server permits" + ), ), - ), - phase: ExecutionPhase = Query( - None, - title="Initial phase for waiting", - description=( - "When waiting for status changes, consider this to be the initial" - " execution phase. If the phase has already changed, return" - " immediately. This parameter should always be provided when" - " wait is used." + ] = None, + phase: Annotated[ + ExecutionPhase | None, + Query( + title="Initial phase for waiting", + description=( + "When waiting for status changes, consider this to be the" + " initial execution phase. If the phase has already changed," + " return immediately. This parameter should always be" + " provided when wait is used." + ), ), - ), - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + ] = None, + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id, wait=wait, wait_phase=phase) @@ -123,11 +135,12 @@ async def get_job( summary="Delete a job", ) async def delete_job( + *, job_id: str, request: Request, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> str: job_service = uws_factory.create_job_service() await job_service.delete(user, job_id) @@ -145,17 +158,20 @@ async def delete_job( summary="Delete a job", ) async def delete_job_via_post( + *, job_id: str, request: Request, - action: Optional[Literal["DELETE"]] = Form( - None, - title="Action to perform", - description="Mandatory, must be set to DELETE", - ), - params: list[JobParameter] = Depends(uws_post_params_dependency), - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + action: Annotated[ + Literal["DELETE"] | None, + Form( + title="Action to perform", + description="Mandatory, must be set to DELETE", + ), + ] = None, + params: Annotated[list[JobParameter], Depends(uws_post_params_dependency)], + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> str: # Work around the obnoxious requirement for case-insensitive parameters, # which is also why the action parameter is declared as optional (but is @@ -184,8 +200,8 @@ async def delete_job_via_post( ) async def get_job_destruction( job_id: str, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) @@ -199,18 +215,21 @@ async def get_job_destruction( summary="Change job destruction time", ) async def post_job_destruction( + *, job_id: str, request: Request, - destruction: Optional[datetime] = Form( - None, - title="New destruction time", - description="Must be in ISO 8601 format.", - example="2021-09-10T10:01:02Z", - ), - params: list[JobParameter] = Depends(uws_post_params_dependency), - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + destruction: Annotated[ + datetime | None, + Form( + title="New destruction time", + description="Must be in ISO 8601 format.", + example="2021-09-10T10:01:02Z", + ), + ] = None, + params: Annotated[list[JobParameter], Depends(uws_post_params_dependency)], + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> str: # Work around the obnoxious requirement for case-insensitive parameters. for param in params: @@ -249,10 +268,11 @@ async def post_job_destruction( summary="Job error", ) async def get_job_error( + *, job_id: str, request: Request, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) @@ -269,8 +289,8 @@ async def get_job_error( ) async def get_job_execution_duration( job_id: str, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) @@ -284,18 +304,21 @@ async def get_job_execution_duration( summary="Change job execution duration", ) async def post_job_execution_duration( + *, job_id: str, request: Request, - executionduration: Optional[int] = Form( - None, - title="New execution duration", - description="Integer seconds of wall clock time.", - example=14400, - ), - params: list[JobParameter] = Depends(uws_post_params_dependency), - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + executionduration: Annotated[ + int | None, + Form( + title="New execution duration", + description="Integer seconds of wall clock time.", + example=14400, + ), + ] = None, + params: Annotated[list[JobParameter], Depends(uws_post_params_dependency)], + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> str: # Work around the obnoxious requirement for case-insensitive parameters. for param in params: @@ -304,8 +327,8 @@ async def post_job_execution_duration( raise ParameterError(msg) try: executionduration = int(param.value) - except Exception: - raise ParameterError(f"Invalid duration {param.value}") + except Exception as e: + raise ParameterError(f"Invalid duration {param.value}") from e if executionduration <= 0: raise ParameterError(f"Invalid duration {param.value}") if not executionduration: @@ -335,8 +358,8 @@ async def post_job_execution_duration( ) async def get_job_owner( job_id: str, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) @@ -349,10 +372,11 @@ async def get_job_owner( summary="Job parameters", ) async def get_job_parameters( + *, job_id: str, request: Request, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) @@ -367,8 +391,8 @@ async def get_job_parameters( ) async def get_job_phase( job_id: str, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) @@ -382,18 +406,21 @@ async def get_job_phase( summary="Start or abort job", ) async def post_job_phase( + *, job_id: str, request: Request, - phase: Optional[Literal["RUN", "ABORT"]] = Form( - None, - title="Job state change", - summary="RUN to start the job, ABORT to abort the job.", - ), - params: list[JobParameter] = Depends(uws_post_params_dependency), - user: str = Depends(auth_dependency), - access_token: str = Depends(auth_delegated_token_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), - logger: BoundLogger = Depends(auth_logger_dependency), + phase: Annotated[ + Literal["RUN", "ABORT"] | None, + Form( + title="Job state change", + summary="RUN to start the job, ABORT to abort the job.", + ), + ] = None, + params: Annotated[list[JobParameter], Depends(uws_post_params_dependency)], + user: Annotated[str, Depends(auth_dependency)], + access_token: Annotated[str, Depends(auth_delegated_token_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], + logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> str: # Work around the obnoxious requirement for case-insensitive parameters. for param in params: @@ -424,8 +451,8 @@ async def post_job_phase( ) async def get_job_quote( job_id: str, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) @@ -443,10 +470,11 @@ async def get_job_quote( summary="Job results", ) async def get_job_results( + *, job_id: str, request: Request, - user: str = Depends(auth_dependency), - uws_factory: UWSFactory = Depends(uws_dependency), + user: Annotated[str, Depends(auth_dependency)], + uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() job = await job_service.get(user, job_id) diff --git a/src/vocutouts/uws/models.py b/src/vocutouts/uws/models.py index e8b59ac..85f30d8 100644 --- a/src/vocutouts/uws/models.py +++ b/src/vocutouts/uws/models.py @@ -9,7 +9,6 @@ from dataclasses import asdict, dataclass from datetime import datetime from enum import Enum -from typing import Optional @dataclass @@ -19,7 +18,7 @@ class Availability: available: bool """Whether the service appears to be available.""" - note: Optional[str] = None + note: str | None = None """Supplemental information, usually when the service is not available.""" @@ -100,7 +99,7 @@ class JobError: use a single message and thus a sequence of length one. """ - detail: Optional[str] = None + detail: str | None = None """Extended error message with additional detail.""" @@ -114,10 +113,10 @@ class JobResult: url: str """The URL for the result, which must point into a GCS bucket.""" - size: Optional[int] = None + size: int | None = None """Size of the result in bytes.""" - mime_type: Optional[str] = None + mime_type: str | None = None """MIME type of the result.""" @@ -135,10 +134,10 @@ class JobResultURL: url: str """Signed URL to retrieve the result.""" - size: Optional[int] = None + size: int | None = None """Size of the result in bytes.""" - mime_type: Optional[str] = None + mime_type: str | None = None """MIME type of the result.""" diff --git a/src/vocutouts/uws/schema/job.py b/src/vocutouts/uws/schema/job.py index 686e7d7..ac6f67e 100644 --- a/src/vocutouts/uws/schema/job.py +++ b/src/vocutouts/uws/schema/job.py @@ -9,6 +9,7 @@ from __future__ import annotations from datetime import datetime +from typing import ClassVar from sqlalchemy import Column, DateTime, Enum, Index, Integer, String, Text from sqlalchemy.orm import Mapped, relationship @@ -22,6 +23,8 @@ class Job(Base): + """Table holding UWS jobs.""" + __tablename__ = "job" id: int = Column(Integer, primary_key=True, autoincrement=True) @@ -47,7 +50,7 @@ class Job(Base): cascade="delete", lazy="selectin", uselist=True ) - __mapper_args__ = {"eager_defaults": True} + __mapper_args__: ClassVar[dict[str, bool]] = {"eager_defaults": True} __table_args__ = ( Index("by_owner_phase", "owner", "phase", "creation_time"), Index("by_owner_time", "owner", "creation_time"), diff --git a/src/vocutouts/uws/schema/job_parameter.py b/src/vocutouts/uws/schema/job_parameter.py index 2d0ca62..5a18879 100644 --- a/src/vocutouts/uws/schema/job_parameter.py +++ b/src/vocutouts/uws/schema/job_parameter.py @@ -18,6 +18,8 @@ class JobParameter(Base): + """Table holding parameters to UWS jobs.""" + __tablename__ = "job_parameter" id: int = Column(Integer, primary_key=True, autoincrement=True) diff --git a/src/vocutouts/uws/schema/job_result.py b/src/vocutouts/uws/schema/job_result.py index 82483df..da0c1d4 100644 --- a/src/vocutouts/uws/schema/job_result.py +++ b/src/vocutouts/uws/schema/job_result.py @@ -10,6 +10,8 @@ class JobResult(Base): + """Table holding job results.""" + __tablename__ = "job_result" id: int = Column(Integer, primary_key=True, autoincrement=True) diff --git a/src/vocutouts/uws/service.py b/src/vocutouts/uws/service.py index a1ca1bf..a1ea0d2 100644 --- a/src/vocutouts/uws/service.py +++ b/src/vocutouts/uws/service.py @@ -3,8 +3,7 @@ from __future__ import annotations import asyncio -from datetime import datetime, timedelta, timezone -from typing import Optional +from datetime import UTC, datetime, timedelta from dramatiq import Message @@ -72,7 +71,7 @@ async def create( self, user: str, *, - run_id: Optional[str] = None, + run_id: str | None = None, params: list[JobParameter], ) -> Job: """Create a pending job. @@ -125,8 +124,8 @@ async def get( user: str, job_id: str, *, - wait: Optional[int] = None, - wait_phase: Optional[ExecutionPhase] = None, + wait: int | None = None, + wait_phase: ExecutionPhase | None = None, wait_for_completion: bool = False, ) -> Job: """Retrieve a job. @@ -185,7 +184,7 @@ async def get( if wait and job.phase in ACTIVE_PHASES: if wait < 0 or wait > self._config.wait_timeout: wait = self._config.wait_timeout - end_time = datetime.now(tz=timezone.utc) + timedelta(seconds=wait) + end_time = datetime.now(tz=UTC) + timedelta(seconds=wait) if not wait_phase: wait_phase = job.phase @@ -203,7 +202,7 @@ def not_done(j: Job) -> bool: while not_done(job): await asyncio.sleep(delay) job = await self._storage.get(job_id) - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) if now >= end_time: break delay *= 1.5 @@ -216,9 +215,9 @@ async def list_jobs( self, user: str, *, - phases: Optional[list[ExecutionPhase]] = None, - after: Optional[datetime] = None, - count: Optional[int] = None, + phases: list[ExecutionPhase] | None = None, + after: datetime | None = None, + count: int | None = None, ) -> list[JobDescription]: """List the jobs for a particular user. diff --git a/src/vocutouts/uws/storage.py b/src/vocutouts/uws/storage.py index 58fa805..c193892 100644 --- a/src/vocutouts/uws/storage.py +++ b/src/vocutouts/uws/storage.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from functools import wraps -from typing import Any, Optional, TypeVar, cast +from typing import Any, TypeVar, cast from safir.database import datetime_from_db, datetime_to_db from sqlalchemy import delete @@ -159,7 +159,7 @@ async def add( self, *, owner: str, - run_id: Optional[str] = None, + run_id: str | None = None, params: list[JobParameter], execution_duration: int, lifetime: int, @@ -188,7 +188,7 @@ async def add( vocutouts.uws.models.Job The internal representation of the newly-created job. """ - now = datetime.now(tz=timezone.utc).replace(microsecond=0) + now = datetime.now(tz=UTC).replace(microsecond=0) destruction_time = now + timedelta(seconds=lifetime) sql_params = [ SQLJobParameter( @@ -223,7 +223,7 @@ async def availability(self) -> Availability: note = "cannot query UWS job database" return Availability(available=False, note=note) except Exception as e: - note = f"{type(e).__name__}: {str(e)}" + note = f"{type(e).__name__}: {e!s}" return Availability(available=False, note=note) async def delete(self, job_id: str) -> None: @@ -242,9 +242,9 @@ async def list_jobs( self, user: str, *, - phases: Optional[list[ExecutionPhase]] = None, - after: Optional[datetime] = None, - count: Optional[int] = None, + phases: list[ExecutionPhase] | None = None, + after: datetime | None = None, + count: int | None = None, ) -> list[JobDescription]: """List the jobs for a particular user. @@ -404,7 +404,7 @@ def mark_completed(self, job_id: str, results: list[JobResult]) -> None: with self._session.begin(): job = self._get_job(job_id) job.phase = ExecutionPhase.COMPLETED - job.end_time = datetime_to_db(datetime.now(tz=timezone.utc)) + job.end_time = datetime_to_db(datetime.now(tz=UTC)) for sequence, result in enumerate(results, start=1): sql_result = SQLJobResult( job_id=job.id, @@ -422,7 +422,7 @@ def mark_errored(self, job_id: str, error: JobError) -> None: with self._session.begin(): job = self._get_job(job_id) job.phase = ExecutionPhase.ERROR - job.end_time = datetime_to_db(datetime.now(tz=timezone.utc)) + job.end_time = datetime_to_db(datetime.now(tz=UTC)) job.error_type = error.error_type job.error_code = error.error_code job.error_message = error.message diff --git a/src/vocutouts/uws/utils.py b/src/vocutouts/uws/utils.py index d1017f5..1e96a55 100644 --- a/src/vocutouts/uws/utils.py +++ b/src/vocutouts/uws/utils.py @@ -2,12 +2,13 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import UTC, datetime def isodatetime(timestamp: datetime) -> str: """Format a timestamp in UTC in the expected UWS ISO date format.""" - assert timestamp.tzinfo in (None, timezone.utc) + if timestamp.tzinfo not in (None, UTC): + raise ValueError("Timestamp not in UTC time zone") return timestamp.strftime("%Y-%m-%dT%H:%M:%SZ") diff --git a/src/vocutouts/workers.py b/src/vocutouts/workers.py index 8006491..8955fc1 100644 --- a/src/vocutouts/workers.py +++ b/src/vocutouts/workers.py @@ -15,7 +15,7 @@ from __future__ import annotations import os -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from urllib.parse import urlparse from uuid import UUID @@ -59,19 +59,19 @@ @dramatiq.actor(queue_name="uws") def job_started(job_id: str, message_id: str, start_time: str) -> None: - pass + """Mark a job as started.""" @dramatiq.actor(queue_name="uws") def job_completed( message: dict[str, Any], result: list[dict[str, str]] ) -> None: - pass + """Mark a job as completed.""" @dramatiq.actor(queue_name="uws") def job_failed(message: dict[str, Any], exception: dict[str, str]) -> None: - pass + """Mark a job as failed.""" # Exceptions of these names are handled specially by job_failed. @@ -102,7 +102,6 @@ def get_backend(butler_label: str, access_token: str) -> ImageCutoutBackend: lsst.image_cutout_backend.ImageCutoutBackend Backend to use. """ - butler = BUTLER_FACTORY.create_butler( label=butler_label, access_token=access_token ) @@ -178,7 +177,7 @@ def cutout( # Tell UWS that we have started executing. message = CurrentMessage.get_current_message() - now = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + now = datetime.now(tz=UTC).strftime("%Y-%m-%dT%H:%M:%SZ") job_started.send(job_id, message.message_id, now) # Currently, only a single dataset ID and a single stencil are supported. @@ -223,8 +222,8 @@ def cutout( result = backend.process_uuid(sky_stencils[0], uuid, mask_plane=None) except Exception as e: logger.exception("Cutout processing failed") - msg = f"Error Cutout processing failed\n{type(e).__name__}: {str(e)}" - raise TaskTransientError(msg) + msg = f"Error Cutout processing failed\n{type(e).__name__}: {e!s}" + raise TaskTransientError(msg) from e # Return the result URL. This must be a dict representation of a # vocutouts.uws.models.JobResult. diff --git a/tests/conftest.py b/tests/conftest.py index 00c403a..f27c5be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Iterator -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from typing import Any import dramatiq @@ -35,7 +35,7 @@ def cutout_test( access_token: str, ) -> list[dict[str, Any]]: message = CurrentMessage.get_current_message() - now = isodatetime(datetime.now(tz=timezone.utc)) + now = isodatetime(datetime.now(tz=UTC)) job_started.send(job_id, message.message_id, now) assert len(dataset_ids) == 1 assert access_token == "sometoken" diff --git a/tests/handlers/external_test.py b/tests/handlers/external_test.py index 3e2d8e7..5d78736 100644 --- a/tests/handlers/external_test.py +++ b/tests/handlers/external_test.py @@ -48,7 +48,7 @@ @pytest.mark.asyncio async def test_get_index(client: AsyncClient) -> None: - """Test ``GET /api/cutout/``""" + """Test ``GET /api/cutout/``.""" response = await client.get("/api/cutout/") assert response.status_code == 200 data = response.json() diff --git a/tests/handlers/internal_test.py b/tests/handlers/internal_test.py index f3d011f..4cc0137 100644 --- a/tests/handlers/internal_test.py +++ b/tests/handlers/internal_test.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio async def test_get_index(client: AsyncClient) -> None: - """Test ``GET /``""" + """Test ``GET /``.""" response = await client.get("/") assert response.status_code == 200 data = response.json() diff --git a/tests/support/uws.py b/tests/support/uws.py index ac382d3..af92eb8 100644 --- a/tests/support/uws.py +++ b/tests/support/uws.py @@ -4,8 +4,8 @@ import asyncio import os -from datetime import datetime, timezone -from typing import Any, Optional +from datetime import UTC, datetime +from typing import Any import dramatiq import structlog @@ -36,7 +36,7 @@ results = StubBackend() """Result backend used by UWS.""" -worker_session: Optional[scoped_session] = None +worker_session: scoped_session | None = None """Shared scoped session used by the UWS worker.""" uws_broker.add_middleware(CurrentMessage()) @@ -71,7 +71,7 @@ def before_worker_boot(self, broker: Broker, worker: Worker) -> None: @dramatiq.actor(broker=uws_broker, queue_name="job", store_results=True) def trivial_job(job_id: str) -> list[dict[str, Any]]: message = CurrentMessage.get_current_message() - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) job_started.send(job_id, message.message_id, isodatetime(now)) return [ { diff --git a/tests/uws/job_error_test.py b/tests/uws/job_error_test.py index e42501c..350aa6a 100644 --- a/tests/uws/job_error_test.py +++ b/tests/uws/job_error_test.py @@ -3,7 +3,7 @@ from __future__ import annotations import time -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any import dramatiq @@ -82,7 +82,7 @@ async def test_temporary_error( @dramatiq.actor(broker=uws_broker, queue_name="job") def error_transient_job(job_id: str) -> list[dict[str, Any]]: message = CurrentMessage.get_current_message() - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) job_started.send(job_id, message.message_id, isodatetime(now)) time.sleep(0.5) raise TaskTransientError( @@ -147,7 +147,7 @@ async def test_fatal_error( @dramatiq.actor(broker=uws_broker, queue_name="job") def error_fatal_job(job_id: str) -> list[dict[str, Any]]: message = CurrentMessage.get_current_message() - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) job_started.send(job_id, message.message_id, isodatetime(now)) time.sleep(0.5) raise TaskFatalError(ErrorCode.ERROR, "Error Whoops\nSome details") @@ -210,7 +210,7 @@ async def test_unknown_error( @dramatiq.actor(broker=uws_broker, queue_name="job") def error_unknown_job(job_id: str) -> list[dict[str, Any]]: message = CurrentMessage.get_current_message() - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) time.sleep(0.5) job_started.send(job_id, message.message_id, isodatetime(now)) raise ValueError("Unknown exception") diff --git a/tests/uws/job_list_test.py b/tests/uws/job_list_test.py index 20fe6ce..dcc46d0 100644 --- a/tests/uws/job_list_test.py +++ b/tests/uws/job_list_test.py @@ -6,7 +6,7 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import pytest from httpx import AsyncClient @@ -112,7 +112,7 @@ async def test_job_list( async with session.begin(): for i, job in enumerate(jobs): hours = (2 - i) * 2 - creation = datetime.now(tz=timezone.utc) - timedelta(hours=hours) + creation = datetime.now(tz=UTC) - timedelta(hours=hours) stmt = ( update(SQLJob) .where(SQLJob.id == int(job.job_id)) @@ -131,7 +131,7 @@ async def test_job_list( assert r.text == expected # Filter by recency. - threshold = datetime.now(tz=timezone.utc) - timedelta(hours=1) + threshold = datetime.now(tz=UTC) - timedelta(hours=1) r = await client.get( "/jobs", headers={"X-Auth-Request-User": "user"}, diff --git a/tests/uws/long_polling_test.py b/tests/uws/long_polling_test.py index a3e3436..4b5d045 100644 --- a/tests/uws/long_polling_test.py +++ b/tests/uws/long_polling_test.py @@ -3,7 +3,7 @@ from __future__ import annotations import time -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from typing import Any import dramatiq @@ -104,13 +104,13 @@ async def test_poll( # Poll for changes for two seconds. Nothing will happen since there is no # worker. - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) r = await client.get( "/jobs/1", headers={"X-Auth-Request-User": "user"}, params={"WAIT": "2"}, ) - assert (datetime.now(tz=timezone.utc) - now).total_seconds() >= 2 + assert (datetime.now(tz=UTC) - now).total_seconds() >= 2 assert r.status_code == 200 assert r.text == PENDING_JOB.strip().format( "PENDING", @@ -121,7 +121,7 @@ async def test_poll( @dramatiq.actor(broker=uws_broker, queue_name="job", store_results=True) def wait_job(job_id: str) -> list[dict[str, Any]]: message = CurrentMessage.get_current_message() - now = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + now = datetime.now(tz=UTC).strftime("%Y-%m-%dT%H:%M:%SZ") job_started.send(job_id, message.message_id, now) time.sleep(2) return [ @@ -147,7 +147,7 @@ def wait_job(job_id: str) -> list[dict[str, Any]]: isodatetime(job.creation_time), isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)), ) - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) worker = Worker(uws_broker, worker_timeout=100) worker.start() @@ -182,6 +182,6 @@ def wait_job(job_id: str) -> list[dict[str, Any]]: isodatetime(job.end_time), isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)), ) - assert (datetime.now(tz=timezone.utc) - now).total_seconds() >= 2 + assert (datetime.now(tz=UTC) - now).total_seconds() >= 2 finally: worker.stop() diff --git a/tests/uws/policy_test.py b/tests/uws/policy_test.py index 937c77c..9a0e773 100644 --- a/tests/uws/policy_test.py +++ b/tests/uws/policy_test.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import pytest from httpx import AsyncClient @@ -20,7 +20,7 @@ class Policy(TrivialPolicy): def validate_destruction( self, destruction: datetime, job: Job ) -> datetime: - max_destruction = datetime.now(tz=timezone.utc) + timedelta(days=1) + max_destruction = datetime.now(tz=UTC) + timedelta(days=1) if destruction > max_destruction: return max_destruction else: @@ -62,7 +62,7 @@ async def test_policy( # Change the destruction time, first to something that should be honored # and then something that should be overridden. - destruction = datetime.now(tz=timezone.utc) + timedelta(hours=1) + destruction = datetime.now(tz=UTC) + timedelta(hours=1) r = await client.post( "/jobs/1/destruction", headers={"X-Auth-Request-User": "user"}, @@ -75,8 +75,8 @@ async def test_policy( ) assert r.status_code == 200 assert r.text == isodatetime(destruction) - destruction = datetime.now(tz=timezone.utc) + timedelta(days=5) - expected = datetime.now(tz=timezone.utc) + timedelta(days=1) + destruction = datetime.now(tz=UTC) + timedelta(days=5) + expected = datetime.now(tz=UTC) + timedelta(days=1) r = await client.post( "/jobs/1/destruction", headers={"X-Auth-Request-User": "user"},