Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions decorated with @util._wraps(...) lose their type annotations. #16863

Open
packquickly opened this issue Jul 27, 2023 · 3 comments
Open
Labels
bug Something isn't working

Comments

@packquickly
Copy link

Description

I've noticed that the functions in _src/numpy/lax_numpy.py which are decorated with @util._wraps(...) lose their type annotation.

For example, jnp.clip is decorated with @util._wraps(np.clip, skip_params=['out']) and for the following code:

import jax.numpy as jnp
from typing import TYPE_CHECKING

x = jnp.arange(17)
x_clipped = jnp.clip(x, a_max=14)

if TYPE_CHECKING:
    reveal_type(x) # Array
    reveal_type(x_clipped) # Unknown or Any

mypy says x_clipped has type Any, and Pyright says it has type Unknown, it should be type jax._src.basearray.Array. In lax_numpy the function clip is properly annotated before this decorator.

What jax/jaxlib version are you using?

0.4.13

Which accelerator(s) are you using?

CPU

Additional system info

python 3.10.12 on MacOS

NVIDIA GPU info

No response

@packquickly packquickly added the bug Something isn't working label Jul 27, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 27, 2023

Thanks for the report – this is a known issue, and historically has not been easy to fix because the solution would require nontrivial protocol types which were not available in earlier Python versions, and even when available have not been well-supported by type checkers like mypy.

Now that we're on Python 3.8 and type checkers have evolved somewhat, it may be worth trying to support this again. Is this something you're interested in exploring?

@edwardwli
Copy link

jit returns Wrapped which has no generic arguments. So a jitted function always returns Any:

from jax import jit, Array
import jax.numpy as jnp


def add(a: Array, b: Array) -> Array:
    return a


jit_add = jit(add)

a = jnp.array([0])
b = a

x = add(a, b)
jit_x = jit_add(a, b)

reveal_type(jit_add) # Wrapped
reveal_type(x)       # Array
reveal_type(jit_x)   # Any

I think adding a generic argument to Wrapped would preserve the return type. Then the signature of jit could be changed to

def jit(f: Callable[..., T], ...) -> Wrapped[T]

This would fix functions that are only annotated with a plain @jit.

A lot of functions in jax have @functools.partial(jit, ...) annotations, and functools.partial doesn't pass through the return type of the decorated function:

from collections.abc import Callable
import functools
from typing import TypeVar

T = TypeVar("T")


def decorator(f: Callable[..., T]) -> Callable[..., T]:
    return f


def one() -> int:
    return 1


reveal_type(functools.partial(decorator)(one))   # def (*Any, **Any) -> T
reveal_type(functools.partial(decorator)(one)()) # T

It might be possible to specialiize functools.partial for decorators that preserve the decorated function's return type since __call__ should only have one required argument:

from collections.abc import Callable
import functools
from typing import Any, TypeVar

T = TypeVar("T")


class DecoratorPartial(functools.partial[Callable[..., Any]]):
    def __call__(self, f: Callable[..., T]) -> Callable[..., T]:
        return super()(f)

Do you think this could be added to the library?

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 19, 2023

That sort of thing has been discussed – see e.g. #14688

My own opinion is that if the deficiencies of the type checker necessitate changing clean, pythonic coding style (e.g. not using partial when it otherwise makes sense), it says that the type checker is not really ready for production, and I wouldn't want to spend too much energy trying to modify our runtime code satisfy it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants