Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor backend-specific clients to use a new Backend interface #515

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions stac_fastapi/pgstac/stac_fastapi/pgstac/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Asynchronous backend for pgstac."""

from typing import Any, Dict, List, Optional, Tuple

import buildpg
from asyncpg.exceptions import InvalidDatetimeFormatError
from starlette.datastructures import State

from stac_fastapi.types.backend import AsyncBackend
from stac_fastapi.types.errors import InvalidQueryParameter
from stac_fastapi.types.links import PaginationLinks
from stac_fastapi.types.search import BaseSearchPostRequest
from stac_fastapi.types.stac import Collection, ItemCollection


class PgstacBackend(AsyncBackend):
"""An asynchronous backend for pgstac."""

async def all_collections(state: State) -> List[Collection]:
"""Get all collections."""
pool = state.readpool
async with pool.acquire() as conn:
collections = await conn.fetchval(
"""
SELECT * FROM all_collections();
"""
)
if collections is None:
return list()
else:
return [Collection(**collection) for collection in collections]

async def get_collection(state: State, collection_id: str) -> Optional[Collection]:
"""Get a single collection."""
pool = state.readpool
async with pool.acquire() as conn:
q, p = buildpg.render(
"""
SELECT * FROM get_collection(:id::text);
""",
id=collection_id,
)
collection = await conn.fetchval(q, *p)
if collection is None:
return None
else:
return Collection(**collection)

async def search_post(
state: State, search_request: BaseSearchPostRequest
) -> Tuple[ItemCollection, PaginationLinks]:
"""Search the database."""
items: Dict[str, Any]
req = search_request.json(exclude_none=True, by_alias=True)
pool = state.readpool
try:
async with pool.acquire() as conn:
q, p = buildpg.render(
"""
SELECT * FROM search(:req::text::jsonb);
""",
req=req,
)
items = await conn.fetchval(q, *p)
except InvalidDatetimeFormatError:
raise InvalidQueryParameter(
f"Datetime parameter {search_request.datetime} is invalid."
)

def make_query_dict(name: str) -> Optional[Dict[str, str]]:
value = items.pop(name, None)
if value is None:
return None
else:
return {"token": f"{name}:{value}"}

next = make_query_dict("next")
prev = make_query_dict("prev")
pagination_links = PaginationLinks.from_dicts(next=next, prev=prev)

return (ItemCollection(**items), pagination_links)

async def get_base_item(
state: State, collection_id: str
) -> Optional[Dict[str, Any]]:
"""Get the base item from the database."""
pool = state.readpool
async with pool.acquire() as conn:
q, p = buildpg.render(
"""
SELECT * FROM collection_base_item(:collection_id::text);
""",
collection_id=collection_id,
)
return await conn.fetchval(q, *p)
83 changes: 17 additions & 66 deletions stac_fastapi/pgstac/stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

import attr
import orjson
from asyncpg.exceptions import InvalidDatetimeFormatError
from buildpg import render
from fastapi import HTTPException
from pydantic import ValidationError
from pygeofilter.backends.cql2_json import to_cql2
Expand All @@ -17,6 +15,7 @@
from stac_pydantic.shared import MimeTypes
from starlette.requests import Request

from stac_fastapi.pgstac.backend import PgstacBackend
from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.models.links import (
CollectionLinks,
Expand All @@ -27,7 +26,7 @@
from stac_fastapi.pgstac.types.search import PgstacSearch
from stac_fastapi.pgstac.utils import filter_fields
from stac_fastapi.types.core import AsyncBaseCoreClient
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
from stac_fastapi.types.errors import NotFoundError
from stac_fastapi.types.requests import get_base_url
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection

Expand All @@ -42,23 +41,15 @@ async def all_collections(self, **kwargs) -> Collections:
"""Read all collections from the database."""
request: Request = kwargs["request"]
base_url = get_base_url(request)
pool = request.app.state.readpool

async with pool.acquire() as conn:
collections = await conn.fetchval(
"""
SELECT * FROM all_collections();
"""
)
collections = await PgstacBackend.all_collections(request.app.state)
linked_collections: List[Collection] = []
if collections is not None and len(collections) > 0:
for c in collections:
coll = Collection(**c)
coll["links"] = await CollectionLinks(
collection_id=coll["id"], request=request
).get_links(extra_links=coll.get("links"))
c["links"] = await CollectionLinks(
collection_id=c["id"], request=request
).get_links(extra_links=c["links"])

linked_collections.append(coll)
linked_collections.append(c)

links = [
{
Expand Down Expand Up @@ -91,26 +82,18 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection:
Returns:
Collection.
"""
collection: Optional[Dict[str, Any]]

request: Request = kwargs["request"]
pool = request.app.state.readpool
async with pool.acquire() as conn:
q, p = render(
"""
SELECT * FROM get_collection(:id::text);
""",
id=collection_id,
)
collection = await conn.fetchval(q, *p)
collection = await PgstacBackend.get_collection(
request.app.state, collection_id
)
if collection is None:
raise NotFoundError(f"Collection {collection_id} does not exist.")

collection["links"] = await CollectionLinks(
collection_id=collection_id, request=request
).get_links(extra_links=collection.get("links"))
).get_links(extra_links=collection["links"])

return Collection(**collection)
return collection

async def _get_base_item(
self, collection_id: str, request: Request
Expand All @@ -123,18 +106,7 @@ async def _get_base_item(
Returns:
Item.
"""
item: Optional[Dict[str, Any]]

pool = request.app.state.readpool
async with pool.acquire() as conn:
q, p = render(
"""
SELECT * FROM collection_base_item(:collection_id::text);
""",
collection_id=collection_id,
)
item = await conn.fetchval(q, *p)

item = await PgstacBackend.get_base_item(request.app.state, collection_id)
if item is None:
raise NotFoundError(f"A base item for {collection_id} does not exist.")

Expand All @@ -155,33 +127,14 @@ async def _search_base(
Returns:
ItemCollection containing items which match the search criteria.
"""
items: Dict[str, Any]

request: Request = kwargs["request"]
settings: Settings = request.app.state.settings
pool = request.app.state.readpool

search_request.conf = search_request.conf or {}
search_request.conf["nohydrate"] = settings.use_api_hydrate
req = search_request.json(exclude_none=True, by_alias=True)

try:
async with pool.acquire() as conn:
q, p = render(
"""
SELECT * FROM search(:req::text::jsonb);
""",
req=req,
)
items = await conn.fetchval(q, *p)
except InvalidDatetimeFormatError:
raise InvalidQueryParameter(
f"Datetime parameter {search_request.datetime} is invalid."
)

next: Optional[str] = items.pop("next", None)
prev: Optional[str] = items.pop("prev", None)
collection = ItemCollection(**items)
collection, pagination_links = await PgstacBackend.search_post(
request.app.state, search_request
)

exclude = search_request.fields.exclude
if exclude and len(exclude) == 0:
Expand Down Expand Up @@ -244,9 +197,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:

collection["features"] = cleaned_features
collection["links"] = await PagingLinks(
request=request,
next=next,
prev=prev,
request=request, pagination_links=pagination_links
).get_links()
return collection

Expand Down
80 changes: 38 additions & 42 deletions stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""link helpers."""

import copy
from typing import Any, Dict, List, Optional
from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse

Expand All @@ -8,6 +9,7 @@
from stac_pydantic.shared import MimeTypes
from starlette.requests import Request

from stac_fastapi.types.links import PaginationLinks, UnresolvedLink
from stac_fastapi.types.requests import get_base_url

# These can be inferred from the item/collection so they aren't included in the database
Expand Down Expand Up @@ -116,54 +118,48 @@ async def get_links(
class PagingLinks(BaseLinks):
"""Create links for paging."""

next: Optional[str] = attr.ib(kw_only=True, default=None)
prev: Optional[str] = attr.ib(kw_only=True, default=None)
pagination_links: PaginationLinks = attr.ib()

def link_next(self) -> Optional[Dict[str, Any]]:
"""Create link for next page."""
if self.next is not None:
method = self.request.method
if method == "GET":
href = merge_params(self.url, {"token": f"next:{self.next}"})
link = dict(
rel=Relations.next.value,
type=MimeTypes.geojson.value,
method=method,
href=href,
)
return link
if method == "POST":
return {
"rel": Relations.next,
"type": MimeTypes.geojson,
"method": method,
"href": f"{self.request.url}",
"body": {**self.request.postbody, "token": f"next:{self.next}"},
}

return None
if self.pagination_links.next is not None:
return self._link(Relations.next, self.pagination_links.next)
else:
return None

def link_prev(self) -> Optional[Dict[str, Any]]:
"""Create link for previous page."""
if self.prev is not None:
method = self.request.method
if method == "GET":
href = merge_params(self.url, {"token": f"prev:{self.prev}"})
return dict(
rel=Relations.previous.value,
type=MimeTypes.geojson.value,
method=method,
href=href,
)
if method == "POST":
return {
"rel": Relations.previous,
"type": MimeTypes.geojson,
"method": method,
"href": f"{self.request.url}",
"body": {**self.request.postbody, "token": f"prev:{self.prev}"},
}
return None
if self.pagination_links.prev is not None:
return self._link(Relations.previous, self.pagination_links.prev)
else:
return None

def _link(self, rel: Relations, unresolved_link: UnresolvedLink) -> Dict[str, Any]:
method = self.request.method
if method == "GET":
href = merge_params(self.url, unresolved_link.query)
link = dict(
rel=rel.value,
type=MimeTypes.geojson.value,
method=method,
href=href,
)
link.update(unresolved_link.extra_fields)
return link
elif method == "POST":
body = copy.deepcopy(self.request.postbody)
body.update(unresolved_link.query)
link = {
"rel": rel.value,
"type": MimeTypes.geojson,
"method": method,
"href": f"{self.request.url}",
"body": body,
}
link.update(unresolved_link.extra_fields)
return link
else:
raise ValueError(f"unsupported paging link method: {method}")


@attr.s
Expand Down
35 changes: 35 additions & 0 deletions stac_fastapi/types/stac_fastapi/types/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Storage backends for stac-fastapi.

Backends are used to fetch data for a client. They intentionally have
restrictive method signatures in order to enforce separation of responsibilities
between backends and clients.
"""

from abc import ABC, abstractclassmethod
from typing import List, Optional

from starlette.datastructures import State

from stac_fastapi.types.search import BaseSearchPostRequest
from stac_fastapi.types.stac import Collection, ItemCollection


class AsyncBackend(ABC):
"""An asynchronous backend."""

@abstractclassmethod
async def all_collections(state: State) -> List[Collection]:
"""Get all collections."""
...

@abstractclassmethod
async def get_collection(state: State, collection_id: str) -> Optional[Collection]:
"""Get a single collection by id."""
...

@abstractclassmethod
async def search_post(
state: State, search: BaseSearchPostRequest
) -> ItemCollection:
"""Search the backend for items using a search POST request model."""
...
Loading