-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functions decorated with @util._wraps(...)
lose their type annotations.
#16863
Comments
Thanks for the report – this is a known issue, and historically has not been easy to fix because the solution would require nontrivial protocol types which were not available in earlier Python versions, and even when available have not been well-supported by type checkers like mypy. Now that we're on Python 3.8 and type checkers have evolved somewhat, it may be worth trying to support this again. Is this something you're interested in exploring? |
from jax import jit, Array
import jax.numpy as jnp
def add(a: Array, b: Array) -> Array:
return a
jit_add = jit(add)
a = jnp.array([0])
b = a
x = add(a, b)
jit_x = jit_add(a, b)
reveal_type(jit_add) # Wrapped
reveal_type(x) # Array
reveal_type(jit_x) # Any I think adding a generic argument to def jit(f: Callable[..., T], ...) -> Wrapped[T] This would fix functions that are only annotated with a plain A lot of functions in jax have from collections.abc import Callable
import functools
from typing import TypeVar
T = TypeVar("T")
def decorator(f: Callable[..., T]) -> Callable[..., T]:
return f
def one() -> int:
return 1
reveal_type(functools.partial(decorator)(one)) # def (*Any, **Any) -> T
reveal_type(functools.partial(decorator)(one)()) # T It might be possible to specialiize from collections.abc import Callable
import functools
from typing import Any, TypeVar
T = TypeVar("T")
class DecoratorPartial(functools.partial[Callable[..., Any]]):
def __call__(self, f: Callable[..., T]) -> Callable[..., T]:
return super()(f) Do you think this could be added to the library? |
That sort of thing has been discussed – see e.g. #14688 My own opinion is that if the deficiencies of the type checker necessitate changing clean, pythonic coding style (e.g. not using |
Description
I've noticed that the functions in
_src/numpy/lax_numpy.py
which are decorated with@util._wraps(...)
lose their type annotation.For example,
jnp.clip
is decorated with@util._wraps(np.clip, skip_params=['out'])
and for the following code:mypy says
x_clipped
has typeAny
, and Pyright says it has typeUnknown
, it should be typejax._src.basearray.Array
. Inlax_numpy
the functionclip
is properly annotated before this decorator.What jax/jaxlib version are you using?
0.4.13
Which accelerator(s) are you using?
CPU
Additional system info
python 3.10.12 on MacOS
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: