Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use search query parameters for additional search views in the API #3887

Merged
merged 14 commits into from
Mar 27, 2024
Merged
2 changes: 1 addition & 1 deletion api/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pytest-django = "~=4.6"
pytest-raises = "~=0.11"
pytest-sugar = "~=0.9"
remote-pdb = "~=2.1"
schemathesis = "~=3.23"
schemathesis = "~=3.25"

[packages]
adrf = "~=0.1.2"
Expand Down
201 changes: 94 additions & 107 deletions api/Pipfile.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions api/api/constants/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
COLLECTION = "unstable__collection"
TAG = "unstable__tag"
69 changes: 25 additions & 44 deletions api/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,7 @@

# Using TYPE_CHECKING to avoid circular imports when importing types
if TYPE_CHECKING:
from api.serializers.audio_serializers import AudioCollectionRequestSerializer
from api.serializers.media_serializers import (
MediaSearchRequestSerializer,
PaginatedRequestSerializer,
)

MediaListRequestSerializer = (
AudioCollectionRequestSerializer
| MediaSearchRequestSerializer
| PaginatedRequestSerializer
)
from api.serializers.media_serializers import MediaSearchRequestSerializer

module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -220,7 +210,7 @@ def get_excluded_providers_query() -> Q | None:
def get_index(
exact_index: bool,
origin_index: OriginIndex,
search_params: MediaListRequestSerializer,
search_params: MediaSearchRequestSerializer,
) -> SearchIndex:
if exact_index:
return origin_index
Expand All @@ -234,7 +224,7 @@ def get_index(


def create_search_filter_queries(
search_params: MediaListRequestSerializer,
search_params: MediaSearchRequestSerializer,
) -> dict[str, list[Q]]:
"""
Create a list of Elasticsearch queries for filtering search results.
Expand Down Expand Up @@ -275,7 +265,7 @@ def create_search_filter_queries(


def create_ranking_queries(
search_params: MediaListRequestSerializer,
search_params: MediaSearchRequestSerializer,
) -> list[Q]:
queries = [Q("rank_feature", field="standardized_popularity", boost=DEFAULT_BOOST)]
if search_params.data["unstable__authority"]:
Expand All @@ -286,7 +276,7 @@ def create_ranking_queries(


def build_search_query(
search_params: MediaListRequestSerializer,
search_params: MediaSearchRequestSerializer,
) -> Q:
# Apply filters from the url query search parameters.
url_queries = create_search_filter_queries(search_params)
Expand Down Expand Up @@ -383,12 +373,10 @@ def log_query_features(query: str, query_name) -> None:


def build_collection_query(
search_params: MediaListRequestSerializer,
collection_params: dict[str, str],
search_params: MediaSearchRequestSerializer,
):
"""
Build the query to retrieve items in a collection.
:param collection_params: `tag`, `source` and/or `creator` values from the path.
:param search_params: the validated search parameters.
:return: the search client with the query applied.
"""
Expand All @@ -397,15 +385,12 @@ def build_collection_query(
# with its corresponding field in Elasticsearch. "None" means that the
# names are identical.
filters = [
# Collection filters allow a single value.
("tag", "tags.name.keyword"),
("source", None),
("creator", "creator.keyword"),
]
for serializer_field, es_field in filters:
if serializer_field in collection_params:
if not (argument := collection_params.get(serializer_field)):
continue
if argument := search_params.validated_data.get(serializer_field):
parameter = es_field or serializer_field
search_query["filter"].append({"term": {parameter: argument}})

Expand All @@ -422,20 +407,14 @@ def build_collection_query(
return Q("bool", **search_query)


def build_query(
strategy: SearchStrategy,
search_params: MediaListRequestSerializer,
collection_params: dict[str, str] | None,
) -> Q:
if strategy == "collection":
return build_collection_query(search_params, collection_params)
return build_search_query(search_params)
query_builders = {
"search": build_search_query,
"collection": build_collection_query,
}


def query_media(
strategy: SearchStrategy,
search_params: MediaListRequestSerializer,
collection_params: dict[str, str] | None,
search_params: MediaSearchRequestSerializer,
origin_index: OriginIndex,
exact_index: bool,
page_size: int,
Expand All @@ -444,17 +423,15 @@ def query_media(
page: int = 1,
) -> tuple[list[Hit], int, int, dict]:
"""
If ``strategy`` is ``search``, perform a ranked paginated search
Build the search or collection query, execute it and return
paginated result.
For queries with `collection` parameter, returns media filtered
by the `tag`, `source` or `source`/`creator` combination, ordered
by the time when they were added to Openverse.
For other queries, performs a ranked paginated search
from the set of keywords and, optionally, filters.
If `strategy` is `collection`, perform a paginated search
for the `tag`, `source` or `source` and `creator` combination.

:param collection_params: The path parameters for collection search, if
strategy is `collection`.
:param strategy: Whether to perform a default search or retrieve a collection.
:param search_params: If `strategy` is `collection`, `PaginatedRequestSerializer`
or `AudioCollectionRequestSerializer`. If `strategy` is `search`, search
query params, see :class: `MediaRequestSerializer`.

:param search_params: Search query params, see :class: `MediaSearchRequestSerializer`.
:param origin_index: The Elasticsearch index to search (e.g. 'image')
:param exact_index: whether to skip all modifications to the index name
:param page_size: The number of results to return per page.
Expand All @@ -468,7 +445,11 @@ def query_media(
"""
index = get_index(exact_index, origin_index, search_params)

query = build_query(strategy, search_params, collection_params)
strategy: SearchStrategy = (
"collection" if search_params.validated_data.get("collection") else "search"
)

query = query_builders[strategy](search_params)

s = Search(index=index).query(query)

Expand Down
44 changes: 14 additions & 30 deletions api/api/docs/audio_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from drf_spectacular.utils import OpenApiResponse, extend_schema

from api.docs.base_docs import collection_schema, custom_extend_schema, fields_to_md
from api.docs.base_docs import (
SEARCH_DESCRIPTION,
custom_extend_schema,
fields_to_md,
)
from api.examples import (
audio_complain_201_example,
audio_complain_curl,
Expand Down Expand Up @@ -35,24 +39,17 @@
from api.serializers.provider_serializers import ProviderSerializer


search = custom_extend_schema(
desc=f"""
Search audio files using a query string.
serializer = AudioSearchRequestSerializer(context={"media_type": "audio"})
audio_filter_fields = fields_to_md([f for f in serializer.field_names if f != "q"])

By using this endpoint, you can obtain search results based on specified
query and optionally filter results by
{fields_to_md(AudioSearchRequestSerializer.field_names)}.

Results are ranked in order of relevance and paginated on the basis of the
`page` param. The `page_size` param controls the total number of pages.
audio_search_description = SEARCH_DESCRIPTION.format(
filter_fields=audio_filter_fields,
media_type="audio files",
)

Although there may be millions of relevant records, only the most relevant
several thousand records can be viewed. This is by design: the search
endpoint should be used to find the top 10,000 most relevant results, not
for exhaustive search or bulk download of every barely relevant result. As
such, the caller should not try to access pages beyond `page_count`, or else
the server will reject the query.""",
params=AudioSearchRequestSerializer,
search = custom_extend_schema(
desc=audio_search_description,
params=serializer,
res={
200: (AudioSerializer, audio_search_200_example),
400: (ValidationError, audio_search_400_example),
Expand Down Expand Up @@ -122,16 +119,3 @@
},
eg=[audio_waveform_curl],
)

source_collection = collection_schema(
media_type="audio",
collection="source",
)
creator_collection = collection_schema(
media_type="audio",
collection="creator",
)
tag_collection = collection_schema(
media_type="audio",
collection="tag",
)
Loading