Skip to content

Commit

Permalink
add collection search extension (#136)
Browse files Browse the repository at this point in the history
* add collection-search extension

* define collections_get_request_model

* add test for additional extensions with collection-search

* pass collections_get_request_model to StacApi

* do not pass collection extensions to post_request_model

* do not pass collection extensions to get_request_model

* Do not add extensions to collection-search extension

* use CollectionSearchExtension.from_extensions()

* keep extensions and collection_search_extension separate

* update tests

* filter -> filter_query

* add collection_get_request_model to client

* add collection_request_model to the client

* recycle collections_get_request_model in client

* drop print statement

* simplify

* remove unused

* clean up control flow for extension-specific logic

* add link to PR in changelog

---------

Co-authored-by: vincentsarago <vincent.sarago@gmail.com>
  • Loading branch information
hrodmn and vincentsarago authored Oct 2, 2024
1 parent c89360c commit 9a3797a
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 91 deletions.
4 changes: 3 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ coverage.xml
*.log
.git
.envrc
*egg-info

venv
venv
env
2 changes: 1 addition & 1 deletion .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
runs-on: ubuntu-latest
services:
pgstac:
image: ghcr.io/stac-utils/pgstac:v0.7.10
image: ghcr.io/stac-utils/pgstac:v0.8.6
env:
POSTGRES_USER: username
POSTGRES_PASSWORD: password
Expand Down
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## [Unreleased]

- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, https://github.com/stac-utils/stac-fastapi-pgstac/pull/142)
- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, <https://github.com/stac-utils/stac-fastapi-pgstac/pull/142>)
- Add collection search extension ([#139](https://github.com/stac-utils/stac-fastapi-pgstac/pull/139))

- Fix `filter` extension implementation in `CoreCrudClient`

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"orjson",
"pydantic",
"stac_pydantic==3.1.*",
"stac-fastapi.api~=3.0",
"stac-fastapi.extensions~=3.0",
"stac-fastapi.types~=3.0",
"stac-fastapi.api~=3.0.2",
"stac-fastapi.extensions~=3.0.2",
"stac-fastapi.types~=3.0.2",
"asyncpg",
"buildpg",
"brotli_asgi",
Expand Down
39 changes: 28 additions & 11 deletions stac_fastapi/pgstac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi.responses import ORJSONResponse
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import (
EmptyRequest,
ItemCollectionUri,
create_get_request_model,
create_post_request_model,
Expand All @@ -22,6 +23,7 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.core.collection_search import CollectionSearchExtension
from stac_fastapi.extensions.third_party import BulkTransactionExtension

from stac_fastapi.pgstac.config import Settings
Expand All @@ -47,34 +49,49 @@
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
}

if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"):
extensions = [
extensions_map[extension_name] for extension_name in enabled_extensions.split(",")
]
else:
extensions = list(extensions_map.values())
enabled_extensions = (
os.environ["ENABLED_EXTENSIONS"].split(",")
if "ENABLED_EXTENSIONS" in os.environ
else list(extensions_map.keys()) + ["collection_search"]
)
extensions = [
extension for key, extension in extensions_map.items() if key in enabled_extensions
]

if any(isinstance(ext, TokenPaginationExtension) for ext in extensions):
items_get_request_model = create_request_model(
items_get_request_model = (
create_request_model(
model_name="ItemCollectionUri",
base_model=ItemCollectionUri,
mixins=[TokenPaginationExtension().GET],
request_type="GET",
)
else:
items_get_request_model = ItemCollectionUri
if any(isinstance(ext, TokenPaginationExtension) for ext in extensions)
else ItemCollectionUri
)

collection_search_extension = (
CollectionSearchExtension.from_extensions(extensions)
if "collection_search" in enabled_extensions
else None
)
collections_get_request_model = (
collection_search_extension.GET if collection_search_extension else EmptyRequest
)

post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
get_request_model = create_get_request_model(extensions)

api = StacApi(
settings=settings,
extensions=extensions,
extensions=extensions + [collection_search_extension]
if collection_search_extension
else extensions,
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
response_class=ORJSONResponse,
items_get_request_model=items_get_request_model,
search_get_request_model=get_request_model,
search_post_request_model=post_request_model,
collections_get_request_model=collections_get_request_model,
)
app = api.app

Expand Down
207 changes: 135 additions & 72 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Item crud client."""

import json
import re
from typing import Any, Dict, List, Optional, Set, Union
from urllib.parse import unquote_plus, urljoin
Expand All @@ -14,12 +15,11 @@
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
from pypgstac.hydration import hydrate
from stac_fastapi.api.models import JSONResponse
from stac_fastapi.types.core import AsyncBaseCoreClient
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
from stac_fastapi.types.requests import get_base_url
from stac_fastapi.types.rfc3339 import DateTimeType
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
from stac_pydantic.links import Relations
from stac_pydantic.shared import BBox, MimeTypes

from stac_fastapi.pgstac.config import Settings
Expand All @@ -39,17 +39,66 @@
class CoreCrudClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""

async def all_collections(self, request: Request, **kwargs) -> Collections:
"""Read all collections from the database."""
async def all_collections( # noqa: C901
self,
request: Request,
# Extensions
bbox: Optional[BBox] = None,
datetime: Optional[DateTimeType] = None,
limit: Optional[int] = None,
query: Optional[str] = None,
token: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter: Optional[str] = None,
filter_lang: Optional[str] = None,
**kwargs,
) -> Collections:
"""Cross catalog search (GET).
Called with `GET /collections`.
Returns:
Collections which match the search criteria, returns all
collections by default.
"""
base_url = get_base_url(request)

# Parse request parameters
base_args = {
"bbox": bbox,
"limit": limit,
"token": token,
"query": orjson.loads(unquote_plus(query)) if query else query,
}

clean_args = clean_search_args(
base_args=base_args,
datetime=datetime,
fields=fields,
sortby=sortby,
filter_query=filter,
filter_lang=filter_lang,
)

async with request.app.state.get_connection(request, "r") as conn:
collections = await conn.fetchval(
"""
SELECT * FROM all_collections();
q, p = render(
"""
SELECT * FROM collection_search(:req::text::jsonb);
""",
req=json.dumps(clean_args),
)
collections_result: Collections = await conn.fetchval(q, *p)

next: Optional[str] = None
prev: Optional[str] = None

if links := collections_result.get("links"):
next = collections_result["links"].pop("next")
prev = collections_result["links"].pop("prev")

linked_collections: List[Collection] = []
collections = collections_result["collections"]
if collections is not None and len(collections) > 0:
for c in collections:
coll = Collection(**c)
Expand All @@ -71,25 +120,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:

linked_collections.append(coll)

links = [
{
"rel": Relations.root.value,
"type": MimeTypes.json,
"href": base_url,
},
{
"rel": Relations.parent.value,
"type": MimeTypes.json,
"href": base_url,
},
{
"rel": Relations.self.value,
"type": MimeTypes.json,
"href": urljoin(base_url, "collections"),
},
]
collection_list = Collections(collections=linked_collections or [], links=links)
return collection_list
links = await PagingLinks(
request=request,
next=next,
prev=prev,
).get_links()

return Collections(
collections=linked_collections or [],
links=links,
)

async def get_collection(
self, collection_id: str, request: Request, **kwargs
Expand Down Expand Up @@ -386,7 +426,7 @@ async def post_search(

return ItemCollection(**item_collection)

async def get_search( # noqa: C901
async def get_search(
self,
request: Request,
collections: Optional[List[str]] = None,
Expand Down Expand Up @@ -421,51 +461,15 @@ async def get_search( # noqa: C901
"query": orjson.loads(unquote_plus(query)) if query else query,
}

if filter:
if filter_lang == "cql2-text":
filter = to_cql2(parse_cql2_text(filter))
filter_lang = "cql2-json"

base_args["filter"] = orjson.loads(filter)
base_args["filter-lang"] = filter_lang

if datetime:
base_args["datetime"] = format_datetime_range(datetime)

if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))

if sortby:
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
sort_param = []
for sort in sortby:
sortparts = re.match(r"^([+-]?)(.*)$", sort)
if sortparts:
sort_param.append(
{
"field": sortparts.group(2).strip(),
"direction": "desc" if sortparts.group(1) == "-" else "asc",
}
)
base_args["sortby"] = sort_param

if fields:
includes = set()
excludes = set()
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
elif field[0] == "+":
includes.add(field[1:])
else:
includes.add(field)
base_args["fields"] = {"include": includes, "exclude": excludes}

# Remove None values from dict
clean = {}
for k, v in base_args.items():
if v is not None and v != []:
clean[k] = v
clean = clean_search_args(
base_args=base_args,
intersects=intersects,
datetime=datetime,
fields=fields,
sortby=sortby,
filter_query=filter,
filter_lang=filter_lang,
)

# Do the request
try:
Expand All @@ -476,3 +480,62 @@ async def get_search( # noqa: C901
) from e

return await self.post_search(search_request, request=request)


def clean_search_args( # noqa: C901
base_args: Dict[str, Any],
intersects: Optional[str] = None,
datetime: Optional[DateTimeType] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter_query: Optional[str] = None,
filter_lang: Optional[str] = None,
) -> Dict[str, Any]:
"""Clean up search arguments to match format expected by pgstac"""
if filter_query:
if filter_lang == "cql2-text":
filter_query = to_cql2(parse_cql2_text(filter_query))
filter_lang = "cql2-json"

base_args["filter"] = orjson.loads(filter_query)
base_args["filter_lang"] = filter_lang

if datetime:
base_args["datetime"] = format_datetime_range(datetime)

if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))

if sortby:
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
sort_param = []
for sort in sortby:
sortparts = re.match(r"^([+-]?)(.*)$", sort)
if sortparts:
sort_param.append(
{
"field": sortparts.group(2).strip(),
"direction": "desc" if sortparts.group(1) == "-" else "asc",
}
)
base_args["sortby"] = sort_param

if fields:
includes = set()
excludes = set()
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
elif field[0] == "+":
includes.add(field[1:])
else:
includes.add(field)
base_args["fields"] = {"include": includes, "exclude": excludes}

# Remove None values from dict
clean = {}
for k, v in base_args.items():
if v is not None and v != []:
clean[k] = v

return clean
Loading

0 comments on commit 9a3797a

Please sign in to comment.