Skip to content

Commit

Permalink
feat: add method create_field_filter_input_type that create field f…
Browse files Browse the repository at this point in the history
…ilter input types from filter set tree leaves
  • Loading branch information
SquakR committed Jan 28, 2022
1 parent 3977969 commit 9f76dec
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 11 deletions.
49 changes: 43 additions & 6 deletions graphene_django_filter/filter_set_converters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Functions for converting a FilterSet class to a tree and then to an input type."""

from typing import List, Optional, Sequence, Type
from typing import Dict, List, Optional, Sequence, Type

import graphene
from anytree import Node
from anytree.exporter import DictExporter
from anytree.importer import DictImporter
from anytree.search import findall_by_attr
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django_filters import Filter, FilterSet
Expand All @@ -13,11 +16,45 @@
from stringcase import camelcase, capitalcase


def create_field_filter_input_types(
type_name: str,
tree: Node,
filter_set_class: Type[FilterSet],
) -> Node:
"""Create field filter input types from filter set tree leaves.
This function return new tree.
"""
tree_copy = DictImporter().import_(DictExporter().export(tree))
for field_node in findall_by_attr(tree_copy, name='height', value=1):
fields: Dict[str, UnmountedType] = {}
for lookup_node in field_node.children:
fields[lookup_node.name] = get_field(
filter_set_class,
lookup_node.filter_name,
filter_set_class.base_filters[lookup_node.filter_name],
)
input_type = type(
get_field_filter_input_type_name(type_name, field_node.path),
(graphene.InputObjectType,),
fields,
)
new_node = Node(name=field_node.name, input_type=input_type)
if tree_copy == field_node:
return new_node
else:
field_node.parent.children = [
*(node for node in field_node.parent.children if node is not field_node),
new_node,
]
return tree_copy


def get_field(filter_set_class: Type[FilterSet], name: str, filter_field: Filter) -> UnmountedType:
"""Return Graphene type from a filter field.
It is a partial copy of the `get_filtering_args_from_filterset` method from graphene-django.
https://github.com/graphql-python/graphene-django/blob/775644b5369bdc5fbb45d3535ae391a069ebf9d4/graphene_django/filter/utils.py#L25
It is a partial copy of the `get_filtering_args_from_filterset` function from graphene-django.
https://github.com/graphql-python/graphene-django/blob/caf954861025b9f3d9d3f9c204a7cbbc87352265/graphene_django/filter/utils.py#L11
"""
model = filter_set_class._meta.model
form_field: Optional[models.Field] = None
Expand All @@ -41,15 +78,15 @@ def get_field(filter_set_class: Type[FilterSet], name: str, filter_field: Filter
return field_type


def get_input_type_name(type_name: str, node_path: Sequence[Node]) -> str:
"""Return input type name from a type name and node path."""
def get_field_filter_input_type_name(type_name: str, node_path: Sequence[Node]) -> str:
"""Return field filter input type name from a type name and node path."""
field_name = ''.join(
map(
lambda node: capitalcase(camelcase(node.name)),
node_path,
),
)
return f'{type_name.replace("Type", "")}{field_name}FilterInputType'
return f'{type_name.replace("Type", "")}{field_name}FieldFilterInputType'


def filter_set_to_trees(filter_set_class: Type[FilterSet]) -> List[Node]:
Expand Down
29 changes: 24 additions & 5 deletions tests/test_filter_set_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from anytree.exporter import DictExporter
from django.test import TestCase
from graphene_django_filter.filter_set_converters import (
create_field_filter_input_types,
filter_set_to_trees,
get_input_type_name,
get_field_filter_input_type_name,
sequence_to_tree,
try_add_sequence,
)
Expand Down Expand Up @@ -51,11 +52,29 @@ def setUp(self) -> None:
),
]

def test_get_input_type_name(self) -> None:
"""Test the `get_input_type_name` function."""
def test_create_field_filter_input_types(self) -> None:
"""Test the `create_field_filter_input_types` function."""
trees = [
create_field_filter_input_types('TaskType', tree, TaskFilter)
for tree in self.task_filter_trees
]
name_input_type = getattr(trees[0], 'input_type')
last_name_input_type = getattr(trees[1].children[0], 'input_type')
email_input_type = getattr(trees[1].children[1], 'input_type')
self.assertEqual('TaskNameFieldFilterInputType', name_input_type.__name__)
self.assertTrue(hasattr(name_input_type, 'exact'))
self.assertEqual('TaskUserLastNameFieldFilterInputType', last_name_input_type.__name__)
self.assertTrue(hasattr(last_name_input_type, 'exact'))
self.assertEqual('TaskUserEmailFieldFilterInputType', email_input_type.__name__)
self.assertTrue(hasattr(email_input_type, 'iexact'))
self.assertTrue(hasattr(email_input_type, 'contains'))
self.assertTrue(hasattr(email_input_type, 'icontains'))

def test_get_field_filter_input_type_name(self) -> None:
"""Test the `get_field_filter_input_type_name` function."""
self.assertEqual(
'TaskUserFirstNameFilterInputType',
get_input_type_name(
'TaskUserFirstNameFieldFilterInputType',
get_field_filter_input_type_name(
'TaskType', (Node(name='user'), Node(name='first_name')),
),
)
Expand Down

0 comments on commit 9f76dec

Please sign in to comment.