Skip to content

Commit

Permalink
[dataclasses plugin] Support kw_only=True (#10867)
Browse files Browse the repository at this point in the history
Fixes #10865
  • Loading branch information
tgallant authored Aug 8, 2021
1 parent c90026b commit ed2b4c7
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 10 deletions.
55 changes: 50 additions & 5 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import Final

from mypy.nodes import (
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Expression, JsonDict, NameExpr, RefExpr,
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode
)
Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
column: int,
type: Optional[Type],
info: TypeInfo,
kw_only: bool,
) -> None:
self.name = name
self.is_in_init = is_in_init
Expand All @@ -45,13 +46,21 @@ def __init__(
self.column = column
self.type = type
self.info = info
self.kw_only = kw_only

def to_argument(self) -> Argument:
arg_kind = ARG_POS
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
return Argument(
variable=self.to_var(),
type_annotation=self.type,
initializer=None,
kind=ARG_OPT if self.has_default else ARG_POS,
kind=arg_kind,
)

def to_var(self) -> Var:
Expand All @@ -67,13 +76,16 @@ def serialize(self) -> JsonDict:
'line': self.line,
'column': self.column,
'type': self.type.serialize(),
'kw_only': self.kw_only,
}

@classmethod
def deserialize(
cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface
) -> 'DataclassAttribute':
data = data.copy()
if data.get('kw_only') is None:
data['kw_only'] = False
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data)

Expand Down Expand Up @@ -122,7 +134,8 @@ def transform(self) -> None:
add_method(
ctx,
'__init__',
args=[attr.to_argument() for attr in attributes if attr.is_in_init],
args=[attr.to_argument() for attr in attributes if attr.is_in_init
and not self._is_kw_only_type(attr.type)],
return_type=NoneType(),
)

Expand Down Expand Up @@ -211,6 +224,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
cls = self._ctx.cls
attrs: List[DataclassAttribute] = []
known_attrs: Set[str] = set()
kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False)
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
Expand Down Expand Up @@ -247,6 +261,9 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
is_init_var = True
node.type = node_type.args[0]

if self._is_kw_only_type(node_type):
kw_only = True

has_field_call, field_args = _collect_field_args(stmt.rvalue)

is_in_init_param = field_args.get('init')
Expand All @@ -270,6 +287,13 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# on self in the generated __init__(), not in the class body.
sym.implicit = True

is_kw_only = kw_only
# Use the kw_only field arg if it is provided. Otherwise use the
# kw_only value from the decorator parameter.
field_kw_only_param = field_args.get('kw_only')
if field_kw_only_param is not None:
is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param))

known_attrs.add(lhs.name)
attrs.append(DataclassAttribute(
name=lhs.name,
Expand All @@ -280,6 +304,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
column=stmt.column,
type=sym.type,
info=cls.info,
kw_only=is_kw_only,
))

# Next, collect attributes belonging to any class in the MRO
Expand Down Expand Up @@ -314,15 +339,18 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
super_attrs.append(attr)
break
all_attrs = super_attrs + all_attrs
all_attrs.sort(key=lambda a: a.kw_only)

# Ensure that arguments without a default don't follow
# arguments that have a default.
found_default = False
# Ensure that the KW_ONLY sentinel is only provided once
found_kw_sentinel = False
for attr in all_attrs:
# If we find any attribute that is_in_init but that
# If we find any attribute that is_in_init, not kw_only, and that
# doesn't have a default after one that does have one,
# then that's an error.
if found_default and attr.is_in_init and not attr.has_default:
if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only:
# If the issue comes from merging different classes, report it
# at the class definition point.
context = (Context(line=attr.line, column=attr.column) if attr in attrs
Expand All @@ -333,6 +361,14 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
)

found_default = found_default or (attr.has_default and attr.is_in_init)
if found_kw_sentinel and self._is_kw_only_type(attr.type):
context = (Context(line=attr.line, column=attr.column) if attr in attrs
else ctx.cls)
ctx.api.fail(
'There may not be more than one field with the KW_ONLY type',
context,
)
found_kw_sentinel = found_kw_sentinel or self._is_kw_only_type(attr.type)

return all_attrs

Expand Down Expand Up @@ -372,6 +408,15 @@ def _propertize_callables(self, attributes: List[DataclassAttribute]) -> None:
var._fullname = info.fullname + '.' + var.name
info.names[var.name] = SymbolTableNode(MDEF, var)

def _is_kw_only_type(self, node: Optional[Type]) -> bool:
"""Checks if the type of the node is the KW_ONLY sentinel value."""
if node is None:
return False
node_type = get_proper_type(node)
if not isinstance(node_type, Instance):
return False
return node_type.type.fullname == 'dataclasses.KW_ONLY'


def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
"""Hooks into the class typechecking process to add support for dataclasses.
Expand Down
134 changes: 133 additions & 1 deletion test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class Person:
name: str
age: int = field(init=None) # E: No overload variant of "field" matches argument type "None" \
# N: Possible overload variant: \
# N: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ...) -> Any \
# N: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> Any \
# N: <2 more non-matching overloads not shown>

[builtins fixtures/list.pyi]
Expand Down Expand Up @@ -311,6 +311,138 @@ class Application:

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnly]
# flags: --python-version 3.10
from dataclasses import dataclass

@dataclass(kw_only=True)
class Application:
name: str = 'Unnamed'
rating: int

Application(rating=5)
Application(name='name', rating=5)
Application() # E: Missing named argument "rating" for "Application"
Application('name') # E: Too many positional arguments for "Application" # E: Missing named argument "rating" for "Application"
Application('name', 123) # E: Too many positional arguments for "Application"
Application('name', rating=123) # E: Too many positional arguments for "Application"
Application(name=123, rating='name') # E: Argument "name" to "Application" has incompatible type "int"; expected "str" # E: Argument "rating" to "Application" has incompatible type "str"; expected "int"
Application(rating='name', name=123) # E: Argument "rating" to "Application" has incompatible type "str"; expected "int" # E: Argument "name" to "Application" has incompatible type "int"; expected "str"

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyOnField]
# flags: --python-version 3.10
from dataclasses import dataclass, field

@dataclass
class Application:
name: str = 'Unnamed'
rating: int = field(kw_only=True)

Application(rating=5)
Application('name', rating=123)
Application(name='name', rating=5)
Application() # E: Missing named argument "rating" for "Application"
Application('name') # E: Missing named argument "rating" for "Application"
Application('name', 123) # E: Too many positional arguments for "Application"
Application(123, rating='name') # E: Argument 1 to "Application" has incompatible type "int"; expected "str" # E: Argument "rating" to "Application" has incompatible type "str"; expected "int"

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyOnFieldFalse]
# flags: --python-version 3.10
from dataclasses import dataclass, field

@dataclass
class Application:
name: str = 'Unnamed'
rating: int = field(kw_only=False) # E: Attributes without a default cannot follow attributes with one

Application(name='name', rating=5)
Application('name', 123)
Application('name', rating=123)
Application() # E: Missing positional argument "name" in call to "Application"
Application('name') # E: Too few arguments for "Application"

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyWithSentinel]
# flags: --python-version 3.10
from dataclasses import dataclass, KW_ONLY

@dataclass
class Application:
_: KW_ONLY
name: str = 'Unnamed'
rating: int

Application(rating=5)
Application(name='name', rating=5)
Application() # E: Missing named argument "rating" for "Application"
Application('name') # E: Too many positional arguments for "Application" # E: Missing named argument "rating" for "Application"
Application('name', 123) # E: Too many positional arguments for "Application"
Application('name', rating=123) # E: Too many positional arguments for "Application"

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyWithSentinelAndFieldOverride]
# flags: --python-version 3.10
from dataclasses import dataclass, field, KW_ONLY

@dataclass
class Application:
_: KW_ONLY
name: str = 'Unnamed'
rating: int = field(kw_only=False) # E: Attributes without a default cannot follow attributes with one

Application(name='name', rating=5)
Application() # E: Missing positional argument "name" in call to "Application"
Application('name') # E: Too many positional arguments for "Application" # E: Too few arguments for "Application"
Application('name', 123) # E: Too many positional arguments for "Application"
Application('name', rating=123) # E: Too many positional arguments for "Application"

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyWithSentinelAndSubclass]
# flags: --python-version 3.10
from dataclasses import dataclass, field, KW_ONLY

@dataclass
class Base:
x: str
_: KW_ONLY
y: int = 0
w: int = 1

@dataclass
class D(Base):
z: str
a: str = "a"

D("Hello", "World")
D(x="Hello", z="World")
D("Hello", "World", y=1, w=2, a="b")
D("Hello") # E: Missing positional argument "z" in call to "D"
D() # E: Missing positional arguments "x", "z" in call to "D"
D(123, "World") # E: Argument 1 to "D" has incompatible type "int"; expected "str"
D("Hello", False) # E: Argument 2 to "D" has incompatible type "bool"; expected "str"
D(123, False) # E: Argument 1 to "D" has incompatible type "int"; expected "str" # E: Argument 2 to "D" has incompatible type "bool"; expected "str"

[case testDataclassesOrderingKwOnlyWithMultipleSentinel]
# flags: --python-version 3.10
from dataclasses import dataclass, field, KW_ONLY

@dataclass
class Base:
x: str
_: KW_ONLY
y: int = 0
__: KW_ONLY # E: There may not be more than one field with the KW_ONLY type
w: int = 1

[builtins fixtures/list.pyi]

[case testDataclassesClassmethods]
# flags: --python-version 3.7
from dataclasses import dataclass
Expand Down
10 changes: 6 additions & 4 deletions test-data/unit/lib-stub/dataclasses.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,28 @@ _T = TypeVar('_T')
class InitVar(Generic[_T]):
...

class KW_ONLY: ...

@overload
def dataclass(_cls: Type[_T]) -> Type[_T]: ...

@overload
def dataclass(*, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ...,
unsafe_hash: bool = ..., frozen: bool = ...) -> Callable[[Type[_T]], Type[_T]]: ...
unsafe_hash: bool = ..., frozen: bool = ..., match_args: bool = ...,
kw_only: bool = ..., slots: bool = ...) -> Callable[[Type[_T]], Type[_T]]: ...


@overload
def field(*, default: _T,
init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> _T: ...
metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...,) -> _T: ...

@overload
def field(*, default_factory: Callable[[], _T],
init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> _T: ...
metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...,) -> _T: ...

@overload
def field(*,
init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> Any: ...
metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...,) -> Any: ...

0 comments on commit ed2b4c7

Please sign in to comment.