diff --git a/jax/_src/api.py b/jax/_src/api.py index 597d0c057844..b3cd8b88b3da 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -24,16 +24,15 @@ import collections from collections.abc import Generator, Hashable, Iterable, Sequence +from contextlib import contextmanager, ExitStack from functools import partial import inspect import math -import typing -from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload, - cast) +from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload import weakref import numpy as np -from contextlib import contextmanager, ExitStack +from typing_extensions import ParamSpec from jax._src import linear_util as lu from jax._src import stages @@ -96,6 +95,9 @@ F = TypeVar("F", bound=Callable) T = TypeVar("T") U = TypeVar("U") +V_co = TypeVar("V_co", covariant=True) +P = ParamSpec("P") + map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -140,7 +142,7 @@ def _update_debug_special_thread_local(_): def jit( - fun: Callable, + fun: Callable[P, V_co], in_shardings=sharding_impls.UNSPECIFIED, out_shardings=sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, @@ -152,7 +154,7 @@ def jit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, -) -> pjit.JitWrapped: +) -> pjit.JitWrapped[P, V_co]: """Sets up ``fun`` for just-in-time compilation with XLA. Args: @@ -1820,7 +1822,7 @@ def cache_miss(*args, **kwargs): ### Decide whether we can support the C++ fast path use_fastpath = False if execute is not None and isinstance(execute, pxla.ExecuteReplicated): - execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) + execute_replicated = cast(pxla.ExecuteReplicated, execute) use_fastpath = ( # TODO(sharadmv): Enable effects in replicated computation not execute_replicated.has_unordered_effects @@ -1830,7 +1832,7 @@ def cache_miss(*args, **kwargs): ### If we can use the fastpath, we return required info to the caller. if use_fastpath: - execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) + execute_replicated = cast(pxla.ExecuteReplicated, execute) out_handler = execute_replicated.out_handler in_handler = execute_replicated.in_handler diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 2e36e92276fa..88a75e334737 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -32,9 +32,11 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, NamedTuple, Protocol, Union +from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union import warnings +from typing_extensions import ParamSpec + import jax from jax._src import core @@ -707,7 +709,11 @@ def cost_analysis(self) -> Any | None: return None -class Wrapped(Protocol): +V_co = TypeVar("V_co", covariant=True) +P = ParamSpec("P") + + +class Wrapped(Protocol, Generic[P, V_co]): """A function ready to be specialized, lowered, and compiled. This protocol reflects the output of functions such as @@ -716,7 +722,7 @@ class Wrapped(Protocol): to compilation, and the result compiled prior to execution. """ - def __call__(self, *args, **kwargs): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> V_co: """Executes the wrapped function, lowering and compiling as needed.""" raise NotImplementedError diff --git a/setup.py b/setup.py index 1c35aa0d859c..3ff3927a5ba2 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ def generate_proto(source): # Python versions < 3.10. Can be dropped when 3.10 is the minimum # required Python version. 'importlib_metadata>=4.6;python_version<"3.10"', + 'typing_extensions>=4.5.0', ], extras_require={ # Minimum jaxlib version; used in testing.