Skip to content

Commit

Permalink
perf(listconnection): only resolve edges when edges or pageInfo are s…
Browse files Browse the repository at this point in the history
…elected (#3480)
  • Loading branch information
euriostigue authored May 1, 2024
1 parent 65142b8 commit a2b41fc
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
8 changes: 8 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Release type: patch

This release adds an optimization to `ListConnection` such that only queries with
`edges` or `pageInfo` in their selected fields triggers `resolve_edges`.

This change is particularly useful for the `strawberry-django` extension's
`ListConnectionWithTotalCount` and the only selected field is `totalCount`. An
extraneous SQL query is prevented with this optimization.
13 changes: 12 additions & 1 deletion strawberry/relay/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from strawberry.utils.inspect import in_async_context
from strawberry.utils.typing import eval_type, is_classvar

from .utils import from_base64, to_base64
from .utils import from_base64, should_resolve_list_connection_edges, to_base64

if TYPE_CHECKING:
from strawberry.scalars import ID
Expand Down Expand Up @@ -933,6 +933,17 @@ async def resolver():
overfetch,
)

if not should_resolve_list_connection_edges(info):
return cls(
edges=[],
page_info=PageInfo(
start_cursor=None,
end_cursor=None,
has_previous_page=False,
has_next_page=False,
),
)

edges = [
edge_class.resolve_edge(
cls.resolve_node(v, info=info, **kwargs),
Expand Down
24 changes: 24 additions & 0 deletions strawberry/relay/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any, Tuple, Union
from typing_extensions import assert_never

from strawberry.types.info import Info
from strawberry.types.nodes import InlineFragment
from strawberry.types.types import StrawberryObjectDefinition


Expand Down Expand Up @@ -61,3 +63,25 @@ def to_base64(type_: Union[str, type, StrawberryObjectDefinition], node_id: Any)
raise ValueError(f"{type_} is not a valid GraphQL type or name") from e

return base64.b64encode(f"{type_name}:{node_id}".encode()).decode()


def should_resolve_list_connection_edges(info: Info) -> bool:
"""Check if the user requested to resolve the `edges` field of a connection.
Args:
info:
The strawberry execution info resolve the type name from
Returns:
True if the user requested to resolve the `edges` field of a connection, False otherwise.
"""
resolve_for_field_names = {"edges", "pageInfo"}
for selection_field in info.selected_fields:
for selection in selection_field.selections:
if (
not isinstance(selection, InlineFragment)
and selection.name in resolve_for_field_names
):
return True
return False
36 changes: 35 additions & 1 deletion tests/relay/test_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Union, cast
from typing_extensions import assert_type
from unittest.mock import MagicMock

import pytest

Expand All @@ -8,7 +9,7 @@
from strawberry.relay.utils import to_base64
from strawberry.types.info import Info

from .schema import Fruit, FruitAsync, schema
from .schema import Fruit, FruitAsync, fruits_resolver, schema


class FakeInfo:
Expand Down Expand Up @@ -255,3 +256,36 @@ def fruit(self) -> Fruit:
return Fruit(color="red") # pragma: no cover

strawberry.Schema(query=Query)


def test_list_connection_without_edges_or_page_info(mocker: MagicMock):
@strawberry.type(name="Connection", description="A connection to a list of items.")
class DummyListConnectionWithTotalCount(relay.ListConnection[relay.NodeType]):
@strawberry.field(description="Total quantity of existing nodes.")
def total_count(self) -> int:
return -1

@strawberry.type
class Query:
fruits: DummyListConnectionWithTotalCount[Fruit] = relay.connection(
resolver=fruits_resolver
)

mock = mocker.patch("strawberry.relay.types.Edge.resolve_edge")
schema = strawberry.Schema(query=Query)
ret = schema.execute_sync(
"""
query {
fruits {
totalCount
}
}
"""
)
mock.assert_not_called()
assert ret.errors is None
assert ret.data == {
"fruits": {
"totalCount": -1,
}
}

0 comments on commit a2b41fc

Please sign in to comment.