From 01ae2026ebf9192a65eca188633a8552bdc3fcdf Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 17 Apr 2022 15:59:10 -0700 Subject: [PATCH] Add support for Unpack in args and kwargs (#523) --- docs/changelog.md | 2 + pyanalyze/annotations.py | 107 +++++++++--------- pyanalyze/arg_spec.py | 11 +- pyanalyze/functions.py | 56 +++++++-- pyanalyze/name_check_visitor.py | 23 +++- pyanalyze/signature.py | 76 ++++++++++--- .../stubs/_pyanalyze_tests-stubs/args.pyi | 11 ++ pyanalyze/test_signature.py | 50 ++++++++ pyanalyze/test_typeshed.py | 16 +++ pyanalyze/typeshed.py | 29 +++-- pyanalyze/value.py | 16 +++ 11 files changed, 299 insertions(+), 98 deletions(-) create mode 100644 pyanalyze/stubs/_pyanalyze_tests-stubs/args.pyi diff --git a/docs/changelog.md b/docs/changelog.md index 822f3ccb..05611f34 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,8 @@ ## Unreleased +- Add support for use of the `Unpack` operator to + annotate heterogeneous `*args` and `**kwargs` parameters (#523) - Detect incompatible types for some calls to `list.append`, `list.extend`, `list.__add__`, and `set.add` (#522) - Optimize local variables with very complex inferred types (#521) diff --git a/pyanalyze/annotations.py b/pyanalyze/annotations.py index 67f28057..8012d6d7 100644 --- a/pyanalyze/annotations.py +++ b/pyanalyze/annotations.py @@ -91,6 +91,7 @@ SequenceValue, TypeGuardExtension, TypedValue, + UnpackedValue, annotate_value, unite_values, Value, @@ -278,6 +279,8 @@ def type_from_runtime( node: Optional[ast.AST] = None, globals: Optional[Mapping[str, object]] = None, ctx: Optional[Context] = None, + *, + allow_unpack: bool = False, ) -> Value: """Given a runtime annotation object, return a :class:`Value `. @@ -297,11 +300,13 @@ def type_from_runtime( :param ctx: :class:`Context` to use for evaluation. + :param allow_unpack: Whether to allow `Unpack` types. + """ if ctx is None: ctx = _DefaultContext(visitor, node, globals) - return _type_from_runtime(val, ctx) + return _type_from_runtime(val, ctx, allow_unpack=allow_unpack) def type_from_value( @@ -309,7 +314,9 @@ def type_from_value( visitor: Optional["NameCheckVisitor"] = None, node: Optional[ast.AST] = None, ctx: Optional[Context] = None, + *, is_typeddict: bool = False, + allow_unpack: bool = False, ) -> Value: """Given a :class:`Value ` representing the type. @@ -336,7 +343,9 @@ def type_from_value( """ if ctx is None: ctx = _DefaultContext(visitor, node) - return _type_from_value(value, ctx, is_typeddict=is_typeddict) + return _type_from_value( + value, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack + ) def value_from_ast( @@ -355,20 +364,20 @@ def _type_from_ast( ctx: Context, *, is_typeddict: bool = False, - unpack_allowed: bool = False, + allow_unpack: bool = False, ) -> Value: val = value_from_ast(node, ctx) return _type_from_value( - val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack ) def _type_from_runtime( - val: Any, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False + val: Any, ctx: Context, *, is_typeddict: bool = False, allow_unpack: bool = False ) -> Value: if isinstance(val, str): return _eval_forward_ref( - val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack ) elif isinstance(val, tuple): # This happens under some Python versions for types @@ -383,16 +392,14 @@ def _type_from_runtime( args = (val[1],) else: args = val[1:] - return _value_of_origin_args( - origin, args, val, ctx, unpack_allowed=unpack_allowed - ) + return _value_of_origin_args(origin, args, val, ctx, allow_unpack=allow_unpack) elif GenericAlias is not None and isinstance(val, GenericAlias): origin = get_origin(val) args = get_args(val) if origin is tuple and not args: return SequenceValue(tuple, []) return _value_of_origin_args( - origin, args, val, ctx, unpack_allowed=origin is tuple + origin, args, val, ctx, allow_unpack=origin is tuple ) elif typing_inspect.is_literal_type(val): args = typing_inspect.get_args(val) @@ -417,7 +424,8 @@ def _type_from_runtime( else: return _make_sequence_value( tuple, - [_type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args], + [_type_from_runtime(arg, ctx, allow_unpack=True) for arg in args], + ctx, ) elif is_instance_of_typing_name(val, "_TypedDictMeta"): required_keys = getattr(val, "__required_keys__", None) @@ -464,7 +472,7 @@ def _type_from_runtime( val, ctx, is_typeddict=is_typeddict, - unpack_allowed=unpack_allowed or origin is tuple or origin is Tuple, + allow_unpack=allow_unpack or origin is tuple or origin is Tuple, ) elif typing_inspect.is_callable_type(val): args = typing_inspect.get_args(val) @@ -568,8 +576,8 @@ def _type_from_runtime( return AnyValue(AnySource.error) # Also 3.6 only. elif is_instance_of_typing_name(val, "_Unpack"): - if unpack_allowed: - return _make_unpacked_value(_type_from_runtime(val.__type__, ctx), ctx) + if allow_unpack: + return UnpackedValue(_type_from_runtime(val.__type__, ctx)) else: ctx.show_error("Unpack[] used in unsupported context") return AnyValue(AnySource.error) @@ -622,7 +630,7 @@ def _callable_args_from_runtime( types = [_type_from_runtime(arg, ctx) for arg in arg_types] params = [ SigParameter( - f"__arg{i}", + f"@{i}", kind=ParameterKind.PARAM_SPEC if isinstance(typ, TypeVarValue) and typ.is_paramspec else ParameterKind.POSITIONAL_ONLY, @@ -648,7 +656,7 @@ def _args_from_concatenate(concatenate: Any, ctx: Context) -> Sequence[SigParame types = [_type_from_runtime(arg, ctx) for arg in concatenate.__args__] params = [ SigParameter( - f"__arg{i}", + f"@{i}", kind=ParameterKind.PARAM_SPEC if i == len(types) - 1 else ParameterKind.POSITIONAL_ONLY, @@ -677,7 +685,7 @@ def _get_typeddict_value( def _eval_forward_ref( - val: str, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False + val: str, ctx: Context, *, is_typeddict: bool = False, allow_unpack: bool = False ) -> Value: try: tree = ast.parse(val, mode="eval") @@ -686,7 +694,7 @@ def _eval_forward_ref( return AnyValue(AnySource.error) else: return _type_from_ast( - tree.body, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + tree.body, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack ) @@ -695,11 +703,11 @@ def _type_from_value( ctx: Context, *, is_typeddict: bool = False, - unpack_allowed: bool = False, + allow_unpack: bool = False, ) -> Value: if isinstance(value, KnownValue): return _type_from_runtime( - value.val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + value.val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack ) elif isinstance(value, TypeVarValue): return value @@ -707,7 +715,7 @@ def _type_from_value( return unite_values( *[ _type_from_value( - val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed + val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack ) for val in value.vals ] @@ -720,7 +728,7 @@ def _type_from_value( value.members, ctx, is_typeddict=is_typeddict, - unpack_allowed=unpack_allowed, + allow_unpack=allow_unpack, ) elif isinstance(value, AnyValue): return value @@ -740,7 +748,7 @@ def _type_from_subscripted_value( ctx: Context, *, is_typeddict: bool = False, - unpack_allowed: bool = False, + allow_unpack: bool = False, ) -> Value: if isinstance(root, GenericValue): if len(root.args) == len(members): @@ -758,7 +766,7 @@ def _type_from_subscripted_value( members, ctx, is_typeddict=is_typeddict, - unpack_allowed=unpack_allowed, + allow_unpack=allow_unpack, ) for subval in root.vals ] @@ -800,7 +808,8 @@ def _type_from_subscripted_value( else: return _make_sequence_value( tuple, - [_type_from_value(arg, ctx, unpack_allowed=True) for arg in members], + [_type_from_value(arg, ctx, allow_unpack=True) for arg in members], + ctx, ) elif root is typing.Optional: if len(members) != 1: @@ -840,13 +849,13 @@ def _type_from_subscripted_value( return AnyValue(AnySource.error) return Pep655Value(False, _type_from_value(members[0], ctx)) elif is_typing_name(root, "Unpack"): - if not unpack_allowed: + if not allow_unpack: ctx.show_error("Unpack[] used in unsupported context") return AnyValue(AnySource.error) if len(members) != 1: ctx.show_error("Unpack requires a single argument") return AnyValue(AnySource.error) - return _make_unpacked_value(_type_from_value(members[0], ctx), ctx) + return UnpackedValue(_type_from_value(members[0], ctx)) elif root is Callable or root is typing.Callable: if len(members) == 2: args, return_value = members @@ -955,11 +964,6 @@ class Pep655Value(Value): value: Value -@dataclass -class UnpackedValue(Value): - elements: Sequence[Tuple[bool, Value]] - - class _Visitor(ast.NodeVisitor): def __init__(self, ctx: Context) -> None: self.ctx = ctx @@ -1136,7 +1140,7 @@ def _value_of_origin_args( ctx: Context, *, is_typeddict: bool = False, - unpack_allowed: bool = False, + allow_unpack: bool = False, ) -> Value: if origin is typing.Type or origin is type: if not args: @@ -1151,9 +1155,9 @@ def _value_of_origin_args( return SequenceValue(tuple, []) else: args_vals = [ - _type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args + _type_from_runtime(arg, ctx, allow_unpack=True) for arg in args ] - return _make_sequence_value(tuple, args_vals) + return _make_sequence_value(tuple, args_vals, ctx) elif origin is typing.Union: return unite_values(*[_type_from_runtime(arg, ctx) for arg in args]) elif origin is Callable or origin is typing.Callable: @@ -1218,13 +1222,13 @@ def _value_of_origin_args( return AnyValue(AnySource.error) return Pep655Value(False, _type_from_runtime(args[0], ctx)) elif is_typing_name(origin, "Unpack"): - if not unpack_allowed: + if not allow_unpack: ctx.show_error("Invalid usage of Unpack") return AnyValue(AnySource.error) if len(args) != 1: ctx.show_error("Unpack requires a single argument") return AnyValue(AnySource.error) - return _make_unpacked_value(_type_from_runtime(args[0], ctx), ctx) + return UnpackedValue(_type_from_runtime(args[0], ctx)) elif origin is None and isinstance(val, type): # This happens for SupportsInt in 3.7. return _maybe_typed_value(val) @@ -1243,27 +1247,22 @@ def _maybe_typed_value(val: Union[type, str]) -> Value: return TypedValue(val) -def _make_sequence_value(typ: type, members: Sequence[Value]) -> SequenceValue: +def _make_sequence_value( + typ: type, members: Sequence[Value], ctx: Context +) -> SequenceValue: pairs = [] for val in members: if isinstance(val, UnpackedValue): - pairs += val.elements + elements = val.get_elements() + if elements is None: + ctx.show_error(f"Invalid usage of Unpack with {val}") + elements = [(True, AnyValue(AnySource.error))] + pairs += elements else: pairs.append((False, val)) return SequenceValue(typ, pairs) -def _make_unpacked_value(val: Value, ctx: Context) -> UnpackedValue: - if isinstance(val, SequenceValue) and val.typ is tuple: - return UnpackedValue(val.members) - elif isinstance(val, GenericValue) and val.typ is tuple: - return UnpackedValue([(True, val.args[0])]) - elif isinstance(val, TypedValue) and val.typ is tuple: - return UnpackedValue([(True, AnyValue(AnySource.generic_argument))]) - ctx.show_error(f"Invalid argument for Unpack: {val}") - return UnpackedValue([]) - - def _make_callable_from_value( args: Value, return_value: Value, ctx: Context, is_asynq: bool = False ) -> Value: @@ -1280,15 +1279,13 @@ def _make_callable_from_value( annotation = _type_from_value(arg, ctx) if is_many: param = SigParameter( - f"__arg{i}", + f"@{i}", kind=ParameterKind.VAR_POSITIONAL, annotation=GenericValue(tuple, [annotation]), ) else: param = SigParameter( - f"__arg{i}", - kind=ParameterKind.POSITIONAL_ONLY, - annotation=annotation, + f"@{i}", kind=ParameterKind.POSITIONAL_ONLY, annotation=annotation ) params.append(param) try: @@ -1318,7 +1315,7 @@ def _make_callable_from_value( annotations = [_type_from_value(arg, ctx) for arg in args.members] params = [ SigParameter( - f"__arg{i}", + f"@{i}", kind=ParameterKind.PARAM_SPEC if i == len(annotations) - 1 else ParameterKind.POSITIONAL_ONLY, diff --git a/pyanalyze/arg_spec.py b/pyanalyze/arg_spec.py index ad772ebc..62772672 100644 --- a/pyanalyze/arg_spec.py +++ b/pyanalyze/arg_spec.py @@ -4,6 +4,7 @@ """ +from .functions import translate_vararg_type from .options import Options, PyObjectSequenceOption from .analysis_lib import is_positional_only_arg_name from .extensions import CustomCheck, TypeGuard, get_overloads, get_type_evaluations @@ -423,14 +424,12 @@ def _get_type_for_parameter( is_constructor: bool, ) -> Value: if parameter.annotation is not inspect.Parameter.empty: + kind = ParameterKind(parameter.kind) + ctx = AnnotationsContext(self, func_globals) typ = type_from_runtime( - parameter.annotation, ctx=AnnotationsContext(self, func_globals) + parameter.annotation, ctx=ctx, allow_unpack=kind.allow_unpack() ) - if parameter.kind is inspect.Parameter.VAR_POSITIONAL: - return GenericValue(tuple, [typ]) - elif parameter.kind is inspect.Parameter.VAR_KEYWORD: - return GenericValue(dict, [TypedValue(str), typ]) - return typ + return translate_vararg_type(kind, typ, self.ctx) # If this is the self argument of a method, try to infer the self type. elif index == 0 and parameter.kind in ( inspect.Parameter.POSITIONAL_ONLY, diff --git a/pyanalyze/functions.py b/pyanalyze/functions.py index a3cd273b..7614479e 100644 --- a/pyanalyze/functions.py +++ b/pyanalyze/functions.py @@ -32,6 +32,7 @@ KnownValue, GenericValue, SubclassValue, + UnpackedValue, TypeVarValue, unite_values, CanAssignError, @@ -94,7 +95,9 @@ class Context(ErrorContext, CanAssignContext, Protocol): def visit_expression(self, __node: ast.AST) -> Value: raise NotImplementedError - def value_of_annotation(self, __node: ast.expr) -> Value: + def value_of_annotation( + self, __node: ast.expr, *, allow_unpack: bool = False + ) -> Value: raise NotImplementedError def check_call( @@ -259,7 +262,9 @@ def compute_parameters( params = [] tv_index = 1 - for idx, ((kind, arg), default) in enumerate(zip_longest(args, defaults)): + for idx, (param, default) in enumerate(zip_longest(args, defaults)): + assert param is not None, "must have more args than defaults" + (kind, arg) = param is_self = ( idx == 0 and enclosing_class is not None @@ -267,7 +272,9 @@ def compute_parameters( and not isinstance(node, ast.Lambda) ) if arg.annotation is not None: - value = ctx.value_of_annotation(arg.annotation) + value = ctx.value_of_annotation( + arg.annotation, allow_unpack=kind.allow_unpack() + ) if default is not None: tv_map = value.can_assign(default, ctx) if isinstance(tv_map, CanAssignError): @@ -305,17 +312,50 @@ def compute_parameters( if default is not None: value = unite_values(value, default) - if kind is ParameterKind.VAR_POSITIONAL: - value = GenericValue(tuple, [value]) - elif kind is ParameterKind.VAR_KEYWORD: - value = GenericValue(dict, [TypedValue(str), value]) - + value = translate_vararg_type(kind, value, ctx, error_ctx=ctx, node=arg) param = SigParameter(arg.arg, kind, default, value) info = ParamInfo(param, arg, is_self) params.append(info) return params +def translate_vararg_type( + kind: ParameterKind, + typ: Value, + can_assign_ctx: CanAssignContext, + *, + error_ctx: Optional[ErrorContext] = None, + node: Optional[ast.AST] = None, +) -> Value: + if kind is ParameterKind.VAR_POSITIONAL: + if isinstance(typ, UnpackedValue): + if not TypedValue(tuple).is_assignable(typ.value, can_assign_ctx): + if error_ctx is not None and node is not None: + error_ctx.show_error( + node, + "Expected tuple type inside Unpack[]", + error_code=ErrorCode.invalid_annotation, + ) + return AnyValue(AnySource.error) + return typ.value + else: + return GenericValue(tuple, [typ]) + elif kind is ParameterKind.VAR_KEYWORD: + if isinstance(typ, UnpackedValue): + if not TypedValue(dict).is_assignable(typ.value, can_assign_ctx): + if error_ctx is not None and node is not None: + error_ctx.show_error( + node, + "Expected dict type inside Unpack[]", + error_code=ErrorCode.invalid_annotation, + ) + return AnyValue(AnySource.error) + return typ.value + else: + return GenericValue(dict, [TypedValue(str), typ]) + return typ + + @dataclass class IsGeneratorVisitor(ast.NodeVisitor): """Determine whether an async function is a generator. diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index d57d101e..8ac78d7a 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -2044,20 +2044,35 @@ def _check_function_unused_vars( replacement=replacement, ) - def value_of_annotation(self, node: ast.expr) -> Value: + def value_of_annotation( + self, node: ast.expr, *, allow_unpack: bool = False + ) -> Value: with qcore.override(self, "state", VisitorState.collect_names): annotated_type = self._visit_annotation(node) - return self._value_of_annotation_type(annotated_type, node) + return self._value_of_annotation_type( + annotated_type, node, allow_unpack=allow_unpack + ) def _visit_annotation(self, node: ast.AST) -> Value: with qcore.override(self, "in_annotation", True): return self.visit(node) def _value_of_annotation_type( - self, val: Value, node: ast.AST, is_typeddict: bool = False + self, + val: Value, + node: ast.AST, + *, + is_typeddict: bool = False, + allow_unpack: bool = False, ) -> Value: """Given a value encountered in a type annotation, return a type.""" - return type_from_value(val, visitor=self, node=node, is_typeddict=is_typeddict) + return type_from_value( + val, + visitor=self, + node=node, + is_typeddict=is_typeddict, + allow_unpack=allow_unpack, + ) def _check_method_first_arg( self, node: FunctionNode, function_info: FunctionInfo diff --git a/pyanalyze/signature.py b/pyanalyze/signature.py index d27c4fe0..d7d098db 100644 --- a/pyanalyze/signature.py +++ b/pyanalyze/signature.py @@ -351,6 +351,9 @@ class ParameterKind(enum.Enum): with any other callable. There can only be one ELLIPSIS parameter and it must be the last one.""" + def allow_unpack(self) -> bool: + return self is ParameterKind.VAR_KEYWORD or self is ParameterKind.VAR_POSITIONAL + KIND_TO_ALLOWED_PREVIOUS = { ParameterKind.POSITIONAL_ONLY: {ParameterKind.POSITIONAL_ONLY}, @@ -431,19 +434,28 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "SigParameter": def get_annotation(self) -> Value: return self.annotation + def is_unnamed(self) -> bool: + return self.name.startswith("@") + def __str__(self) -> str: # Adapted from Parameter.__str__ kind = self.kind - formatted = self.name - - if self.annotation != UNANNOTATED: - formatted = f"{formatted}: {self.annotation}" + if self.is_unnamed(): + if self.default is None: + formatted = str(self.annotation) + else: + formatted = f"{self.annotation} = {self.default}" + else: + formatted = self.name - if self.default is not None: if self.annotation != UNANNOTATED: - formatted = f"{formatted} = {self.default}" - else: - formatted = f"{formatted}={self.default}" + formatted = f"{formatted}: {self.annotation}" + + if self.default is not None: + if self.annotation != UNANNOTATED: + formatted = f"{formatted} = {self.default}" + else: + formatted = f"{formatted}={self.default}" if kind is ParameterKind.VAR_POSITIONAL: formatted = "*" + formatted @@ -779,9 +791,14 @@ def bind_arguments( elif actual_args.ellipsis: bound_args[param.name] = UNKNOWN, ELLIPSIS_COMPOSITE else: - self.show_call_error( - f"Missing required positional argument: '{param.name}'", ctx - ) + if param.is_unnamed(): + message = ( + "Missing required positional argument at position" + f" {int(param.name[1:])}" + ) + else: + message = f"Missing required positional argument '{param.name}'" + self.show_call_error(message, ctx) return None elif param.kind is ParameterKind.POSITIONAL_OR_KEYWORD: if positional_index < len(actual_args.positionals): @@ -864,7 +881,7 @@ def bind_arguments( bound_args[param.name] = DEFAULT, ELLIPSIS_COMPOSITE else: self.show_call_error( - f"Missing required argument: '{param.name}'", ctx + f"Missing required argument '{param.name}'", ctx ) return None elif param.kind is ParameterKind.KEYWORD_ONLY: @@ -906,7 +923,7 @@ def bind_arguments( bound_args[param.name] = DEFAULT, ELLIPSIS_COMPOSITE else: self.show_call_error( - f"Missing required argument: '{param.name}'", ctx + f"Missing required argument '{param.name}'", ctx ) return None elif param.kind is ParameterKind.VAR_POSITIONAL: @@ -1662,8 +1679,39 @@ def make( if return_annotation is None: return_annotation = AnyValue(AnySource.unannotated) has_return_annotation = False + param_dict = {} + i = 0 + for param in parameters: + if param.kind is ParameterKind.VAR_POSITIONAL and isinstance( + param.annotation, SequenceValue + ): + simple_members = param.annotation.get_member_sequence() + if simple_members is None: + param_dict[param.name] = param + i += 1 + else: + for member in simple_members: + name = f"@{i}" + param_dict[name] = SigParameter( + name, ParameterKind.POSITIONAL_ONLY, annotation=member + ) + i += 1 + elif param.kind is ParameterKind.VAR_KEYWORD and isinstance( + param.annotation, TypedDictValue + ): + for name, (is_required, value) in param.annotation.items.items(): + param_dict[name] = SigParameter( + name, + ParameterKind.KEYWORD_ONLY, + annotation=value, + default=None if is_required else AnyValue(AnySource.marker), + ) + i += 1 + else: + param_dict[param.name] = param + i += 1 return cls( - {param.name: param for param in parameters}, + param_dict, return_value=return_annotation, impl=impl, callable=callable, diff --git a/pyanalyze/stubs/_pyanalyze_tests-stubs/args.pyi b/pyanalyze/stubs/_pyanalyze_tests-stubs/args.pyi new file mode 100644 index 00000000..b94eef04 --- /dev/null +++ b/pyanalyze/stubs/_pyanalyze_tests-stubs/args.pyi @@ -0,0 +1,11 @@ +from typing import Tuple +from typing_extensions import Unpack, TypedDict + +class TD(TypedDict): + x: int + y: str + +def f(*args: Unpack[Tuple[int, str]]) -> None: ... +def g(**kwargs: Unpack[TD]) -> None: ... +def h(*args: int) -> None: ... +def i(**kwargs: str) -> None: ... diff --git a/pyanalyze/test_signature.py b/pyanalyze/test_signature.py index a4cfdf20..fbe470cf 100644 --- a/pyanalyze/test_signature.py +++ b/pyanalyze/test_signature.py @@ -1174,3 +1174,53 @@ def prop(self: "Capybara[int]") -> int: def caller(ci: Capybara[int], cs: Capybara[str]): assert_is_value(ci.prop, TypedValue(int)) cs.prop # E: incompatible_argument + + +class TestUnpack(TestNameCheckVisitorBase): + @assert_passes() + def test_args(self): + from typing_extensions import Unpack + from typing import Tuple + + def f(*args: Unpack[Tuple[int, str]]) -> None: + assert_is_value( + args, make_simple_sequence(tuple, [TypedValue(int), TypedValue(str)]) + ) + + def capybara(): + f(1, "x") + f(1) # E: incompatible_call + f(1, 1) # E: incompatible_argument + + @assert_passes() + def test_kwargs(self): + from typing_extensions import Unpack, TypedDict, NotRequired, Required + + class TD(TypedDict): + a: NotRequired[int] + b: Required[str] + + def capybara(**kwargs: Unpack[TD]): + assert_is_value( + kwargs, + TypedDictValue( + {"a": (False, TypedValue(int)), "b": (True, TypedValue(str))} + ), + ) + + def caller(): + capybara(a=1, b="x") + capybara(b="x") + capybara() # E: incompatible_call + capybara(a="x", b="x") # E: incompatible_argument + capybara(a=1, b="x", c=3) # E: incompatible_call + + @assert_passes() + def test_invalid(self): + from typing_extensions import Unpack + + def unpack_that_int(*args: Unpack[int]) -> None: # E: invalid_annotation + assert_is_value(args, AnyValue(AnySource.error)) + + def bad_kwargs(**kwargs: Unpack[None]) -> None: # E: invalid_annotation + assert_is_value(kwargs, AnyValue(AnySource.error)) diff --git a/pyanalyze/test_typeshed.py b/pyanalyze/test_typeshed.py index b047e6e5..c8980b18 100644 --- a/pyanalyze/test_typeshed.py +++ b/pyanalyze/test_typeshed.py @@ -270,6 +270,22 @@ def f(x: _ScandirIterator): want_cm(x) len(x) # E: incompatible_argument + @assert_passes() + def test_args_kwargs(self): + def capybara(): + from _pyanalyze_tests.args import f, g, h, i + + f(1) # E: incompatible_call + f(1, "x") + g(x=1) # E: incompatible_call + g(x=1, y="x") + h("x") # E: incompatible_argument + h() + h(1) + i(x=3) # E: incompatible_argument + i(x="x") + i() + class TestConstructors(TestNameCheckVisitorBase): @assert_passes() diff --git a/pyanalyze/typeshed.py b/pyanalyze/typeshed.py index 20c5d5d8..c481d439 100644 --- a/pyanalyze/typeshed.py +++ b/pyanalyze/typeshed.py @@ -4,6 +4,7 @@ """ +from pyanalyze.functions import translate_vararg_type from .node_visitor import Failure from .options import Options, PathSequenceOption from .extensions import evaluated @@ -829,20 +830,16 @@ def _get_signature_from_func_def( arguments = arguments[1:] if args.vararg is not None: - vararg_param = self._parse_param( - args.vararg, None, mod, ParameterKind.VAR_POSITIONAL + arguments.append( + self._parse_param(args.vararg, None, mod, ParameterKind.VAR_POSITIONAL) ) - annotation = GenericValue(tuple, [vararg_param.annotation]) - arguments.append(replace(vararg_param, annotation=annotation)) arguments += self._parse_param_list( args.kwonlyargs, args.kw_defaults, mod, ParameterKind.KEYWORD_ONLY ) if args.kwarg is not None: - kwarg_param = self._parse_param( - args.kwarg, None, mod, ParameterKind.VAR_KEYWORD + arguments.append( + self._parse_param(args.kwarg, None, mod, ParameterKind.VAR_KEYWORD) ) - annotation = GenericValue(dict, [TypedValue(str), kwarg_param.annotation]) - arguments.append(replace(kwarg_param, annotation=annotation)) # some typeshed types have a positional-only after a normal argument, # and Signature doesn't like that seen_non_positional = False @@ -898,7 +895,9 @@ def _parse_param( ) -> SigParameter: typ = AnyValue(AnySource.unannotated) if arg.annotation is not None: - typ = self._parse_type(arg.annotation, module) + typ = self._parse_type( + arg.annotation, module, allow_unpack=kind.allow_unpack() + ) elif objclass is not None: bases = self.get_bases(objclass) if bases is None: @@ -925,6 +924,7 @@ def _parse_param( ): kind = ParameterKind.POSITIONAL_ONLY name = name[2:] + typ = translate_vararg_type(kind, typ, self.ctx) # Mark self as positional-only. objclass should be given only if we believe # it's the "self" parameter. if objclass is not None: @@ -942,11 +942,18 @@ def _parse_expr(self, node: ast.AST, module: str) -> Value: return value_from_ast(node, ctx=ctx) def _parse_type( - self, node: ast.AST, module: str, *, is_typeddict: bool = False + self, + node: ast.AST, + module: str, + *, + is_typeddict: bool = False, + allow_unpack: bool = False, ) -> Value: val = self._parse_expr(node, module) ctx = _AnnotationContext(finder=self, module=module) - typ = type_from_value(val, ctx=ctx, is_typeddict=is_typeddict) + typ = type_from_value( + val, ctx=ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack + ) if self.verbose and isinstance(typ, AnyValue): self.log("Got Any", (ast.dump(node), module)) return typ diff --git a/pyanalyze/value.py b/pyanalyze/value.py index d4d699d4..0bf5b2a3 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -1973,6 +1973,22 @@ def simplify(self) -> Value: return AnnotatedValue(self.value.simplify(), self.metadata) +@dataclass(frozen=True) +class UnpackedValue(Value): + """Represents the result of PEP 646's Unpack operator.""" + + value: Value + + def get_elements(self) -> Optional[Sequence[Tuple[bool, Value]]]: + if isinstance(self.value, SequenceValue) and self.value.typ is tuple: + return self.value.members + elif isinstance(self.value, GenericValue) and self.value.typ is tuple: + return [(True, self.value.args[0])] + elif isinstance(self.value, TypedValue) and self.value.typ is tuple: + return [(True, AnyValue(AnySource.generic_argument))] + return None + + @dataclass(frozen=True) class VariableNameValue(AnyValue): """Value that is stored in a variable associated with a particular kind of value.