Skip to content

Commit

Permalink
feat: support for manytomany
Browse files Browse the repository at this point in the history
  • Loading branch information
superlevure committed Nov 30, 2023
1 parent 339f020 commit 19bb716
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 16 deletions.
4 changes: 2 additions & 2 deletions docs/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ By default it will resolve the default related queryset of the Django model, but
model = Recipe
fields = "__all__"
ingredients_dataloaded = DjangoDataloadedListField("ingredients")
ingredients_dataloaded_custom_resolver = DjangoDataloadedListField("ingredients")
ingredients_dataloaded = DjangoDataloadedListField(IngredientType, field="ingredients")
ingredients_dataloaded_custom_resolver = DjangoDataloadedListField(IngredientType, field="ingredients")
def resolve_ingredients_dataloaded_custom_resolver(self, info):
# Important: the queryset returned by the resolver must derivate
Expand Down
34 changes: 23 additions & 11 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

import django
from django.db.models import IntegerField, Value
from django.db.models import F, IntegerField, Value
from django.db.models.query import QuerySet
from graphql_relay import (
cursor_to_offset,
Expand Down Expand Up @@ -88,8 +88,8 @@ class DjangoDataloadedListField(Field):
def __init__(
self,
_type,
field,
*args,
related_name=None,
**kwargs,
):
from graphene_django.types import DjangoObjectType
Expand All @@ -104,7 +104,7 @@ def __init__(
self._underlying_type, DjangoObjectType
), "DjangoListField only accepts DjangoObjectType types"

self._related_name = related_name
self._field = field

@property
def _underlying_type(self):
Expand All @@ -122,9 +122,11 @@ def get_manager(self):

@staticmethod
def list_resolver(
related_name, django_object_type, resolver, default_manager, root, info, **args
field, django_object_type, resolver, default_manager, root, info, **args
):
related_name = related_name or root._meta.db_table
related_name = root._meta.get_field(field).remote_field.name
many_to_many = root._meta.get_field(field).many_to_many

queryset = maybe_queryset(resolver(root, info, **args))
if queryset is None:
queryset = maybe_queryset(default_manager)
Expand All @@ -146,11 +148,18 @@ def list_resolver(

def load_many(keys):
results_by_ids = defaultdict(list)
lookup = {
f"{related_name}_id__in": keys,
}

qs: QuerySet = queryset.filter(**lookup)
if many_to_many:
lookup = {
f"{related_name}__in": keys,
}
annotation = {f"{related_name}_id": F(related_name)}
qs = queryset.filter(**lookup).annotate(**annotation)
else:
lookup = {
f"{related_name}_id__in": keys,
}

qs = queryset.filter(**lookup)

for result in qs.iterator():
results_by_ids[
Expand All @@ -163,6 +172,9 @@ def load_many(keys):

return info.context.dataloaders[dataloader_key].load(root.id)

if many_to_many:
return queryset.filter(**{f"{related_name}": root.id})

return queryset.filter(**{f"{related_name}_id": root.id})

def wrap_resolve(self, parent_resolver):
Expand All @@ -173,7 +185,7 @@ def wrap_resolve(self, parent_resolver):
django_object_type = _type.of_type.of_type
return partial(
self.list_resolver,
self._related_name,
self._field,
django_object_type,
resolver,
self.get_manager(),
Expand Down
1 change: 1 addition & 0 deletions graphene_django/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class FilmDetails(models.Model):


class Film(models.Model):
name = models.CharField(max_length=30)
genre = models.CharField(
max_length=2,
help_text="Genre",
Expand Down
104 changes: 101 additions & 3 deletions graphene_django/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def resolve_articles(root, info):
captured.captured_queries[1]["sql"],
)

def test_django_dataloaded_list_field(
def test_django_one_to_many_dataloaded_list_field(
self, execution_context_class, django_assert_max_num_queries
):
class Context:
Expand All @@ -665,9 +665,9 @@ class Meta:
model = ReporterModel
fields = ("first_name", "articles")

articles_dataloaded = DjangoDataloadedListField(Article)
articles_dataloaded = DjangoDataloadedListField(Article, field="articles")
articles_dataloaded_with_custom_resolver = DjangoDataloadedListField(
Article
Article, field="articles"
)

def resolve_articles_dataloaded_with_custom_resolver(self, info):
Expand Down Expand Up @@ -833,3 +833,101 @@ class Query(ObjectType):
r"SELECT .* FROM \"tests_article\" WHERE \(\"tests_article\".\"headline\" LIKE \'%Not%\' ESCAPE \'\\\' AND \"tests_article\".\"reporter_id\" = \d+\) .*",
captured.captured_queries[2]["sql"],
)

def test_django_many_to_many_dataloaded_list_field(
self, execution_context_class, django_assert_max_num_queries
):
class Context:
pass

class Film(DjangoObjectType):
class Meta:
model = FilmModel
fields = ("name",)

class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")

films_dataloaded = DjangoDataloadedListField(Film, field="films")

class Query(ObjectType):
reporters = DjangoListField(Reporter)

schema = Schema(query=Query)

r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
FilmModel.objects.create(name="Zoro").reporters.add(r1)
FilmModel.objects.create(name="Inception").reporters.add(r1)

r2 = ReporterModel.objects.create(first_name="Debra", last_name="Payne")
FilmModel.objects.create(name="Interstellar").reporters.add(r2)
FilmModel.objects.create(name="Cube").reporters.add(r2)

FilmModel.objects.create(name="Lost in translation").reporters.add(r1, r2)

query = """
query {
reporters {
firstName
filmsDataloaded {
name
}
}
}
"""

with django_assert_max_num_queries(3) as captured:
result = schema.execute(
query,
execution_context_class=execution_context_class,
context_value=Context(),
)

assert not result.errors
assert result.data == {
"reporters": [
{
"firstName": "Tara",
"filmsDataloaded": [
{"name": "Zoro"},
{"name": "Inception"},
{"name": "Lost in translation"},
],
},
{
"firstName": "Debra",
"filmsDataloaded": [
{"name": "Interstellar"},
{"name": "Cube"},
{"name": "Lost in translation"},
],
},
]
}

if execution_context_class == DeferredExecutionContext:
assert len(captured.captured_queries) == 2
assert re.match(
r'SELECT .* FROM "tests_reporter"',
captured.captured_queries[0]["sql"],
)
assert re.match(
r"SELECT .* FROM \"tests_film\" INNER JOIN \"tests_film_reporters\" ON \(\"tests_film\".\"id\" = \"tests_film_reporters\".\"film_id\"\) WHERE \"tests_film_reporters\".\"reporter_id\" IN \(\d+, \d+\)",
captured.captured_queries[1]["sql"],
)
else:
assert len(captured.captured_queries) == 3
assert re.match(
r'SELECT .* FROM "tests_reporter"',
captured.captured_queries[0]["sql"],
)
assert re.match(
r"SELECT .* FROM \"tests_film\" INNER JOIN \"tests_film_reporters\" ON \(\"tests_film\".\"id\" = \"tests_film_reporters\".\"film_id\"\) WHERE \"tests_film_reporters\".\"reporter_id\" = \d+",
captured.captured_queries[1]["sql"],
)
assert re.match(
r"SELECT .* FROM \"tests_film\" INNER JOIN \"tests_film_reporters\" ON \(\"tests_film\".\"id\" = \"tests_film_reporters\".\"film_id\"\) WHERE \"tests_film_reporters\".\"reporter_id\" = \d+",
captured.captured_queries[2]["sql"],
)

0 comments on commit 19bb716

Please sign in to comment.