diff --git a/graphene_django_filter/connection_field.py b/graphene_django_filter/connection_field.py index b558a49..6a5ce38 100644 --- a/graphene_django_filter/connection_field.py +++ b/graphene_django_filter/connection_field.py @@ -9,12 +9,12 @@ import graphene from django.core.exceptions import ValidationError from django.db import models -from django_filters import FilterSet from graphene_django import DjangoObjectType from graphene_django.filter import DjangoFilterConnectionField -from .filterset import AdvancedFilterSet, tree_input_type_to_data +from .filterset import AdvancedFilterSet from .filterset_factories import get_filterset_class +from .input_data_factories import tree_input_type_to_data from .input_type_factories import get_filtering_args_from_filterset @@ -74,14 +74,19 @@ def resolve_queryset( info: graphene.ResolveInfo, args: Dict[str, Any], filtering_args: Dict[str, graphene.InputField], - filterset_class: Type[FilterSet], + filterset_class: Type[AdvancedFilterSet], ) -> models.QuerySet: """Return a filtered QuerySet.""" qs = super(DjangoFilterConnectionField, cls).resolve_queryset( connection, iterable, info, args, ) filterset = filterset_class( - data=tree_input_type_to_data(args['filter']), queryset=qs, request=info.context, + data=tree_input_type_to_data( + filterset_class, + args['filter'], + ), + queryset=qs, + request=info.context, ) if filterset.form.is_valid(): return filterset.qs diff --git a/graphene_django_filter/filterset.py b/graphene_django_filter/filterset.py index 19c8fb3..df72769 100644 --- a/graphene_django_filter/filterset.py +++ b/graphene_django_filter/filterset.py @@ -14,34 +14,11 @@ from django_filters import Filter from django_filters.conf import settings as django_settings from django_filters.filterset import BaseFilterSet, FilterSetMetaclass -from graphene.types.inputobjecttype import InputObjectTypeContainer from wrapt import ObjectProxy from .conf import settings -def tree_input_type_to_data( - tree_input_type: InputObjectTypeContainer, - prefix: str = '', -) -> Dict[str, Any]: - """Convert a tree_input_type to a FilterSet data.""" - result: Dict[str, Any] = {} - for key, value in tree_input_type.items(): - if key in ('and', 'or'): - result[key] = [tree_input_type_to_data(subtree) for subtree in value] - elif key == 'not': - result[key] = tree_input_type_to_data(value) - else: - k = (prefix + LOOKUP_SEP + key if prefix else key).replace( - LOOKUP_SEP + django_settings.DEFAULT_LOOKUP_EXPR, '', - ) - if isinstance(value, InputObjectTypeContainer): - result.update(tree_input_type_to_data(value, k)) - else: - result[k] = value - return result - - class QuerySetProxy(ObjectProxy): """Proxy for a QuerySet object. @@ -105,7 +82,7 @@ def is_regular_lookup_expr(lookup_expr: str) -> bool: class AdvancedFilterSet(BaseFilterSet, metaclass=FilterSetMetaclass): - """Allow you to use advanced filters with `or` and `and` expressions.""" + """Allow you to use advanced filters.""" class TreeFormMixin(Form): """Tree-like form mixin.""" diff --git a/graphene_django_filter/input_data_factories.py b/graphene_django_filter/input_data_factories.py new file mode 100644 index 0000000..13269fb --- /dev/null +++ b/graphene_django_filter/input_data_factories.py @@ -0,0 +1,203 @@ +"""Functions for converting tree data into data suitable for the FilterSet.""" + +from typing import Any, Dict, List, Type, Union + +from django.contrib.postgres.search import ( + SearchQuery, + SearchRank, + SearchVector, + TrigramDistance, + TrigramSimilarity, +) +from django.core.exceptions import ValidationError +from django.db import models +from django.db.models.constants import LOOKUP_SEP +from django_filters.conf import settings as django_settings +from graphene.types.inputobjecttype import InputObjectTypeContainer +from graphene_django_filter.filters import SearchQueryFilter, SearchRankFilter, TrigramFilter +from graphene_django_filter.input_types import ( + SearchConfigInputType, + SearchQueryFilterInputType, + SearchQueryInputType, + SearchRankFilterInputType, + SearchRankWeightsInputType, + SearchVectorInputType, + TrigramFilterInputType, + TrigramSearchKind, +) + +from .conf import settings +from .filterset import AdvancedFilterSet + + +def tree_input_type_to_data( + filterset_class: Type[AdvancedFilterSet], + tree_input_type: InputObjectTypeContainer, + prefix: str = '', +) -> Dict[str, Any]: + """Convert a tree_input_type to a FilterSet data.""" + result: Dict[str, Any] = {} + for key, value in tree_input_type.items(): + if key in ('and', 'or'): + result[key] = [tree_input_type_to_data(filterset_class, subtree) for subtree in value] + elif key == 'not': + result[key] = tree_input_type_to_data(filterset_class, value) + else: + result.update( + create_data( + (prefix + LOOKUP_SEP + key if prefix else key).replace( + LOOKUP_SEP + django_settings.DEFAULT_LOOKUP_EXPR, '', + ), + value, + filterset_class, + ), + ) + return result + + +def create_data(key: str, value: Any, filterset_class: Type[AdvancedFilterSet]) -> Dict[str, Any]: + """Create data from a key and a value.""" + for factory_key, factory in DATA_FACTORIES.items(): + if factory_key in key: + return factory(value, key, filterset_class) + if isinstance(value, InputObjectTypeContainer): + return tree_input_type_to_data(filterset_class, value, key) + else: + return {key: value} + + +def create_search_query_data( + input_type: SearchQueryFilterInputType, + key: str, + filterset_class: Type[AdvancedFilterSet], +) -> Dict[str, SearchQueryFilter.Value]: + """Create a data for the `SearchQueryFilter` class.""" + return { + key: SearchQueryFilter.Value( + annotation_value=create_search_vector(input_type.vector, filterset_class), + search_value=create_search_query(input_type.query), + ), + } + + +def create_search_rank_data( + input_type: Union[SearchRankFilterInputType, InputObjectTypeContainer], + key: str, + filterset_class: Type[AdvancedFilterSet], +) -> Dict[str, SearchRankFilter.Value]: + """Create a data for the `SearchRankFilter` class.""" + rank_data = {} + for lookup, value in input_type.lookups.items(): + search_rank_data = { + 'vector': create_search_vector(input_type.vector, filterset_class), + 'query': create_search_query(input_type.query), + 'cover_density': input_type.cover_density, + } + weights = input_type.get('weights', None) + if weights: + search_rank_data['weights'] = create_search_rank_weights(weights) + normalization = input_type.get('normalization', None) + if normalization: + search_rank_data['normalization'] = normalization + rank_data[key + LOOKUP_SEP + lookup] = SearchRankFilter.Value( + annotation_value=SearchRank(**search_rank_data), + search_value=value, + ) + return rank_data + + +def create_trigram_data( + input_type: TrigramFilterInputType, + key: str, + *args +) -> Dict[str, TrigramFilter.Value]: + """Create a data for the `TrigramFilter` class.""" + trigram_data = {} + if input_type.kind == TrigramSearchKind.SIMILARITY: + trigram_class = TrigramSimilarity + else: + trigram_class = TrigramDistance + for lookup, value in input_type.lookups.items(): + trigram_data[key + LOOKUP_SEP + lookup] = TrigramFilter.Value( + annotation_value=trigram_class( + LOOKUP_SEP.join(key.split(LOOKUP_SEP)[:-1]), + input_type.value, + ), + search_value=value, + ) + return trigram_data + + +def create_search_vector( + input_type: Union[SearchVectorInputType, InputObjectTypeContainer], + filterset_class: Type[AdvancedFilterSet], +) -> SearchVector: + """Create an object of the `SearchVector` class.""" + validate_search_vector_fields(filterset_class, input_type.fields) + search_vector_data = {} + config = input_type.get('config', None) + if config: + search_vector_data['config'] = create_search_config(config) + weight = input_type.get('weight', None) + if weight: + search_vector_data['weight'] = weight.value + return SearchVector(*input_type.fields, **search_vector_data) + + +def create_search_query( + input_type: Union[SearchQueryInputType, InputObjectTypeContainer], +) -> SearchQuery: + """Create an object of the `SearchQuery` class.""" + config = input_type.get('config', None) + if config: + search_query = SearchQuery(input_type.value, config=create_search_config(config)) + else: + search_query = SearchQuery(input_type.value) + and_search_query = None + for and_input_type in input_type.get(settings.AND_KEY, []): + if and_search_query is None: + and_search_query = create_search_query(and_input_type) + else: + and_search_query = and_search_query & create_search_query(and_input_type) + or_search_query = None + for or_input_type in input_type.get(settings.OR_KEY, []): + if or_search_query is None: + or_search_query = create_search_query(or_input_type) + else: + or_search_query = or_search_query | create_search_query(or_input_type) + not_input_type = input_type.get(settings.NOT_KEY, None) + not_search_query = create_search_query(not_input_type) if not_input_type else None + valid_queries = ( + q for q in (and_search_query, or_search_query, not_search_query) if q is not None + ) + for valid_query in valid_queries: + search_query = search_query & valid_query + return search_query + + +def create_search_config(input_type: SearchConfigInputType) -> Union[str, models.F]: + """Create a `SearchVector` or `SearchQuery` object config.""" + return models.F(input_type.value) if input_type.is_field else input_type.value + + +def create_search_rank_weights(input_type: SearchRankWeightsInputType) -> List[float]: + """Create a search rank weights list.""" + return [input_type.D, input_type.C, input_type.B, input_type.A] + + +def validate_search_vector_fields( + filterset_class: Type[AdvancedFilterSet], + fields: List[str], +) -> None: + """Validate that fields is included in full text search fields.""" + full_text_search_fields = filterset_class.get_full_text_search_fields() + for field in fields: + if field not in full_text_search_fields: + raise ValidationError(f'The `{field}` field is not included in full text search fields') + + +DATA_FACTORIES = { + SearchQueryFilter.postfix: create_search_query_data, + SearchRankFilter.postfix: create_search_rank_data, + TrigramFilter.postfix: create_trigram_data, +} diff --git a/tests/test_filterset.py b/tests/test_filterset.py index 8a75774..71024ac 100644 --- a/tests/test_filterset.py +++ b/tests/test_filterset.py @@ -2,11 +2,10 @@ from collections import OrderedDict from contextlib import ExitStack -from datetime import datetime, timedelta +from datetime import datetime from typing import List from unittest.mock import MagicMock, patch -import graphene from django.db import models from django.test import TestCase from django.utils.timezone import make_aware @@ -18,7 +17,6 @@ get_q, is_full_text_search_lookup_expr, is_regular_lookup_expr, - tree_input_type_to_data, ) from .data_generation import generate_data @@ -29,99 +27,6 @@ class UtilsTests(TestCase): """Tests for utility functions and classes of the `filterset` module.""" - class TaskNameFilterInputType(graphene.InputObjectType): - exact = graphene.String() - - class TaskDescriptionFilterInputType(graphene.InputObjectType): - exact = graphene.String() - - class TaskUserEmailFilterInputType(graphene.InputObjectType): - exact = graphene.String() - iexact = graphene.String() - contains = graphene.String() - icontains = graphene.String() - - class TaskUserLastNameFilterInputType(graphene.InputObjectType): - exact = graphene.String() - - class TaskUserFilterInputType(graphene.InputObjectType): - exact = graphene.String() - email = graphene.InputField( - lambda: UtilsTests.TaskUserEmailFilterInputType, - ) - last_name = graphene.InputField( - lambda: UtilsTests.TaskUserLastNameFilterInputType, - ) - - class TaskCreatedAtInputType(graphene.InputObjectType): - gt = graphene.DateTime() - - class TaskCompletedAtInputType(graphene.InputObjectType): - lg = graphene.DateTime() - - TaskFilterInputType = type( - 'TaskFilterInputType', - (graphene.InputObjectType,), - { - 'name': graphene.InputField(lambda: UtilsTests.TaskNameFilterInputType), - 'description': graphene.InputField(lambda: UtilsTests.TaskDescriptionFilterInputType), - 'user': graphene.InputField(lambda: UtilsTests.TaskUserFilterInputType), - 'created_at': graphene.InputField(lambda: UtilsTests.TaskCreatedAtInputType), - 'completed_at': graphene.InputField(lambda: UtilsTests.TaskCompletedAtInputType), - 'and': graphene.InputField(graphene.List(lambda: UtilsTests.TaskFilterInputType)), - 'or': graphene.InputField(graphene.List(lambda: UtilsTests.TaskFilterInputType)), - 'not': graphene.InputField(lambda: UtilsTests.TaskFilterInputType), - }, - ) - - gt_datetime = datetime.today() - timedelta(days=1) - lt_datetime = datetime.today() - tree_input_type = TaskFilterInputType._meta.container({ - 'name': TaskNameFilterInputType._meta.container({'exact': 'Important task'}), - 'description': TaskDescriptionFilterInputType._meta.container( - {'exact': 'This task in very important'}, - ), - 'user': TaskUserFilterInputType._meta.container( - {'email': TaskUserEmailFilterInputType._meta.container({'contains': 'dev'})}, - ), - 'and': [ - TaskFilterInputType._meta.container({ - 'completed_at': TaskCompletedAtInputType._meta.container({'lt': lt_datetime}), - }), - ], - 'or': [ - TaskFilterInputType._meta.container({ - 'created_at': TaskCreatedAtInputType._meta.container({'gt': gt_datetime}), - }), - ], - 'not': TaskFilterInputType._meta.container({ - 'user': TaskUserFilterInputType._meta.container( - {'first_name': TaskUserEmailFilterInputType._meta.container({'exact': 'John'})}, - ), - }), - }) - - def test_tree_input_type_to_data(self) -> None: - """Test the `tree_input_type_to_data` function.""" - data = tree_input_type_to_data(self.tree_input_type) - self.assertEqual( - { - 'name': 'Important task', - 'description': 'This task in very important', - 'user__email__contains': 'dev', - 'and': [{ - 'completed_at__lt': self.lt_datetime, - }], - 'or': [{ - 'created_at__gt': self.gt_datetime, - }], - 'not': { - 'user__first_name': 'John', - }, - }, - data, - ) - def test_queryset_proxy(self) -> None: """Test the `QuerySetProxy` class.""" queryset = User.objects.all() diff --git a/tests/test_input_data_factories.py b/tests/test_input_data_factories.py new file mode 100644 index 0000000..c14885e --- /dev/null +++ b/tests/test_input_data_factories.py @@ -0,0 +1,460 @@ +"""Input data factories tests.""" + +from collections import OrderedDict +from contextlib import contextmanager +from datetime import datetime, timedelta +from typing import Any, Generator, Tuple, Type, cast +from unittest.mock import MagicMock, patch + +import graphene +from django.contrib.postgres.search import ( + SearchQuery, + SearchRank, + SearchVector, + TrigramDistance, + TrigramSimilarity, +) +from django.core.exceptions import ValidationError +from django.db import models +from django.test import TestCase +from graphene.types.inputobjecttype import InputObjectTypeContainer +from graphene_django_filter.filters import ( + SearchQueryFilter, + SearchRankFilter, + TrigramFilter, +) +from graphene_django_filter.filterset import AdvancedFilterSet +from graphene_django_filter.input_data_factories import ( + create_data, + create_search_config, + create_search_query, + create_search_query_data, + create_search_rank_data, + create_search_rank_weights, + create_search_vector, + create_trigram_data, + tree_input_type_to_data, + validate_search_vector_fields, +) +from graphene_django_filter.input_types import ( + FloatLookupsInputType, + SearchConfigInputType, + SearchQueryFilterInputType, + SearchQueryInputType, + SearchRankFilterInputType, + SearchRankWeightsInputType, + SearchVectorInputType, + SearchVectorWeight, + TrigramFilterInputType, + TrigramSearchKind, +) + + +class InputDataFactoriesTests(TestCase): + """Input data factories tests.""" + + filterset_class_mock = cast( + Type[AdvancedFilterSet], + MagicMock( + get_full_text_search_fields=MagicMock( + return_value=OrderedDict([ + ('field1', MagicMock()), + ('field2', MagicMock()), + ]), + ), + ), + ) + + rank_weights_input_type = SearchRankWeightsInputType._meta.container({ + 'A': 0.9, + 'B': SearchRankWeightsInputType.B.kwargs['default_value'], + 'C': SearchRankWeightsInputType.C.kwargs['default_value'], + 'D': SearchRankWeightsInputType.D.kwargs['default_value'], + }) + + config_search_query_input_type = SearchConfigInputType._meta.container({ + 'config': SearchConfigInputType._meta.container({ + 'value': 'russian', + 'is_field': False, + }), + 'value': 'value', + }) + expressions_search_query_input_type = SearchQueryInputType._meta.container({ + 'value': 'value1', + 'and': [ + SearchQueryInputType._meta.container({'value': 'and_value1'}), + SearchQueryInputType._meta.container({'value': 'and_value2'}), + ], + 'or': [ + SearchQueryInputType._meta.container({'value': 'or_value1'}), + SearchQueryInputType._meta.container({'value': 'or_value2'}), + ], + 'not': SearchQueryInputType._meta.container({'value': 'not_value'}), + }) + + search_rank_input_type = SearchRankFilterInputType._meta.container({ + 'vector': MagicMock(), + 'query': MagicMock(), + 'lookups': FloatLookupsInputType._meta.container({'gt': 0.8, 'lt': 0.9}), + 'weights': rank_weights_input_type, + 'cover_density': True, + 'normalization': 2, + }) + + @contextmanager + def patch_vector_and_query_factories( + self, + ) -> Generator[Tuple[MagicMock, MagicMock, MagicMock, MagicMock], Any, None]: + """Patch `create_search_vector` and `create_search_query` functions.""" + with patch( + 'graphene_django_filter.input_data_factories.create_search_vector', + ) as create_search_vector_mock, patch( + 'graphene_django_filter.input_data_factories.create_search_query', + ) as create_search_query_mock: + search_vector_mock = MagicMock() + create_search_vector_mock.return_value = search_vector_mock + search_query_mock = MagicMock() + create_search_query_mock.return_value = search_query_mock + yield ( + create_search_vector_mock, + search_vector_mock, + create_search_query_mock, + search_query_mock, + ) + + class TaskNameFilterInputType(graphene.InputObjectType): + exact = graphene.String() + trigram = graphene.InputField(TrigramFilterInputType) + + class TaskDescriptionFilterInputType(graphene.InputObjectType): + exact = graphene.String() + + class TaskUserEmailFilterInputType(graphene.InputObjectType): + exact = graphene.String() + iexact = graphene.String() + contains = graphene.String() + icontains = graphene.String() + + class TaskUserLastNameFilterInputType(graphene.InputObjectType): + exact = graphene.String() + + class TaskUserFilterInputType(graphene.InputObjectType): + exact = graphene.String() + email = graphene.InputField( + lambda: InputDataFactoriesTests.TaskUserEmailFilterInputType, + ) + last_name = graphene.InputField( + lambda: InputDataFactoriesTests.TaskUserLastNameFilterInputType, + ) + + class TaskCreatedAtInputType(graphene.InputObjectType): + gt = graphene.DateTime() + + class TaskCompletedAtInputType(graphene.InputObjectType): + lg = graphene.DateTime() + + TaskFilterInputType = type( + 'TaskFilterInputType', + (graphene.InputObjectType,), + { + 'name': graphene.InputField( + lambda: InputDataFactoriesTests.TaskNameFilterInputType, + ), + 'description': graphene.InputField( + lambda: InputDataFactoriesTests.TaskDescriptionFilterInputType, + ), + 'user': graphene.InputField( + lambda: InputDataFactoriesTests.TaskUserFilterInputType, + ), + 'created_at': graphene.InputField( + lambda: InputDataFactoriesTests.TaskCreatedAtInputType, + ), + 'completed_at': graphene.InputField( + lambda: InputDataFactoriesTests.TaskCompletedAtInputType, + ), + 'and': graphene.InputField( + graphene.List(lambda: InputDataFactoriesTests.TaskFilterInputType), + ), + 'or': graphene.InputField( + graphene.List(lambda: InputDataFactoriesTests.TaskFilterInputType), + ), + 'not': graphene.InputField( + lambda: InputDataFactoriesTests.TaskFilterInputType, + ), + 'search_query': graphene.InputField( + SearchQueryFilterInputType, + ), + 'search_rank': graphene.InputField( + SearchRankFilterInputType, + ), + }, + ) + + task_filterset_class_mock = cast( + Type[AdvancedFilterSet], + MagicMock( + get_full_text_search_fields=MagicMock( + return_value=OrderedDict([ + ('name', MagicMock()), + ]), + ), + ), + ) + gt_datetime = datetime.today() - timedelta(days=1) + lt_datetime = datetime.today() + tree_input_type = TaskFilterInputType._meta.container({ + 'name': TaskNameFilterInputType._meta.container({ + 'exact': 'Important task', + 'trigram': TrigramFilterInputType._meta.container({ + 'kind': TrigramSearchKind.SIMILARITY, + 'lookups': FloatLookupsInputType._meta.container({'gt': 0.8}), + 'value': 'Buy some milk', + }), + }), + 'description': TaskDescriptionFilterInputType._meta.container( + {'exact': 'This task is very important'}, + ), + 'user': TaskUserFilterInputType._meta.container( + {'email': TaskUserEmailFilterInputType._meta.container({'contains': 'dev'})}, + ), + 'and': [ + TaskFilterInputType._meta.container({ + 'completed_at': TaskCompletedAtInputType._meta.container({'lt': lt_datetime}), + }), + ], + 'or': [ + TaskFilterInputType._meta.container({ + 'created_at': TaskCreatedAtInputType._meta.container({'gt': gt_datetime}), + }), + ], + 'not': TaskFilterInputType._meta.container({ + 'user': TaskUserFilterInputType._meta.container( + {'first_name': TaskUserEmailFilterInputType._meta.container({'exact': 'John'})}, + ), + }), + 'search_query': SearchQueryFilterInputType._meta.container({ + 'vector': SearchVectorInputType._meta.container({'fields': ['name']}), + 'query': SearchQueryInputType._meta.container({'value': 'Fix the bug'}), + }), + 'search_rank': SearchRankFilterInputType._meta.container({ + 'vector': SearchVectorInputType._meta.container({'fields': ['name']}), + 'query': SearchQueryInputType._meta.container({'value': 'Fix the bug'}), + 'lookups': FloatLookupsInputType._meta.container({'gt': 0.8}), + 'cover_density': False, + }), + }) + + def test_validate_search_vector_fields(self) -> None: + """Test the `validate_search_vector_fields` function.""" + validate_search_vector_fields(self.filterset_class_mock, ['field1', 'field2']) + with self.assertRaisesMessage( + ValidationError, + 'The `field3` field is not included in full text search fields', + ): + validate_search_vector_fields(self.filterset_class_mock, ['field1', 'field2', 'field3']) + + def test_create_search_rank_weights(self) -> None: + """Test the `create_search_rank_weights` function.""" + self.assertEqual( + [0.1, 0.2, 0.4, 0.9], + create_search_rank_weights(self.rank_weights_input_type), + ) + + def test_create_search_config(self) -> None: + """Test the `create_search_config` function.""" + string_input_type = SearchConfigInputType._meta.container({ + 'value': 'russian', + 'is_field': False, + }) + string_config = create_search_config(string_input_type) + self.assertEqual(string_input_type.value, string_config) + field_input_type = SearchConfigInputType._meta.container({ + 'value': 'russian', + 'is_field': True, + }) + field_config = create_search_config(field_input_type) + self.assertEqual(models.F('russian'), field_config) + + def test_create_search_query(self) -> None: + """Test the `create_search_query` function.""" + config_search_query = create_search_query(self.config_search_query_input_type) + self.assertEqual(SearchQuery('value', config='russian'), config_search_query) + expressions_search_query = create_search_query(self.expressions_search_query_input_type) + self.assertEqual( + SearchQuery('value1') & ( + SearchQuery('and_value1') & SearchQuery('and_value2') + ) & ( + SearchQuery('or_value1') | SearchQuery('or_value2') + ) & ~SearchQuery('not_value'), + expressions_search_query, + ) + + def test_create_search_vector(self) -> None: + """Test the `create_search_vector` function.""" + invalid_input_type = SearchVectorInputType._meta.container({ + 'fields': ['field1', 'field2', 'field3'], + }) + with self.assertRaises(ValidationError): + create_search_vector(invalid_input_type, self.filterset_class_mock) + config_input_type = SearchVectorInputType._meta.container({ + 'fields': ['field1', 'field2'], + 'config': SearchConfigInputType._meta.container({ + 'value': 'russian', + }), + }) + config_search_vector = create_search_vector(config_input_type, self.filterset_class_mock) + self.assertEqual(SearchVector('field1', 'field2', config='russian'), config_search_vector) + weight_input_type = SearchVectorInputType._meta.container({ + 'fields': ['field1', 'field2'], + 'weight': SearchVectorWeight.A, + }) + weight_search_vector = create_search_vector(weight_input_type, self.filterset_class_mock) + self.assertEqual(SearchVector('field1', 'field2', weight='A'), weight_search_vector) + + def test_create_trigram_data(self) -> None: + """Test the `create_trigram_data` function.""" + for trigram_class in (TrigramSimilarity, TrigramDistance): + with self.subTest(trigram_class=trigram_class): + similarity_input_type = TrigramFilterInputType._meta.container({ + 'kind': TrigramSearchKind.SIMILARITY + if trigram_class == TrigramSimilarity else TrigramSearchKind.DISTANCE, + 'lookups': FloatLookupsInputType._meta.container({'gt': 0.8, 'lt': 0.9}), + 'value': 'value', + }) + trigram_data = create_trigram_data(similarity_input_type, 'field__trigram') + expected_trigram_data = { + 'field__trigram__gt': TrigramFilter.Value( + annotation_value=trigram_class('field', 'value'), + search_value=0.8, + ), + 'field__trigram__lt': TrigramFilter.Value( + annotation_value=trigram_class('field', 'value'), + search_value=0.9, + ), + } + self.assertEqual(expected_trigram_data, trigram_data) + + @patch.object( + SearchRank, + '__eq__', + new=lambda self, other: str(self) + self.function == str(other) + other.function, + ) + def test_create_search_rank_data(self) -> None: + """Test the `create_search_rank_data` function.""" + with self.patch_vector_and_query_factories() as mocks: + create_sv_mock, sv_mock, create_sq_mock, sq_mock = mocks + search_rank_data = create_search_rank_data( + self.search_rank_input_type, + 'field', + self.filterset_class_mock, + ) + expected_search_rank = SearchRank( + vector=sv_mock, + query=sq_mock, + weights=[0.1, 0.2, 0.4, 0.9], + cover_density=True, + normalization=2, + ) + self.assertEqual( + { + 'field__gt': SearchRankFilter.Value( + annotation_value=expected_search_rank, + search_value=0.8, + ), + 'field__lt': SearchRankFilter.Value( + annotation_value=expected_search_rank, + search_value=0.9, + ), + }, search_rank_data, + ) + create_sv_mock.assert_called_with( + self.search_rank_input_type.vector, + self.filterset_class_mock, + ) + create_sq_mock.assert_called_with(self.search_rank_input_type.query) + + def test_create_search_query_data(self) -> None: + """Test the `create_search_query_data` function.""" + with self.patch_vector_and_query_factories() as mocks: + create_sv_mock, sv_mock, create_sq_mock, sq_mock = mocks + vector = MagicMock() + query = MagicMock() + search_query_data = create_search_query_data( + SearchQueryFilterInputType._meta.container({ + 'vector': vector, + 'query': query, + }), + 'field', + self.filterset_class_mock, + ) + self.assertEqual( + { + 'field': SearchQueryFilter.Value( + annotation_value=sv_mock, + search_value=sq_mock, + ), + }, search_query_data, + ) + create_sv_mock.assert_called_once_with(vector, self.filterset_class_mock) + create_sq_mock.assert_called_once_with(query) + + def test_create_data(self) -> None: + """Test the `create_data` function.""" + with patch( + 'graphene_django_filter.input_data_factories.DATA_FACTORIES', + new={ + 'search_query': MagicMock(return_value=MagicMock()), + 'search_rank': MagicMock(return_value=MagicMock()), + 'trigram': MagicMock(return_value=MagicMock()), + }, + ): + from graphene_django_filter.input_data_factories import DATA_FACTORIES + for factory_key, factory in DATA_FACTORIES.items(): + value = MagicMock() + self.assertEqual( + factory.return_value, + create_data(factory_key, value, self.filterset_class_mock), + ) + factory.assert_called_once_with(value, factory_key, self.filterset_class_mock) + with patch('graphene_django_filter.input_data_factories.tree_input_type_to_data') as mock: + key = 'field' + value = MagicMock(spec=InputObjectTypeContainer) + mock.return_value = {key: value} + self.assertEqual(mock.return_value, create_data(key, value, self.filterset_class_mock)) + key = 'field' + value = MagicMock() + self.assertEqual({key: value}, create_data(key, value, self.filterset_class_mock)) + + def test_tree_input_type_to_data(self) -> None: + """Test the `tree_input_type_to_data` function.""" + data = tree_input_type_to_data(self.task_filterset_class_mock, self.tree_input_type) + expected_data = { + 'name': 'Important task', + 'name__trigram__gt': TrigramFilter.Value( + annotation_value=TrigramSimilarity('name', 'Buy some milk'), + search_value=0.8, + ), + 'description': 'This task is very important', + 'user__email__contains': 'dev', + 'and': [{ + 'completed_at__lt': self.lt_datetime, + }], + 'or': [{ + 'created_at__gt': self.gt_datetime, + }], + 'not': { + 'user__first_name': 'John', + }, + 'search_query': SearchQueryFilter.Value( + annotation_value=SearchVector('name'), + search_value=SearchQuery('Fix the bug'), + ), + 'search_rank__gt': SearchRankFilter.Value( + annotation_value=SearchRank( + vector=SearchVector('name'), + query=SearchQuery('Fix the bug'), + ), + search_value=0.8, + ), + } + self.assertEqual(expected_data, data) diff --git a/tests/test_input_type_factories.py b/tests/test_input_type_factories.py index fc38b37..d97bfdf 100644 --- a/tests/test_input_type_factories.py +++ b/tests/test_input_type_factories.py @@ -24,7 +24,7 @@ from .object_types import TaskFilterSetClassType -class InputTypeBuildersTests(TestCase): +class InputTypeFactoriesTests(TestCase): """Input type factories tests.""" abstract_tree_root = Node(