From 78e0860cfa1370e442bc052b9d9d2701b7037fdf Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Wed, 25 Jan 2023 15:57:03 -0700 Subject: [PATCH] wip, refactor: backends --- .../pgstac/stac_fastapi/pgstac/backend.py | 95 +++++++++++++++++++ .../pgstac/stac_fastapi/pgstac/core.py | 83 ++++------------ .../stac_fastapi/pgstac/models/links.py | 80 ++++++++-------- .../types/stac_fastapi/types/backend.py | 35 +++++++ .../types/stac_fastapi/types/links.py | 38 +++++++- 5 files changed, 222 insertions(+), 109 deletions(-) create mode 100644 stac_fastapi/pgstac/stac_fastapi/pgstac/backend.py create mode 100644 stac_fastapi/types/stac_fastapi/types/backend.py diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/backend.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/backend.py new file mode 100644 index 000000000..66f22af7a --- /dev/null +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/backend.py @@ -0,0 +1,95 @@ +"""Asynchronous backend for pgstac.""" + +from typing import Any, Dict, List, Optional, Tuple + +import buildpg +from asyncpg.exceptions import InvalidDatetimeFormatError +from starlette.datastructures import State + +from stac_fastapi.types.backend import AsyncBackend +from stac_fastapi.types.errors import InvalidQueryParameter +from stac_fastapi.types.links import PaginationLinks +from stac_fastapi.types.search import BaseSearchPostRequest +from stac_fastapi.types.stac import Collection, ItemCollection + + +class PgstacBackend(AsyncBackend): + """An asynchronous backend for pgstac.""" + + async def all_collections(state: State) -> List[Collection]: + """Get all collections.""" + pool = state.readpool + async with pool.acquire() as conn: + collections = await conn.fetchval( + """ + SELECT * FROM all_collections(); + """ + ) + if collections is None: + return list() + else: + return [Collection(**collection) for collection in collections] + + async def get_collection(state: State, collection_id: str) -> Optional[Collection]: + """Get a single collection.""" + pool = state.readpool + async with pool.acquire() as conn: + q, p = buildpg.render( + """ + SELECT * FROM get_collection(:id::text); + """, + id=collection_id, + ) + collection = await conn.fetchval(q, *p) + if collection is None: + return None + else: + return Collection(**collection) + + async def search_post( + state: State, search_request: BaseSearchPostRequest + ) -> Tuple[ItemCollection, PaginationLinks]: + """Search the database.""" + items: Dict[str, Any] + req = search_request.json(exclude_none=True, by_alias=True) + pool = state.readpool + try: + async with pool.acquire() as conn: + q, p = buildpg.render( + """ + SELECT * FROM search(:req::text::jsonb); + """, + req=req, + ) + items = await conn.fetchval(q, *p) + except InvalidDatetimeFormatError: + raise InvalidQueryParameter( + f"Datetime parameter {search_request.datetime} is invalid." + ) + + def make_query_dict(name: str) -> Optional[Dict[str, str]]: + value = items.pop(name, None) + if value is None: + return None + else: + return {"token": f"{name}:{value}"} + + next = make_query_dict("next") + prev = make_query_dict("prev") + pagination_links = PaginationLinks.from_dicts(next=next, prev=prev) + + return (ItemCollection(**items), pagination_links) + + async def get_base_item( + state: State, collection_id: str + ) -> Optional[Dict[str, Any]]: + """Get the base item from the database.""" + pool = state.readpool + async with pool.acquire() as conn: + q, p = buildpg.render( + """ + SELECT * FROM collection_base_item(:collection_id::text); + """, + collection_id=collection_id, + ) + return await conn.fetchval(q, *p) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index e074a9f9c..39bd59c20 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -6,8 +6,6 @@ import attr import orjson -from asyncpg.exceptions import InvalidDatetimeFormatError -from buildpg import render from fastapi import HTTPException from pydantic import ValidationError from pygeofilter.backends.cql2_json import to_cql2 @@ -17,6 +15,7 @@ from stac_pydantic.shared import MimeTypes from starlette.requests import Request +from stac_fastapi.pgstac.backend import PgstacBackend from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.models.links import ( CollectionLinks, @@ -27,7 +26,7 @@ from stac_fastapi.pgstac.types.search import PgstacSearch from stac_fastapi.pgstac.utils import filter_fields from stac_fastapi.types.core import AsyncBaseCoreClient -from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError +from stac_fastapi.types.errors import NotFoundError from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection @@ -42,23 +41,15 @@ async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" request: Request = kwargs["request"] base_url = get_base_url(request) - pool = request.app.state.readpool - - async with pool.acquire() as conn: - collections = await conn.fetchval( - """ - SELECT * FROM all_collections(); - """ - ) + collections = await PgstacBackend.all_collections(request.app.state) linked_collections: List[Collection] = [] if collections is not None and len(collections) > 0: for c in collections: - coll = Collection(**c) - coll["links"] = await CollectionLinks( - collection_id=coll["id"], request=request - ).get_links(extra_links=coll.get("links")) + c["links"] = await CollectionLinks( + collection_id=c["id"], request=request + ).get_links(extra_links=c["links"]) - linked_collections.append(coll) + linked_collections.append(c) links = [ { @@ -91,26 +82,18 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection: Returns: Collection. """ - collection: Optional[Dict[str, Any]] - request: Request = kwargs["request"] - pool = request.app.state.readpool - async with pool.acquire() as conn: - q, p = render( - """ - SELECT * FROM get_collection(:id::text); - """, - id=collection_id, - ) - collection = await conn.fetchval(q, *p) + collection = await PgstacBackend.get_collection( + request.app.state, collection_id + ) if collection is None: raise NotFoundError(f"Collection {collection_id} does not exist.") collection["links"] = await CollectionLinks( collection_id=collection_id, request=request - ).get_links(extra_links=collection.get("links")) + ).get_links(extra_links=collection["links"]) - return Collection(**collection) + return collection async def _get_base_item( self, collection_id: str, request: Request @@ -123,18 +106,7 @@ async def _get_base_item( Returns: Item. """ - item: Optional[Dict[str, Any]] - - pool = request.app.state.readpool - async with pool.acquire() as conn: - q, p = render( - """ - SELECT * FROM collection_base_item(:collection_id::text); - """, - collection_id=collection_id, - ) - item = await conn.fetchval(q, *p) - + item = await PgstacBackend.get_base_item(request.app.state, collection_id) if item is None: raise NotFoundError(f"A base item for {collection_id} does not exist.") @@ -155,33 +127,14 @@ async def _search_base( Returns: ItemCollection containing items which match the search criteria. """ - items: Dict[str, Any] - request: Request = kwargs["request"] settings: Settings = request.app.state.settings - pool = request.app.state.readpool search_request.conf = search_request.conf or {} search_request.conf["nohydrate"] = settings.use_api_hydrate - req = search_request.json(exclude_none=True, by_alias=True) - - try: - async with pool.acquire() as conn: - q, p = render( - """ - SELECT * FROM search(:req::text::jsonb); - """, - req=req, - ) - items = await conn.fetchval(q, *p) - except InvalidDatetimeFormatError: - raise InvalidQueryParameter( - f"Datetime parameter {search_request.datetime} is invalid." - ) - - next: Optional[str] = items.pop("next", None) - prev: Optional[str] = items.pop("prev", None) - collection = ItemCollection(**items) + collection, pagination_links = await PgstacBackend.search_post( + request.app.state, search_request + ) exclude = search_request.fields.exclude if exclude and len(exclude) == 0: @@ -244,9 +197,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]: collection["features"] = cleaned_features collection["links"] = await PagingLinks( - request=request, - next=next, - prev=prev, + request=request, pagination_links=pagination_links ).get_links() return collection diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py index c59891876..d7a7b87fb 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py @@ -1,5 +1,6 @@ """link helpers.""" +import copy from typing import Any, Dict, List, Optional from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse @@ -8,6 +9,7 @@ from stac_pydantic.shared import MimeTypes from starlette.requests import Request +from stac_fastapi.types.links import PaginationLinks, UnresolvedLink from stac_fastapi.types.requests import get_base_url # These can be inferred from the item/collection so they aren't included in the database @@ -116,54 +118,48 @@ async def get_links( class PagingLinks(BaseLinks): """Create links for paging.""" - next: Optional[str] = attr.ib(kw_only=True, default=None) - prev: Optional[str] = attr.ib(kw_only=True, default=None) + pagination_links: PaginationLinks = attr.ib() def link_next(self) -> Optional[Dict[str, Any]]: """Create link for next page.""" - if self.next is not None: - method = self.request.method - if method == "GET": - href = merge_params(self.url, {"token": f"next:{self.next}"}) - link = dict( - rel=Relations.next.value, - type=MimeTypes.geojson.value, - method=method, - href=href, - ) - return link - if method == "POST": - return { - "rel": Relations.next, - "type": MimeTypes.geojson, - "method": method, - "href": f"{self.request.url}", - "body": {**self.request.postbody, "token": f"next:{self.next}"}, - } - - return None + if self.pagination_links.next is not None: + return self._link(Relations.next, self.pagination_links.next) + else: + return None def link_prev(self) -> Optional[Dict[str, Any]]: """Create link for previous page.""" - if self.prev is not None: - method = self.request.method - if method == "GET": - href = merge_params(self.url, {"token": f"prev:{self.prev}"}) - return dict( - rel=Relations.previous.value, - type=MimeTypes.geojson.value, - method=method, - href=href, - ) - if method == "POST": - return { - "rel": Relations.previous, - "type": MimeTypes.geojson, - "method": method, - "href": f"{self.request.url}", - "body": {**self.request.postbody, "token": f"prev:{self.prev}"}, - } - return None + if self.pagination_links.prev is not None: + return self._link(Relations.previous, self.pagination_links.prev) + else: + return None + + def _link(self, rel: Relations, unresolved_link: UnresolvedLink) -> Dict[str, Any]: + method = self.request.method + if method == "GET": + href = merge_params(self.url, unresolved_link.query) + link = dict( + rel=rel.value, + type=MimeTypes.geojson.value, + method=method, + href=href, + ) + link.update(unresolved_link.extra_fields) + return link + elif method == "POST": + body = copy.deepcopy(self.request.postbody) + body.update(unresolved_link.query) + link = { + "rel": rel.value, + "type": MimeTypes.geojson, + "method": method, + "href": f"{self.request.url}", + "body": body, + } + link.update(unresolved_link.extra_fields) + return link + else: + raise ValueError(f"unsupported paging link method: {method}") @attr.s diff --git a/stac_fastapi/types/stac_fastapi/types/backend.py b/stac_fastapi/types/stac_fastapi/types/backend.py new file mode 100644 index 000000000..c3f8b76fb --- /dev/null +++ b/stac_fastapi/types/stac_fastapi/types/backend.py @@ -0,0 +1,35 @@ +"""Storage backends for stac-fastapi. + +Backends are used to fetch data for a client. They intentionally have +restrictive method signatures in order to enforce separation of responsibilities +between backends and clients. +""" + +from abc import ABC, abstractclassmethod +from typing import List, Optional + +from starlette.datastructures import State + +from stac_fastapi.types.search import BaseSearchPostRequest +from stac_fastapi.types.stac import Collection, ItemCollection + + +class AsyncBackend(ABC): + """An asynchronous backend.""" + + @abstractclassmethod + async def all_collections(state: State) -> List[Collection]: + """Get all collections.""" + ... + + @abstractclassmethod + async def get_collection(state: State, collection_id: str) -> Optional[Collection]: + """Get a single collection by id.""" + ... + + @abstractclassmethod + async def search_post( + state: State, search: BaseSearchPostRequest + ) -> ItemCollection: + """Search the backend for items using a search POST request model.""" + ... diff --git a/stac_fastapi/types/stac_fastapi/types/links.py b/stac_fastapi/types/stac_fastapi/types/links.py index 0349984b1..459ebd03d 100644 --- a/stac_fastapi/types/stac_fastapi/types/links.py +++ b/stac_fastapi/types/stac_fastapi/types/links.py @@ -1,6 +1,6 @@ """link helpers.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from urllib.parse import urljoin import attr @@ -108,3 +108,39 @@ def create_links(self) -> List[Dict[str, Any]]: self.root(), ] return links + + +@attr.s +class UnresolvedLink: + """An resolved link.""" + + query: Dict[str, str] = attr.ib() + extra_fields: Dict[str, Any] = attr.ib() + + @classmethod + def from_opt_dict( + cls, maybe_query: Optional[Dict[str, str]] + ) -> Optional["UnresolvedLink"]: + """Create an unresolved link from an optional query dictionary.""" + if maybe_query is None: + return None + else: + return UnresolvedLink(query=maybe_query, extra_fields={}) + + +@attr.s +class PaginationLinks: + """Unresolved links for pagination.""" + + next: Optional[UnresolvedLink] = attr.ib() + prev: Optional[UnresolvedLink] = attr.ib() + + @classmethod + def from_dicts( + cls, next: Optional[Dict[str, str]], prev: Optional[Dict[str, str]] + ) -> "PaginationLinks": + """Create a new PaginationLinks from two optional dictionaries of queries.""" + return cls( + next=UnresolvedLink.from_opt_dict(next), + prev=UnresolvedLink.from_opt_dict(prev), + )