From adb9042f51b42cf9034da917b1f078e4ce1d63f2 Mon Sep 17 00:00:00 2001 From: yossi <54272821+Apakottur@users.noreply.github.com> Date: Tue, 21 Dec 2021 10:50:02 +0000 Subject: [PATCH] CR fixes --- strawberry/ext/mypy_plugin.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py index f204e2d09c..064b3f2ec5 100644 --- a/strawberry/ext/mypy_plugin.py +++ b/strawberry/ext/mypy_plugin.py @@ -1,5 +1,5 @@ from decimal import Decimal -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast from typing_extensions import Final @@ -58,6 +58,8 @@ class MypyVersion: + """Stores the mypy version to be used by the plugin""" + VERSION: Decimal @@ -274,32 +276,23 @@ def is_dataclasses_field_or_strawberry_field(expr: Expression) -> bool: return False -def _collect_field_args( - expr: Expression, ctx: ClassDefContext -) -> Tuple[bool, Dict[str, Expression]]: +def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]: """Returns a tuple where the first value represents whether or not the expression is a call to dataclass.field and the second is a dictionary of the keyword arguments that field() was called with. """ - if ( - isinstance(expr, CallExpr) - and isinstance(expr.callee, RefExpr) - and expr.callee.fullname == "dataclasses.field" - ): + + if is_dataclasses_field_or_strawberry_field(expr): + expr = cast(CallExpr, expr) + # field() only takes keyword arguments. args = {} + for name, arg in zip(expr.arg_names, expr.args): - if name is None: - # This means that `field` is used with `**` unpacking, - # the best we can do for now is not to fail. - # TODO: we can infer what's inside `**` and try to collect it. - ctx.api.fail( - 'Unpacking **kwargs in "field()" is not supported', - expr, - ) - return True, {} + assert name is not None args[name] = arg return True, args + return False, {} @@ -500,7 +493,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: is_init_var = True node.type = node_type.args[0] - has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx) + has_field_call, field_args = _collect_field_args(stmt.rvalue) is_in_init_param = field_args.get("init") if is_in_init_param is None: