Skip to content

Commit

Permalink
CR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Apakottur committed Dec 21, 2021
1 parent 1faeb2f commit adb9042
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions strawberry/ext/mypy_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from decimal import Decimal
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast

from typing_extensions import Final

Expand Down Expand Up @@ -58,6 +58,8 @@


class MypyVersion:
"""Stores the mypy version to be used by the plugin"""

VERSION: Decimal


Expand Down Expand Up @@ -274,32 +276,23 @@ def is_dataclasses_field_or_strawberry_field(expr: Expression) -> bool:
return False


def _collect_field_args(
expr: Expression, ctx: ClassDefContext
) -> Tuple[bool, Dict[str, Expression]]:
def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname == "dataclasses.field"
):

if is_dataclasses_field_or_strawberry_field(expr):
expr = cast(CallExpr, expr)

# field() only takes keyword arguments.
args = {}

for name, arg in zip(expr.arg_names, expr.args):
if name is None:
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
ctx.api.fail(
'Unpacking **kwargs in "field()" is not supported',
expr,
)
return True, {}
assert name is not None
args[name] = arg
return True, args

return False, {}


Expand Down Expand Up @@ -500,7 +493,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
is_init_var = True
node.type = node_type.args[0]

has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)
has_field_call, field_args = _collect_field_args(stmt.rvalue)

is_in_init_param = field_args.get("init")
if is_in_init_param is None:
Expand Down

0 comments on commit adb9042

Please sign in to comment.