diff --git a/jax/_src/api.py b/jax/_src/api.py index 62347fe1fc86..cbc75b305208 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -24,16 +24,15 @@ import collections from collections.abc import Generator, Hashable, Iterable, Sequence +from contextlib import contextmanager, ExitStack from functools import partial import inspect import math -import typing -from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, - cast) +from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload import weakref import numpy as np -from contextlib import contextmanager, ExitStack +from typing_extensions import ParamSpec from jax._src import linear_util as lu from jax._src import stages @@ -94,6 +93,9 @@ F = TypeVar("F", bound=Callable) T = TypeVar("T") U = TypeVar("U") +V_co = TypeVar("V_co", covariant=True) +P = ParamSpec("P") + map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -139,7 +141,7 @@ def _update_debug_special_thread_local(_): def jit( - fun: Callable, + fun: Callable[P, V_co], in_shardings=sharding_impls.UNSPECIFIED, out_shardings=sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, @@ -151,7 +153,7 @@ def jit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, -) -> stages.Wrapped: +) -> stages.Wrapped[P, V_co]: """Sets up ``fun`` for just-in-time compilation with XLA. Args: @@ -1804,7 +1806,7 @@ def cache_miss(*args, **kwargs): ### Decide whether we can support the C++ fast path use_fastpath = False if execute is not None and isinstance(execute, pxla.ExecuteReplicated): - execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) + execute_replicated = cast(pxla.ExecuteReplicated, execute) use_fastpath = ( # TODO(sharadmv): Enable effects in replicated computation not execute_replicated.has_unordered_effects @@ -1814,7 +1816,7 @@ def cache_miss(*args, **kwargs): ### If we can use the fastpath, we return required info to the caller. if use_fastpath: - execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) + execute_replicated = cast(pxla.ExecuteReplicated, execute) out_handler = execute_replicated.out_handler in_handler = execute_replicated.in_handler diff --git a/jax/_src/stages.py b/jax/_src/stages.py index c470d1da2c6a..0e050817b0e7 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -32,7 +32,8 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, NamedTuple, Protocol, Union +from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union +from typing_extensions import ParamSpec import jax @@ -646,7 +647,11 @@ def cost_analysis(self) -> Any | None: return None -class Wrapped(Protocol): +V_co = TypeVar("V_co", covariant=True) +P = ParamSpec("P") + + +class Wrapped(Protocol, Generic[P, V_co]): """A function ready to be specialized, lowered, and compiled. This protocol reflects the output of functions such as @@ -655,7 +660,7 @@ class Wrapped(Protocol): to compilation, and the result compiled prior to execution. """ - def __call__(self, *args, **kwargs): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> V_co: """Executes the wrapped function, lowering and compiling as needed.""" raise NotImplementedError diff --git a/setup.py b/setup.py index 711a4e9d7a90..bec977b14d0a 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ def generate_proto(source): # Python versions < 3.10. Can be dropped when 3.10 is the minimum # required Python version. 'importlib_metadata>=4.6;python_version<"3.10"', + 'typing_extensions>=4.5.0', ], extras_require={ # Minimum jaxlib version; used in testing.