Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Support Union in DjangoFilterConnectionField and DjangoConnectionField #1537

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion graphene_django/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .fields import DjangoConnectionField, DjangoListField
from .types import DjangoObjectType
from .types import DjangoObjectType, DjangoUnionType
from .utils import bypass_get_queryset

__version__ = "3.2.2"

__all__ = [
"__version__",
"DjangoObjectType",
"DjangoUnionType",
"DjangoListField",
"DjangoConnectionField",
"bypass_get_queryset",
Expand Down
6 changes: 3 additions & 3 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ def __init__(self, *args, **kwargs):

@property
def type(self):
from .types import DjangoObjectType
from .types import DjangoObjectType, DjangoUnionType

_type = super(ConnectionField, self).type
non_null = False
if isinstance(_type, NonNull):
_type = _type.of_type
non_null = True
assert issubclass(
_type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
_type, (DjangoObjectType, DjangoUnionType)
), "DjangoConnectionField only accepts DjangoObjectType or DjangoUnionType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
Expand Down
6 changes: 3 additions & 3 deletions graphene_django/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ def __init__(self):
self._field_registry = {}

def register(self, cls):
from .types import DjangoObjectType
from .types import DjangoObjectType, DjangoUnionType

assert issubclass(
cls, DjangoObjectType
), f'Only DjangoObjectTypes can be registered, received "{cls.__name__}"'
cls, (DjangoObjectType, DjangoUnionType)
), f'Only DjangoObjectTypes or DjangoUnionType can be registered, received "{cls.__name__}"'
assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
Expand Down
52 changes: 51 additions & 1 deletion graphene_django/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@

from .. import registry
from ..filter import DjangoFilterConnectionField
from ..types import DjangoObjectType, DjangoObjectTypeOptions
from ..types import (
DjangoObjectType,
DjangoObjectTypeOptions,
DjangoUnionType,
)
from .models import (
APNewsReporter as APNewsReporterModel,
Article as ArticleModel,
CNNReporter as CNNReporterModel,
Reporter as ReporterModel,
)

Expand Down Expand Up @@ -799,3 +805,47 @@ class Query(ObjectType):
assert "type Reporter implements Node {" not in schema
assert "type ReporterConnection {" not in schema
assert "type ReporterEdge {" not in schema


@with_local_registry
def test_django_uniontype_name_connection_propagation():
class CNNReporter(DjangoObjectType):
class Meta:
model = CNNReporterModel
name = "CNNReporter"
fields = "__all__"
filter_fields = ["email"]
interfaces = (Node,)

class APNewsReporter(DjangoObjectType):
class Meta:
model = APNewsReporterModel
name = "APNewsReporter"
fields = "__all__"
filter_fields = ["email"]
interfaces = (Node,)

class ReporterUnion(DjangoUnionType):
class Meta:
model = ReporterModel
types = (CNNReporter, APNewsReporter)
interfaces = (Node,)
filter_fields = ("id", "first_name", "last_name")

@classmethod
def resolve_type(cls, instance, info):
if isinstance(instance, CNNReporterModel):
return CNNReporter
elif isinstance(instance, APNewsReporterModel):
return APNewsReporter
return None

class Query(ObjectType):
reporter = Node.Field(ReporterUnion)
reporters = DjangoFilterConnectionField(ReporterUnion)

schema = str(Schema(query=Query))

assert "union ReporterUnion = CNNReporter | APNewsReporter" in schema
assert "CNNReporter implements Node" in schema
assert "ReporterUnionConnection" in schema
132 changes: 132 additions & 0 deletions graphene_django/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import graphene
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.union import Union, UnionOptions
from graphene.types.utils import yank_fields_from_attrs

from .converter import convert_django_field_with_choices
Expand Down Expand Up @@ -293,6 +294,137 @@ def get_node(cls, info, id):
return None


class DjangoUnionTypeOptions(UnionOptions, ObjectTypeOptions):
model = None # type: Type[Model]
registry = None # type: Registry
connection = None # type: Type[Connection]

filter_fields = ()
filterset_class = None


class DjangoUnionType(Union):
"""
A Django specific Union type that allows to map multiple Django object types
One use case is to handle polymorphic relationships for a Django model, using a library like django-polymorphic.

Can be used in combination with DjangoConnectionField and DjangoFilterConnectionField

Args:
Meta (class): The meta class of the union type
model (Model): The Django model that represents the union type
types (tuple): A tuple of DjangoObjectType classes that represent the possible types of the union

Example:
```python
from graphene_django.types import DjangoObjectType, DjangoUnionType

class AssessmentUnion(DjangoUnionType):
class Meta:
model = Assessment
types = (HomeworkAssessmentNode, QuizAssessmentNode)
interfaces = (graphene.relay.Node,)
filter_fields = ("id", "title", "description")

@classmethod
def resolve_type(cls, instance, info):
if isinstance(instance, HomeworkAssessment):
return HomeworkAssessmentNode
elif isinstance(instance, QuizAssessment):
return QuizAssessmentNode

class Query(graphene.ObjectType):
all_assessments = DjangoFilterConnectionField(AssessmentUnion)
```
"""

class Meta:
abstract = True

@classmethod
def __init_subclass_with_meta__(
cls,
model=None,
types=None,
registry=None,
skip_registry=False,
_meta=None,
fields=None,
exclude=None,
convert_choices_to_enum=None,
filter_fields=None,
filterset_class=None,
connection=None,
connection_class=None,
use_connection=None,
interfaces=(),
**options,
):
django_fields = yank_fields_from_attrs(
construct_fields(model, registry, fields, exclude, convert_choices_to_enum),
_as=graphene.Field,
)

if use_connection is None and interfaces:
use_connection = any(
issubclass(interface, Node) for interface in interfaces
)

if not registry:
registry = get_global_registry()

assert isinstance(registry, Registry), (
f"The attribute registry in {cls.__name__} needs to be an instance of "
f'Registry, received "{registry}".'
)

if filter_fields and filterset_class:
raise Exception("Can't set both filter_fields and filterset_class")

if not DJANGO_FILTER_INSTALLED and (filter_fields or filterset_class):
raise Exception(
"Can only set filter_fields or filterset_class if "
"Django-Filter is installed"
)

if not _meta:
_meta = DjangoUnionTypeOptions(cls)

_meta.model = model
_meta.types = types
_meta.fields = django_fields
_meta.filter_fields = filter_fields
_meta.filterset_class = filterset_class
_meta.registry = registry

if use_connection and not connection:
# We create the connection automatically
if not connection_class:
connection_class = Connection

connection = connection_class.create_type(
"{}Connection".format(options.get("name") or cls.__name__), node=cls
)

if connection is not None:
assert issubclass(
connection, Connection
), f"The connection must be a Connection. Received {connection.__name__}"

_meta.connection = connection

super().__init_subclass_with_meta__(
types=types, _meta=_meta, interfaces=interfaces, **options
)

if not skip_registry:
registry.register(cls)

@classmethod
def get_queryset(cls, queryset, info):
return queryset


class ErrorType(ObjectType):
field = graphene.String(required=True)
messages = graphene.List(graphene.NonNull(graphene.String), required=True)
Expand Down