diff --git a/CHANGES.md b/CHANGES.md index bf14d02ea..f704c7b34 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,7 @@ * Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367)) * Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383)) * Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411)) +* Add APIRouter prefix support for pgstac implementation. ([429](https://github.com/stac-utils/stac-fastapi/pull/429)) * Respect `Forwarded` or `X-Forwarded-*` request headers when building links to better accommodate load balancers and proxies. ### Changed diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 9761deaa6..0632a5bab 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -336,7 +336,7 @@ def customize_openapi(self) -> Optional[Dict[str, Any]]: def add_health_check(self): """Add a health check.""" - mgmt_router = APIRouter() + mgmt_router = APIRouter(prefix=self.app.state.router_prefix) @mgmt_router.get("/_mgmt/ping") async def ping(): @@ -384,6 +384,10 @@ def __attrs_post_init__(self): self.register_core() self.app.include_router(self.router) + # keep link to the router prefix value + router_prefix = self.router.prefix + self.app.state.router_prefix = router_prefix if router_prefix else "" + # register extensions for ext in self.extensions: ext.register(self.app) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py index 243cfaf4e..0854c9f4f 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py @@ -107,6 +107,7 @@ def register(self, app: FastAPI) -> None: Returns: None """ + self.router.prefix = app.state.router_prefix self.router.add_api_route( name="Queryables", path="/queryables", diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py index 476301fc9..f446de2be 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py @@ -160,6 +160,7 @@ def register(self, app: FastAPI) -> None: Returns: None """ + self.router.prefix = app.state.router_prefix self.register_create_item() self.register_update_item() self.register_delete_item() diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py index bdb68d9a8..3fe25c9d1 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -116,7 +116,7 @@ def register(self, app: FastAPI) -> None: """ items_request_model = create_request_model("Items", base_model=Items) - router = APIRouter() + router = APIRouter(prefix=app.state.router_prefix) router.add_api_route( name="Bulk Create Item", path="/collections/{collection_id}/bulk_items", diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index 2b3297a88..07844f4bb 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -23,6 +23,7 @@ from stac_fastapi.pgstac.utils import filter_fields from stac_fastapi.types.core import AsyncBaseCoreClient from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError +from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection NumType = Union[float, int] @@ -35,7 +36,7 @@ class CoreCrudClient(AsyncBaseCoreClient): async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" request: Request = kwargs["request"] - base_url = str(request.base_url) + base_url = get_base_url(request) pool = request.app.state.readpool async with pool.acquire() as conn: diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py index 4816c0969..798db5426 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py @@ -8,6 +8,8 @@ from stac_pydantic.shared import MimeTypes from starlette.requests import Request +from stac_fastapi.types.requests import get_base_url + # These can be inferred from the item/collection so they aren't included in the database # Instead they are dynamically generated when querying the database using the classes defined below INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"] @@ -45,7 +47,7 @@ class BaseLinks: @property def base_url(self): """Get the base url.""" - return str(self.request.base_url) + return get_base_url(self.request) @property def url(self): diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index 3619055ed..369f9f0ce 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -51,16 +51,24 @@ async def test_api_headers(app_client): assert resp.status_code == 200 -async def test_core_router(api_client): - core_routes = set(STAC_CORE_ROUTES) +async def test_core_router(api_client, app): + core_routes = set() + for core_route in STAC_CORE_ROUTES: + method, path = core_route.split(" ") + core_routes.add("{} {}".format(method, app.state.router_prefix + path)) + api_routes = set( [f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes] ) assert not core_routes - api_routes -async def test_transactions_router(api_client): - transaction_routes = set(STAC_TRANSACTION_ROUTES) +async def test_transactions_router(api_client, app): + transaction_routes = set() + for transaction_route in STAC_TRANSACTION_ROUTES: + method, path = transaction_route.split(" ") + transaction_routes.add("{} {}".format(method, app.state.router_prefix + path)) + api_routes = set( [f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes] ) diff --git a/stac_fastapi/pgstac/tests/conftest.py b/stac_fastapi/pgstac/tests/conftest.py index 200965176..4a32fd73e 100644 --- a/stac_fastapi/pgstac/tests/conftest.py +++ b/stac_fastapi/pgstac/tests/conftest.py @@ -3,9 +3,11 @@ import os import time from typing import Callable, Dict +from urllib.parse import urljoin import asyncpg import pytest +from fastapi import APIRouter from fastapi.responses import ORJSONResponse from httpx import AsyncClient from pypgstac.db import PgstacDB @@ -107,9 +109,26 @@ async def pgstac(pg): # Run all the tests that use the api_client in both db hydrate and api hydrate mode -@pytest.fixture(params=[settings, pgstac_api_hydrate_settings], scope="session") +@pytest.fixture( + params=[ + (settings, ""), + (settings, "/router_prefix"), + (pgstac_api_hydrate_settings, ""), + (pgstac_api_hydrate_settings, "/router_prefix"), + ], + scope="session", +) def api_client(request, pg): - print("creating client with settings, hydrate:", request.param.use_api_hydrate) + api_settings, prefix = request.param + + api_settings.openapi_url = prefix + api_settings.openapi_url + api_settings.docs_url = prefix + api_settings.docs_url + + print( + "creating client with settings, hydrate: {}, router prefix: '{}'".format( + api_settings.use_api_hydrate, prefix + ) + ) extensions = [ TransactionExtension(client=TransactionsClient(), settings=settings), @@ -122,12 +141,13 @@ def api_client(request, pg): ] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) api = StacApi( - settings=request.param, + settings=api_settings, extensions=extensions, client=CoreCrudClient(post_request_model=post_request_model), search_get_request_model=create_get_request_model(extensions), search_post_request_model=post_request_model, response_class=ORJSONResponse, + router=APIRouter(prefix=prefix), ) return api @@ -150,7 +170,12 @@ async def app(api_client): @pytest.fixture(scope="function") async def app_client(app): print("creating app_client") - async with AsyncClient(app=app, base_url="http://test") as c: + + base_url = "http://test" + if app.state.router_prefix != "": + base_url = urljoin(base_url, app.state.router_prefix) + + async with AsyncClient(app=app, base_url=base_url) as c: yield c diff --git a/stac_fastapi/pgstac/tests/resources/test_conformance.py b/stac_fastapi/pgstac/tests/resources/test_conformance.py index b080c4b8a..b9f78852c 100644 --- a/stac_fastapi/pgstac/tests/resources/test_conformance.py +++ b/stac_fastapi/pgstac/tests/resources/test_conformance.py @@ -46,7 +46,7 @@ def test_landing_page_health(response): @pytest.mark.parametrize("rel_type,expected_media_type,expected_path", link_tests) async def test_landing_page_links( - response_json: Dict, app_client, rel_type, expected_media_type, expected_path + response_json: Dict, app_client, app, rel_type, expected_media_type, expected_path ): link = get_link(response_json, rel_type) @@ -54,9 +54,9 @@ async def test_landing_page_links( assert link.get("type") == expected_media_type link_path = urllib.parse.urlsplit(link.get("href")).path - assert link_path == expected_path + assert link_path == app.state.router_prefix + expected_path - resp = await app_client.get(link_path) + resp = await app_client.get(link_path.rsplit("/", 1)[-1]) assert resp.status_code == 200 @@ -64,7 +64,7 @@ async def test_landing_page_links( # code here seems meaningless since it would be the same as if the endpoint did not exist. Once # https://github.com/stac-utils/stac-fastapi/pull/227 has been merged we can add this to the # parameterized tests above. -def test_search_link(response_json: Dict): +def test_search_link(response_json: Dict, app): for search_link in [ get_link(response_json, "search", "GET"), get_link(response_json, "search", "POST"), @@ -73,4 +73,4 @@ def test_search_link(response_json: Dict): assert search_link.get("type") == "application/geo+json" search_path = urllib.parse.urlsplit(search_link.get("href")).path - assert search_path == "/search" + assert search_path == app.state.router_prefix + "/search" diff --git a/stac_fastapi/pgstac/tests/resources/test_item.py b/stac_fastapi/pgstac/tests/resources/test_item.py index a56fd16dd..d261738ab 100644 --- a/stac_fastapi/pgstac/tests/resources/test_item.py +++ b/stac_fastapi/pgstac/tests/resources/test_item.py @@ -1166,7 +1166,7 @@ async def test_get_missing_item(app_client, load_test_data): assert resp.status_code == 404 -async def test_relative_link_construction(): +async def test_relative_link_construction(app): req = Request( scope={ "type": "http", @@ -1177,11 +1177,14 @@ async def test_relative_link_construction(): "raw_path": b"/tab/abc", "query_string": b"", "headers": {}, + "app": app, "server": ("test", HTTP_PORT), } ) links = CollectionLinks(collection_id="naip", request=req) - assert links.link_items()["href"] == "http://test/stac/collections/naip/items" + assert links.link_items()["href"] == ( + "http://test/stac{}/collections/naip/items".format(app.state.router_prefix) + ) async def test_search_bbox_errors(app_client): diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 416bf1c50..965b5f26b 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -14,6 +14,7 @@ from stac_fastapi.types import stac as stac_types from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES from stac_fastapi.types.extension import ApiExtension +from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.search import BaseSearchPostRequest from stac_fastapi.types.stac import Conformance @@ -349,7 +350,7 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: API landing page, serving as an entry point to the API. """ request: Request = kwargs["request"] - base_url = str(request.base_url) + base_url = get_base_url(request) extension_schemas = [ schema.schema_href for schema in self.extensions if schema.schema_href ] @@ -377,7 +378,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: "rel": "service-desc", "type": "application/vnd.oai.openapi+json;version=3.0", "title": "OpenAPI service description", - "href": urljoin(base_url, request.app.openapi_url.lstrip("/")), + "href": urljoin( + str(request.base_url), request.app.openapi_url.lstrip("/") + ), } ) @@ -387,7 +390,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage: "rel": "service-doc", "type": "text/html", "title": "OpenAPI service documentation", - "href": urljoin(base_url, request.app.docs_url.lstrip("/")), + "href": urljoin( + str(request.base_url), request.app.docs_url.lstrip("/") + ), } ) @@ -538,7 +543,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: API landing page, serving as an entry point to the API. """ request: Request = kwargs["request"] - base_url = str(request.base_url) + base_url = get_base_url(request) extension_schemas = [ schema.schema_href for schema in self.extensions if schema.schema_href ] @@ -564,7 +569,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: "rel": "service-desc", "type": "application/vnd.oai.openapi+json;version=3.0", "title": "OpenAPI service description", - "href": urljoin(base_url, request.app.openapi_url.lstrip("/")), + "href": urljoin( + str(request.base_url), request.app.openapi_url.lstrip("/") + ), } ) @@ -574,7 +581,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: "rel": "service-doc", "type": "text/html", "title": "OpenAPI service documentation", - "href": urljoin(base_url, request.app.docs_url.lstrip("/")), + "href": urljoin( + str(request.base_url), request.app.docs_url.lstrip("/") + ), } ) diff --git a/stac_fastapi/types/stac_fastapi/types/requests.py b/stac_fastapi/types/stac_fastapi/types/requests.py new file mode 100644 index 000000000..7ce0e81a4 --- /dev/null +++ b/stac_fastapi/types/stac_fastapi/types/requests.py @@ -0,0 +1,14 @@ +"""requests helpers.""" + +from starlette.requests import Request + + +def get_base_url(request: Request) -> str: + """Get base URL with respect of APIRouter prefix.""" + app = request.app + if not app.state.router_prefix: + return str(request.base_url) + else: + return "{}{}/".format( + str(request.base_url), app.state.router_prefix.lstrip("/") + )