Skip to content

Commit

Permalink
Don't override annotations when processing a type (#3003)
Browse files Browse the repository at this point in the history
* Ignore some third party warnings in tests

* Only set annotation if it is missing

* Use original annotations when defined

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix older versions of Python

* Fix for pydantic

* Add release notes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
patrick91 and pre-commit-ci[bot] authored Apr 20, 2024
1 parent 0388e15 commit 617d33f
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 9 deletions.
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Release type: patch

This release fixes an issue where annotations on `@strawberry.type`s were overridden
by our code. With release all annotations should be preserved.

This is useful for libraries that use annotations to introspect Strawberry types.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ asyncio_mode = "auto"
filterwarnings = [
"ignore::DeprecationWarning:strawberry.*.resolver",
"ignore:LazyType is deprecated:DeprecationWarning",
"ignore::DeprecationWarning:ddtrace.internal",
"ignore::DeprecationWarning:django.utils.encoding",
# ignoring the text instead of the whole warning because we'd
# get an error when django is not installed
"ignore:The default value of USE_TZ",
Expand Down
2 changes: 1 addition & 1 deletion strawberry/experimental/pydantic/error_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def wrap(cls: Type) -> Type:
]

wrapped = _wrap_dataclass(cls)
extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped))
extra_fields = cast(List[dataclasses.Field], _get_fields(wrapped, {}))
private_fields = get_private_fields(wrapped)

all_model_fields.extend(
Expand Down
2 changes: 1 addition & 1 deletion strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]:
)

wrapped = _wrap_dataclass(cls)
extra_strawberry_fields = _get_fields(wrapped)
extra_strawberry_fields = _get_fields(wrapped, {})
extra_fields = cast(List[dataclasses.Field], extra_strawberry_fields)
private_fields = get_private_fields(wrapped)

Expand Down
2 changes: 1 addition & 1 deletion strawberry/federation/schema_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def schema_directive(
) -> Callable[..., T]:
def _wrap(cls: T) -> T:
cls = _wrap_dataclass(cls)
fields = _get_fields(cls)
fields = _get_fields(cls, {})

cls.__strawberry_directive__ = StrawberryFederationSchemaDirective(
python_name=cls.__name__,
Expand Down
29 changes: 26 additions & 3 deletions strawberry/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _check_field_annotations(cls: Type[Any]):
# If the field has a type override then use that instead of using
# the class annotations or resolver annotation
if field_.type_annotation is not None:
cls_annotations[field_name] = field_.type_annotation.annotation
if field_name not in cls_annotations:
cls_annotations[field_name] = field_.type_annotation.annotation
continue

# Make sure the cls has an annotation
Expand All @@ -85,7 +86,8 @@ def _check_field_annotations(cls: Type[Any]):
field_name, resolver=field_.base_resolver
)

cls_annotations[field_name] = field_.base_resolver.type_annotation
if field_name not in cls_annotations:
cls_annotations[field_name] = field_.base_resolver.type_annotation

# TODO: Make sure the cls annotation agrees with the field's type
# >>> if cls_annotations[field_name] != field.base_resolver.type:
Expand Down Expand Up @@ -133,11 +135,13 @@ def _process_type(
description: Optional[str] = None,
directives: Optional[Sequence[object]] = (),
extend: bool = False,
original_type_annotations: Optional[Dict[str, Any]] = None,
) -> T:
name = name or to_camel_case(cls.__name__)
original_type_annotations = original_type_annotations or {}

interfaces = _get_interfaces(cls)
fields = _get_fields(cls)
fields = _get_fields(cls, original_type_annotations)
is_type_of = getattr(cls, "is_type_of", None)
resolve_type = getattr(cls, "resolve_type", None)

Expand Down Expand Up @@ -245,7 +249,25 @@ def wrap(cls: Type) -> T:
exc = ObjectIsNotClassError.type
raise exc(cls)

# when running `_wrap_dataclass` we lose some of the information about the
# the passed types, especially the type_annotation inside the StrawberryField
# this makes it impossible to customise the field type, like this:
# >>> @strawberry.type
# >>> class Query:
# >>> a: int = strawberry.field(graphql_type=str)
# so we need to extract the information before running `_wrap_dataclass`
original_type_annotations: Dict[str, Any] = {}

annotations = getattr(cls, "__annotations__", {})

for field_name in annotations:
field = getattr(cls, field_name, None)

if field and isinstance(field, StrawberryField) and field.type_annotation:
original_type_annotations[field_name] = field.type_annotation.annotation

wrapped = _wrap_dataclass(cls)

return _process_type(
wrapped,
name=name,
Expand All @@ -254,6 +276,7 @@ def wrap(cls: Type) -> T:
description=description,
directives=directives,
extend=extend,
original_type_annotations=original_type_annotations,
)

if cls is None:
Expand Down
2 changes: 1 addition & 1 deletion strawberry/schema_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def schema_directive(
) -> Callable[..., T]:
def _wrap(cls: T) -> T:
cls = _wrap_dataclass(cls)
fields = _get_fields(cls)
fields = _get_fields(cls, {})

cls.__strawberry_directive__ = StrawberrySchemaDirective(
python_name=cls.__name__,
Expand Down
11 changes: 9 additions & 2 deletions strawberry/types/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import sys
from typing import Dict, List, Type
from typing import Any, Dict, List, Type

from strawberry.annotation import StrawberryAnnotation
from strawberry.exceptions import (
Expand All @@ -16,7 +16,9 @@
from strawberry.unset import UNSET


def _get_fields(cls: Type) -> List[StrawberryField]:
def _get_fields(
cls: Type[Any], original_type_annotations: Dict[str, Type[Any]]
) -> List[StrawberryField]:
"""Get all the strawberry fields off a strawberry.type cls
This function returns a list of StrawberryFields (one for each field item), while
Expand Down Expand Up @@ -49,6 +51,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
passing a named function (i.e. not an anonymous lambda) to strawberry.field
(typically as a decorator).
"""

fields: Dict[str, StrawberryField] = {}

# before trying to find any fields, let's first add the fields defined in
Expand Down Expand Up @@ -152,6 +155,10 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
assert_message = "Field must have a name by the time the schema is generated"
assert field_name is not None, assert_message

if field.name in original_type_annotations:
field.type = original_type_annotations[field.name]
field.type_annotation = StrawberryAnnotation(annotation=field.type)

# TODO: Raise exception if field_name already in fields
fields[field_name] = field

Expand Down
16 changes: 16 additions & 0 deletions tests/types/test_object_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,19 @@ class Thing:
match=re.escape("__init__() takes 1 positional argument but 2 were given"),
):
Thing("something")


def test_object_preserves_annotations():
@strawberry.type
class Object:
a: bool
b: Annotated[str, "something"]
c: bool = strawberry.field(graphql_type=int)
d: Annotated[str, "something"] = strawberry.field(graphql_type=int)

assert Object.__annotations__ == {
"a": bool,
"b": Annotated[str, "something"],
"c": bool,
"d": Annotated[str, "something"],
}

0 comments on commit 617d33f

Please sign in to comment.