diff --git a/CHANGES.md b/CHANGES.md index d04917a10..b29958eaa 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,8 @@ ### Changed * Updated CI to test against [pgstac v0.6.12](https://github.com/stac-utils/pgstac/releases/tag/v0.6.12) ([#511](https://github.com/stac-utils/stac-fastapi/pull/511)) +* Reworked `update_openapi` and added a test for it ([#523](https://github.com/stac-utils/stac-fastapi/pull/523)) +* Limit values above 10,000 are now replaced with 10,000 instead of returning a 400 error ([#526](https://github.com/stac-utils/stac-fastapi/pull/526)) ### Removed @@ -22,6 +24,12 @@ * `self` link rel for `/collections/{c_id}/items` ([#508](https://github.com/stac-utils/stac-fastapi/pull/508)) * Media type of the item collection endpoint ([#508](https://github.com/stac-utils/stac-fastapi/pull/508)) * Manually exclude non-truthy optional values from sqlalchemy serialization of Collections ([#508](https://github.com/stac-utils/stac-fastapi/pull/508)) +* Support `intersects` in GET requests ([#521](https://github.com/stac-utils/stac-fastapi/pull/521)) +* Deleting items that had repeated ids in other collections ([#520](https://github.com/stac-utils/stac-fastapi/pull/520)) + +### Deprecated + +* Deprecated `VndOaiResponse` and `config_openapi`, will be removed in v3.0 ([#523](https://github.com/stac-utils/stac-fastapi/pull/523)) ## [2.4.3] - 2022-11-25 diff --git a/stac_fastapi/api/stac_fastapi/api/openapi.py b/stac_fastapi/api/stac_fastapi/api/openapi.py index 574176a46..2ccd48282 100644 --- a/stac_fastapi/api/stac_fastapi/api/openapi.py +++ b/stac_fastapi/api/stac_fastapi/api/openapi.py @@ -1,8 +1,11 @@ """openapi.""" +import warnings + from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, Response +from starlette.routing import Route, request_response from stac_fastapi.api.config import ApiExtensions from stac_fastapi.types.config import ApiSettings @@ -13,37 +16,54 @@ class VndOaiResponse(JSONResponse): media_type = "application/vnd.oai.openapi+json;version=3.0" + def __init__(self, *args, **kwargs): + """Init function with deprecation warning.""" + warnings.warn( + "VndOaiResponse is deprecated and will be removed in v3.0", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + def update_openapi(app: FastAPI) -> FastAPI: """Update OpenAPI response content-type. This function modifies the openapi route to comply with the STAC API spec's - required content-type response header + required content-type response header. """ - urls = (server_data.get("url") for server_data in app.servers) - server_urls = {url for url in urls if url} - - async def openapi(req: Request) -> JSONResponse: - root_path = req.scope.get("root_path", "").rstrip("/") - if root_path not in server_urls: - if root_path and app.root_path_in_servers: - app.servers.insert(0, {"url": root_path}) - server_urls.add(root_path) - return VndOaiResponse(app.openapi()) - - # Remove the default openapi route - app.router.routes = list( - filter(lambda r: r.path != app.openapi_url, app.router.routes) + # Find the route for the openapi_url in the app + openapi_route: Route = next( + route for route in app.router.routes if route.path == app.openapi_url ) - # Add the updated openapi route - app.add_route(app.openapi_url, openapi, include_in_schema=False) + # Store the old endpoint function so we can call it from the patched function + old_endpoint = openapi_route.endpoint + + # Create a patched endpoint function that modifies the content type of the response + async def patched_openapi_endpoint(req: Request) -> Response: + # Get the response from the old endpoint function + response: JSONResponse = await old_endpoint(req) + # Update the content type header in place + response.headers[ + "content-type" + ] = "application/vnd.oai.openapi+json;version=3.0" + # Return the updated response + return response + + # When a Route is accessed the `handle` function calls `self.app`. Which is + # the endpoint function wrapped with `request_response`. So we need to wrap + # our patched function and replace the existing app with it. + openapi_route.app = request_response(patched_openapi_endpoint) + + # return the patched app return app -# TODO: Remove or fix, this is currently unused -# and calls a missing method on ApiSettings def config_openapi(app: FastAPI, settings: ApiSettings): """Config openapi.""" + warnings.warn( + "config_openapi is deprecated and will be removed in v3.0", + DeprecationWarning, + ) def custom_openapi(): """Config openapi.""" diff --git a/stac_fastapi/api/tests/test_api.py b/stac_fastapi/api/tests/test_api.py index ab5a304d4..15629e7b7 100644 --- a/stac_fastapi/api/tests/test_api.py +++ b/stac_fastapi/api/tests/test_api.py @@ -49,6 +49,15 @@ def _assert_dependency_applied(api, routes): ), "Authenticated requests should be accepted" assert response.json() == "dummy response" + def test_openapi_content_type(self): + api = self._build_api() + with TestClient(api.app) as client: + response = client.get(api.settings.openapi_url) + assert ( + response.headers["content-type"] + == "application/vnd.oai.openapi+json;version=3.0" + ) + def test_build_api_with_route_dependencies(self): routes = [ {"path": "/collections", "method": "POST"}, diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index e074a9f9c..a8c73d9f8 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -352,6 +352,7 @@ async def get_search( sortby: Optional[str] = None, filter: Optional[str] = None, filter_lang: Optional[str] = None, + intersects: Optional[str] = None, **kwargs, ) -> ItemCollection: """Cross catalog search (GET). @@ -389,6 +390,9 @@ async def get_search( if datetime: base_args["datetime"] = datetime + if intersects: + base_args["intersects"] = orjson.loads(unquote_plus(intersects)) + if sortby: # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form sort_param = [] diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py index 2991a92a5..bc6cc96e6 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/db.py @@ -1,7 +1,8 @@ """Database connection handling.""" import json -from typing import Dict, Union +from contextlib import contextmanager +from typing import Dict, Generator, Union import attr import orjson @@ -61,7 +62,7 @@ async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]): arg -- the argument to the PostgreSQL function as either a string or a dict that will be converted into jsonb """ - try: + with translate_pgstac_errors(): if isinstance(arg, str): async with pool.acquire() as conn: q, p = render( @@ -80,6 +81,13 @@ async def dbfunc(pool: pool, func: str, arg: Union[str, Dict]): item=json.dumps(arg), ) return await conn.fetchval(q, *p) + + +@contextmanager +def translate_pgstac_errors() -> Generator[None, None, None]: + """Context manager that translates pgstac errors into FastAPI errors.""" + try: + yield except exceptions.UniqueViolationError as e: raise ConflictError from e except exceptions.NoDataFoundError as e: diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py index 9013a6319..68479aa34 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py @@ -4,6 +4,7 @@ from typing import Optional, Union import attr +from buildpg import render from fastapi import HTTPException from starlette.responses import JSONResponse, Response @@ -11,7 +12,7 @@ AsyncBaseBulkTransactionsClient, Items, ) -from stac_fastapi.pgstac.db import dbfunc +from stac_fastapi.pgstac.db import dbfunc, translate_pgstac_errors from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks from stac_fastapi.types import stac as stac_types from stac_fastapi.types.core import AsyncBaseTransactionsClient @@ -98,12 +99,19 @@ async def update_collection( return stac_types.Collection(**collection) async def delete_item( - self, item_id: str, **kwargs + self, item_id: str, collection_id: str, **kwargs ) -> Optional[Union[stac_types.Item, Response]]: """Delete item.""" request = kwargs["request"] pool = request.app.state.writepool - await dbfunc(pool, "delete_item", item_id) + async with pool.acquire() as conn: + q, p = render( + "SELECT * FROM delete_item(:item::text, :collection::text);", + item=item_id, + collection=collection_id, + ) + with translate_pgstac_errors(): + await conn.fetchval(q, *p) return JSONResponse({"deleted item": item_id}) async def delete_collection( diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 06a80675e..e3baf32c1 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -3,6 +3,7 @@ import orjson import pytest +from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent STAC_CORE_ROUTES = [ "GET /", @@ -24,6 +25,24 @@ "PUT /collections/{collection_id}/items/{item_id}", ] +GLOBAL_BBOX = [-180.0, -90.0, 180.0, 90.0] +GLOBAL_GEOMETRY = { + "type": "Polygon", + "coordinates": ( + ( + (180.0, -90.0), + (180.0, 90.0), + (-180.0, 90.0), + (-180.0, -90.0), + (180.0, -90.0), + ), + ), +} +DEFAULT_EXTENT = Extent( + SpatialExtent(GLOBAL_BBOX), + TemporalExtent([[datetime.now(), None]]), +) + async def test_post_search_content_type(app_client): params = {"limit": 1} @@ -183,7 +202,7 @@ async def test_app_query_extension_limit_gt10000( params = {"limit": 10001} resp = await app_client.post("/search", json=params) - assert resp.status_code == 400 + assert resp.status_code == 200 async def test_app_query_extension_gt(load_test_data, app_client, load_test_collection): @@ -310,6 +329,15 @@ async def test_search_point_intersects( resp = await app_client.post(f"/collections/{coll.id}/items", json=item) assert resp.status_code == 200 + new_coordinates = list() + for coordinate in item["geometry"]["coordinates"][0]: + new_coordinates.append([coordinate[0] * -1, coordinate[1] * -1]) + item["id"] = "test-item-other-hemispheres" + item["geometry"]["coordinates"] = [new_coordinates] + item["bbox"] = list(value * -1 for value in item["bbox"]) + resp = await app_client.post(f"/collections/{coll.id}/items", json=item) + assert resp.status_code == 200 + point = [150.04, -33.14] intersects = {"type": "Point", "coordinates": point} @@ -322,6 +350,12 @@ async def test_search_point_intersects( resp_json = resp.json() assert len(resp_json["features"]) == 1 + params["intersects"] = orjson.dumps(params["intersects"]).decode("utf-8") + resp = await app_client.get("/search", params=params) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["features"]) == 1 + async def test_search_line_string_intersects( load_test_data, app_client, load_test_collection @@ -513,3 +547,33 @@ async def test_bad_collection_queryables( ): resp = await app_client.get("/collections/bad-collection/queryables") assert resp.status_code == 404 + + +async def test_deleting_items_with_identical_ids(app_client): + collection_a = Collection("collection-a", "The first collection", DEFAULT_EXTENT) + collection_b = Collection("collection-b", "The second collection", DEFAULT_EXTENT) + item = Item("the-item", GLOBAL_GEOMETRY, GLOBAL_BBOX, datetime.now(), {}) + + for collection in (collection_a, collection_b): + response = await app_client.post( + "/collections", json=collection.to_dict(include_self_link=False) + ) + assert response.status_code == 200 + item_as_dict = item.to_dict(include_self_link=False) + item_as_dict["collection"] = collection.id + response = await app_client.post( + f"/collections/{collection.id}/items", json=item_as_dict + ) + assert response.status_code == 200 + response = await app_client.get(f"/collections/{collection.id}/items") + assert response.status_code == 200, response.json() + assert len(response.json()["features"]) == 1 + + for collection in (collection_a, collection_b): + response = await app_client.delete( + f"/collections/{collection.id}/items/{item.id}" + ) + assert response.status_code == 200, response.json() + response = await app_client.get(f"/collections/{collection.id}/items") + assert response.status_code == 200, response.json() + assert not response.json()["features"] diff --git a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py index a0a99d044..68995d209 100644 --- a/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py +++ b/stac_fastapi/sqlalchemy/stac_fastapi/sqlalchemy/core.py @@ -249,6 +249,7 @@ def get_search( token: Optional[str] = None, fields: Optional[List[str]] = None, sortby: Optional[str] = None, + intersects: Optional[str] = None, **kwargs, ) -> ItemCollection: """GET search catalog.""" @@ -265,6 +266,9 @@ def get_search( if datetime: base_args["datetime"] = datetime + if intersects: + base_args["intersects"] = json.loads(unquote_plus(intersects)) + if sortby: # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form sort_param = [] diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index cbab5bfc2..6fdbb6ed8 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -209,7 +209,7 @@ def test_app_query_extension_limit_gt10000( params = {"limit": 10001} resp = app_client.post("/search", json=params) - assert resp.status_code == 400 + assert resp.status_code == 200 def test_app_query_extension_limit_10000( @@ -276,6 +276,16 @@ def test_search_point_intersects(load_test_data, app_client, postgres_transactio item["collection"], item, request=MockStarletteRequest ) + new_coordinates = list() + for coordinate in item["geometry"]["coordinates"][0]: + new_coordinates.append([coordinate[0] * -1, coordinate[1] * -1]) + item["id"] = "test-item-other-hemispheres" + item["geometry"]["coordinates"] = [new_coordinates] + item["bbox"] = list(value * -1 for value in item["bbox"]) + postgres_transactions.create_item( + item["collection"], item, request=MockStarletteRequest + ) + point = [150.04, -33.14] intersects = {"type": "Point", "coordinates": point} @@ -288,6 +298,12 @@ def test_search_point_intersects(load_test_data, app_client, postgres_transactio resp_json = resp.json() assert len(resp_json["features"]) == 1 + params["intersects"] = orjson.dumps(params["intersects"]).decode("utf-8") + resp = app_client.get("/search", params=params) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["features"]) == 1 + def test_datetime_non_interval(load_test_data, app_client, postgres_transactions): item = load_test_data("test_item.json") diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 8470bec08..944f3c352 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -433,6 +433,7 @@ def get_search( token: Optional[str] = None, fields: Optional[List[str]] = None, sortby: Optional[str] = None, + intersects: Optional[str] = None, **kwargs, ) -> stac_types.ItemCollection: """Cross catalog search (GET). @@ -627,6 +628,7 @@ async def get_search( token: Optional[str] = None, fields: Optional[List[str]] = None, sortby: Optional[str] = None, + intersects: Optional[str] = None, **kwargs, ) -> stac_types.ItemCollection: """Cross catalog search (GET). diff --git a/stac_fastapi/types/stac_fastapi/types/search.py b/stac_fastapi/types/stac_fastapi/types/search.py index f12c3c518..185ec4754 100644 --- a/stac_fastapi/types/stac_fastapi/types/search.py +++ b/stac_fastapi/types/stac_fastapi/types/search.py @@ -8,7 +8,7 @@ from datetime import datetime from enum import auto from types import DynamicClassAttribute -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Union import attr from geojson_pydantic.geometries import ( @@ -20,7 +20,9 @@ Polygon, _GeometryBase, ) -from pydantic import BaseModel, conint, validator +from pydantic import BaseModel, ConstrainedInt, validator +from pydantic.errors import NumberNotGtError +from pydantic.validators import int_validator from stac_pydantic.shared import BBox from stac_pydantic.utils import AutoValueEnum @@ -30,6 +32,28 @@ NumType = Union[float, int] +class Limit(ConstrainedInt): + """An positive integer that maxes out at 10,000.""" + + ge: int = 1 + le: int = 10_000 + + @classmethod + def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: + """Yield the relevant validators.""" + yield int_validator + yield cls.validate + + @classmethod + def validate(cls, value: int) -> int: + """Validate the integer value.""" + if value < cls.ge: + raise NumberNotGtError(limit_value=cls.ge) + if value > cls.le: + return cls.le + return value + + class Operator(str, AutoValueEnum): """Defines the set of operators supported by the API.""" @@ -74,7 +98,7 @@ class BaseSearchGetRequest(APIRequest): collections: Optional[str] = attr.ib(default=None, converter=str2list) ids: Optional[str] = attr.ib(default=None, converter=str2list) bbox: Optional[str] = attr.ib(default=None, converter=str2list) - intersects: Optional[str] = attr.ib(default=None, converter=str2list) + intersects: Optional[str] = attr.ib(default=None) datetime: Optional[str] = attr.ib(default=None) limit: Optional[int] = attr.ib(default=10) @@ -97,7 +121,7 @@ class BaseSearchPostRequest(BaseModel): Union[Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon] ] datetime: Optional[str] - limit: Optional[conint(gt=0, le=10000)] = 10 + limit: Optional[Limit] = 10 @property def start_date(self) -> Optional[datetime]: diff --git a/stac_fastapi/types/tests/test_limit.py b/stac_fastapi/types/tests/test_limit.py new file mode 100644 index 000000000..e5b2125bd --- /dev/null +++ b/stac_fastapi/types/tests/test_limit.py @@ -0,0 +1,22 @@ +import pytest +from pydantic import ValidationError + +from stac_fastapi.types.search import BaseSearchPostRequest + + +@pytest.mark.parametrize("value", [0, -1]) +def test_limit_ge(value): + with pytest.raises(ValidationError): + BaseSearchPostRequest(limit=value) + + +@pytest.mark.parametrize("value", [1, 10_000]) +def test_limit(value): + search = BaseSearchPostRequest(limit=value) + assert search.limit == value + + +@pytest.mark.parametrize("value", [10_001, 100_000, 1_000_000]) +def test_limit_le(value): + search = BaseSearchPostRequest(limit=value) + assert search.limit == 10_000