From 9f76dec0928af16516d02707c7d6a4f09ed42e1c Mon Sep 17 00:00:00 2001 From: squak Date: Fri, 28 Jan 2022 16:45:30 +0300 Subject: [PATCH] feat: add method `create_field_filter_input_type` that create field filter input types from filter set tree leaves --- .../filter_set_converters.py | 49 ++++++++++++++++--- tests/test_filter_set_converters.py | 29 +++++++++-- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/graphene_django_filter/filter_set_converters.py b/graphene_django_filter/filter_set_converters.py index 63b68d8..645de34 100644 --- a/graphene_django_filter/filter_set_converters.py +++ b/graphene_django_filter/filter_set_converters.py @@ -1,9 +1,12 @@ """Functions for converting a FilterSet class to a tree and then to an input type.""" -from typing import List, Optional, Sequence, Type +from typing import Dict, List, Optional, Sequence, Type import graphene from anytree import Node +from anytree.exporter import DictExporter +from anytree.importer import DictImporter +from anytree.search import findall_by_attr from django.db import models from django.db.models.constants import LOOKUP_SEP from django_filters import Filter, FilterSet @@ -13,11 +16,45 @@ from stringcase import camelcase, capitalcase +def create_field_filter_input_types( + type_name: str, + tree: Node, + filter_set_class: Type[FilterSet], +) -> Node: + """Create field filter input types from filter set tree leaves. + + This function return new tree. + """ + tree_copy = DictImporter().import_(DictExporter().export(tree)) + for field_node in findall_by_attr(tree_copy, name='height', value=1): + fields: Dict[str, UnmountedType] = {} + for lookup_node in field_node.children: + fields[lookup_node.name] = get_field( + filter_set_class, + lookup_node.filter_name, + filter_set_class.base_filters[lookup_node.filter_name], + ) + input_type = type( + get_field_filter_input_type_name(type_name, field_node.path), + (graphene.InputObjectType,), + fields, + ) + new_node = Node(name=field_node.name, input_type=input_type) + if tree_copy == field_node: + return new_node + else: + field_node.parent.children = [ + *(node for node in field_node.parent.children if node is not field_node), + new_node, + ] + return tree_copy + + def get_field(filter_set_class: Type[FilterSet], name: str, filter_field: Filter) -> UnmountedType: """Return Graphene type from a filter field. - It is a partial copy of the `get_filtering_args_from_filterset` method from graphene-django. - https://github.com/graphql-python/graphene-django/blob/775644b5369bdc5fbb45d3535ae391a069ebf9d4/graphene_django/filter/utils.py#L25 + It is a partial copy of the `get_filtering_args_from_filterset` function from graphene-django. + https://github.com/graphql-python/graphene-django/blob/caf954861025b9f3d9d3f9c204a7cbbc87352265/graphene_django/filter/utils.py#L11 """ model = filter_set_class._meta.model form_field: Optional[models.Field] = None @@ -41,15 +78,15 @@ def get_field(filter_set_class: Type[FilterSet], name: str, filter_field: Filter return field_type -def get_input_type_name(type_name: str, node_path: Sequence[Node]) -> str: - """Return input type name from a type name and node path.""" +def get_field_filter_input_type_name(type_name: str, node_path: Sequence[Node]) -> str: + """Return field filter input type name from a type name and node path.""" field_name = ''.join( map( lambda node: capitalcase(camelcase(node.name)), node_path, ), ) - return f'{type_name.replace("Type", "")}{field_name}FilterInputType' + return f'{type_name.replace("Type", "")}{field_name}FieldFilterInputType' def filter_set_to_trees(filter_set_class: Type[FilterSet]) -> List[Node]: diff --git a/tests/test_filter_set_converters.py b/tests/test_filter_set_converters.py index f8e5c9a..e0da382 100644 --- a/tests/test_filter_set_converters.py +++ b/tests/test_filter_set_converters.py @@ -4,8 +4,9 @@ from anytree.exporter import DictExporter from django.test import TestCase from graphene_django_filter.filter_set_converters import ( + create_field_filter_input_types, filter_set_to_trees, - get_input_type_name, + get_field_filter_input_type_name, sequence_to_tree, try_add_sequence, ) @@ -51,11 +52,29 @@ def setUp(self) -> None: ), ] - def test_get_input_type_name(self) -> None: - """Test the `get_input_type_name` function.""" + def test_create_field_filter_input_types(self) -> None: + """Test the `create_field_filter_input_types` function.""" + trees = [ + create_field_filter_input_types('TaskType', tree, TaskFilter) + for tree in self.task_filter_trees + ] + name_input_type = getattr(trees[0], 'input_type') + last_name_input_type = getattr(trees[1].children[0], 'input_type') + email_input_type = getattr(trees[1].children[1], 'input_type') + self.assertEqual('TaskNameFieldFilterInputType', name_input_type.__name__) + self.assertTrue(hasattr(name_input_type, 'exact')) + self.assertEqual('TaskUserLastNameFieldFilterInputType', last_name_input_type.__name__) + self.assertTrue(hasattr(last_name_input_type, 'exact')) + self.assertEqual('TaskUserEmailFieldFilterInputType', email_input_type.__name__) + self.assertTrue(hasattr(email_input_type, 'iexact')) + self.assertTrue(hasattr(email_input_type, 'contains')) + self.assertTrue(hasattr(email_input_type, 'icontains')) + + def test_get_field_filter_input_type_name(self) -> None: + """Test the `get_field_filter_input_type_name` function.""" self.assertEqual( - 'TaskUserFirstNameFilterInputType', - get_input_type_name( + 'TaskUserFirstNameFieldFilterInputType', + get_field_filter_input_type_name( 'TaskType', (Node(name='user'), Node(name='first_name')), ), )