diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index a68410765367..d1b074ca9e8e 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -39,6 +39,7 @@ ) from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( + _get_callee_type, _get_decorator_bool_argument, add_attribute_to_class, add_method_to_class, @@ -47,7 +48,7 @@ from mypy.semanal_shared import find_dataclass_transform_spec, require_bool_literal_argument from mypy.server.trigger import make_wildcard_trigger from mypy.state import state -from mypy.typeops import map_type_from_supertype +from mypy.typeops import map_type_from_supertype, try_getting_literals_from_type from mypy.types import ( AnyType, CallableType, @@ -509,7 +510,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: is_in_init_param = field_args.get("init") if is_in_init_param is None: - is_in_init = True + is_in_init = self._get_default_init_value_for_field_specifier(stmt.rvalue) else: is_in_init = bool(self._api.parse_bool(is_in_init_param)) @@ -738,6 +739,30 @@ def _get_bool_arg(self, name: str, default: bool) -> bool: return require_bool_literal_argument(self._api, expression, name, default) return default + def _get_default_init_value_for_field_specifier(self, call: Expression) -> bool: + """ + Find a default value for the `init` parameter of the specifier being called. If the + specifier's type signature includes an `init` parameter with a type of `Literal[True]` or + `Literal[False]`, return the appropriate boolean value from the literal. Otherwise, + fall back to the standard default of `True`. + """ + if not isinstance(call, CallExpr): + return True + + specifier_type = _get_callee_type(call) + if specifier_type is None: + return True + + parameter = specifier_type.argument_by_name("init") + if parameter is None: + return True + + literals = try_getting_literals_from_type(parameter.typ, bool, "builtins.bool") + if literals is None or len(literals) != 1: + return True + + return literals[0] + def add_dataclass_tag(info: TypeInfo) -> None: # The value is ignored, only the existence matters. diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index b0c1cdf56097..e8e7802d3072 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -328,6 +328,38 @@ Foo(a=1, b='bye') [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] +[case testDataclassTransformFieldSpecifierImplicitInit] +# flags: --python-version 3.11 +from typing import dataclass_transform, Literal, overload + +def init(*, init: Literal[True] = True): ... +def no_init(*, init: Literal[False] = False): ... + +@overload +def field_overload(*, custom: None, init: Literal[True] = True): ... +@overload +def field_overload(*, custom: str, init: Literal[False] = False): ... +def field_overload(*, custom, init): ... + +@dataclass_transform(field_specifiers=(init, no_init, field_overload)) +def my_dataclass(cls): return cls + +@my_dataclass +class Foo: + a: int = init() + b: int = field_overload(custom=None) + + bad1: int = no_init() + bad2: int = field_overload(custom="bad2") + +reveal_type(Foo) # N: Revealed type is "def (a: builtins.int, b: builtins.int) -> __main__.Foo" +Foo(a=1, b=2) +Foo(a=1, b=2, bad1=0) # E: Unexpected keyword argument "bad1" for "Foo" +Foo(a=1, b=2, bad2=0) # E: Unexpected keyword argument "bad2" for "Foo" + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + [case testDataclassTransformOverloadsDecoratorOnOverload] # flags: --python-version 3.11 from typing import dataclass_transform, overload, Any, Callable, Type, Literal