Skip to content

Commit

Permalink
Reject ParamSpec-typed callables calls with insufficient arguments (#…
Browse files Browse the repository at this point in the history
…17323)

Fixes #14571.

When type checking a call of a `ParamSpec`-typed callable, currently
there is an incorrect "fast path" (if there are two arguments of shape
`(*args: P.args, **kwargs: P.kwargs)`, accept), which breaks with
`Concatenate` (such call was accepted even for `Concatenate[int, P]`).

Also there was no checking that args and kwargs are actually present:
since `*args` and `**kwargs` are not required, their absence was
silently accepted.
  • Loading branch information
sterliakov authored Sep 24, 2024
1 parent cf3db99 commit 9518b6a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 6 deletions.
23 changes: 17 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,7 +1756,11 @@ def check_callable_call(
)

param_spec = callee.param_spec()
if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]:
if (
param_spec is not None
and arg_kinds == [ARG_STAR, ARG_STAR2]
and len(formal_to_actual) == 2
):
arg1 = self.accept(args[0])
arg2 = self.accept(args[1])
if (
Expand Down Expand Up @@ -2362,6 +2366,9 @@ def check_argument_count(
# Positional argument when expecting a keyword argument.
self.msg.too_many_positional_arguments(callee, context)
ok = False
elif callee.param_spec() is not None and not formal_to_actual[i]:
self.msg.too_few_arguments(callee, context, actual_names)
ok = False
return ok

def check_for_extra_actual_arguments(
Expand Down Expand Up @@ -2763,9 +2770,9 @@ def plausible_overload_call_targets(
) -> list[CallableType]:
"""Returns all overload call targets that having matching argument counts.
If the given args contains a star-arg (*arg or **kwarg argument), this method
will ensure all star-arg overloads appear at the start of the list, instead
of their usual location.
If the given args contains a star-arg (*arg or **kwarg argument, including
ParamSpec), this method will ensure all star-arg overloads appear at the start
of the list, instead of their usual location.
The only exception is if the starred argument is something like a Tuple or a
NamedTuple, which has a definitive "shape". If so, we don't move the corresponding
Expand Down Expand Up @@ -2793,9 +2800,13 @@ def has_shape(typ: Type) -> bool:
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i]
)

with self.msg.filter_errors():
if self.check_argument_count(
if typ.param_spec() is not None:
# ParamSpec can be expanded in a lot of different ways. We may try
# to expand it here instead, but picking an impossible overload
# is safe: it will be filtered out later.
star_matches.append(typ)
elif self.check_argument_count(
typ, arg_types, arg_kinds, arg_names, formal_to_actual, None
):
if args_have_var_arg and typ.is_var_arg:
Expand Down
111 changes: 111 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2192,3 +2192,114 @@ parametrize(_test, Case(1, b=2), Case(3, b=4))
parametrize(_test, Case(1, 2), Case(3))
parametrize(_test, Case(1, 2), Case(3, b=4))
[builtins fixtures/paramspec.pyi]

[case testRunParamSpecInsufficientArgs]
from typing_extensions import ParamSpec, Concatenate
from typing import Callable

_P = ParamSpec("_P")

def run(predicate: Callable[_P, None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here
predicate() # E: Too few arguments
predicate(*args) # E: Too few arguments
predicate(**kwargs) # E: Too few arguments
predicate(*args, **kwargs)

def fn() -> None: ...
def fn_args(x: int) -> None: ...
def fn_posonly(x: int, /) -> None: ...

run(fn)
run(fn_args, 1)
run(fn_args, x=1)
run(fn_posonly, 1)
run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run"

[builtins fixtures/paramspec.pyi]

[case testRunParamSpecConcatenateInsufficientArgs]
from typing_extensions import ParamSpec, Concatenate
from typing import Callable

_P = ParamSpec("_P")

def run(predicate: Callable[Concatenate[int, _P], None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here
predicate() # E: Too few arguments
predicate(1) # E: Too few arguments
predicate(1, *args) # E: Too few arguments
predicate(1, *args) # E: Too few arguments
predicate(1, **kwargs) # E: Too few arguments
predicate(*args, **kwargs) # E: Argument 1 has incompatible type "*_P.args"; expected "int"
predicate(1, *args, **kwargs)

def fn() -> None: ...
def fn_args(x: int, y: str) -> None: ...
def fn_posonly(x: int, /) -> None: ...
def fn_posonly_args(x: int, /, y: str) -> None: ...

run(fn) # E: Argument 1 to "run" has incompatible type "Callable[[], None]"; expected "Callable[[int], None]"
run(fn_args, 1, 'a') # E: Too many arguments for "run" \
# E: Argument 2 to "run" has incompatible type "int"; expected "str"
run(fn_args, y='a')
run(fn_args, 'a')
run(fn_posonly)
run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run"
run(fn_posonly_args) # E: Missing positional argument "y" in call to "run"
run(fn_posonly_args, 'a')
run(fn_posonly_args, y='a')

[builtins fixtures/paramspec.pyi]

[case testRunParamSpecConcatenateInsufficientArgsInDecorator]
from typing_extensions import ParamSpec, Concatenate
from typing import Callable

P = ParamSpec("P")

def decorator(fn: Callable[Concatenate[str, P], None]) -> Callable[P, None]:
def inner(*args: P.args, **kwargs: P.kwargs) -> None:
fn("value") # E: Too few arguments
fn("value", *args) # E: Too few arguments
fn("value", **kwargs) # E: Too few arguments
fn(*args, **kwargs) # E: Argument 1 has incompatible type "*P.args"; expected "str"
fn("value", *args, **kwargs)
return inner

@decorator
def foo(s: str, s2: str) -> None: ...

[builtins fixtures/paramspec.pyi]

[case testRunParamSpecOverload]
from typing_extensions import ParamSpec
from typing import Callable, NoReturn, TypeVar, Union, overload

P = ParamSpec("P")
T = TypeVar("T")

@overload
def capture(
sync_fn: Callable[P, NoReturn],
*args: P.args,
**kwargs: P.kwargs,
) -> int: ...
@overload
def capture(
sync_fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Union[T, int]: ...
def capture(
sync_fn: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Union[T, int]:
return sync_fn(*args, **kwargs)

def fn() -> str: return ''
def err() -> NoReturn: ...

reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.int]"
reveal_type(capture(err)) # N: Revealed type is "builtins.int"

[builtins fixtures/paramspec.pyi]

0 comments on commit 9518b6a

Please sign in to comment.