From 49603a67a657616bcc1c020336e5c122781e9a5c Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Thu, 24 Mar 2022 19:06:12 -0400 Subject: [PATCH] Use ParamSpec in jit annotation --- jax/_src/api.py | 19 ++++++++++--------- jax/_src/stages.py | 11 ++++++++--- setup.py | 1 + 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 32648deccc0a..021b39c3dcb3 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -24,17 +24,15 @@ import collections from collections.abc import Generator, Hashable, Iterable, Sequence +from contextlib import contextmanager, ExitStack from functools import partial import inspect import math -import typing -from typing import (Any, Callable, Literal, - NamedTuple, Optional, TypeVar, Union, - overload, cast) +from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload import weakref import numpy as np -from contextlib import contextmanager, ExitStack +from typing_extensions import ParamSpec from jax._src import linear_util as lu from jax._src import stages @@ -104,6 +102,9 @@ F = TypeVar("F", bound=Callable) T = TypeVar("T") U = TypeVar("U") +V_co = TypeVar("V_co", covariant=True) +P = ParamSpec("P") + map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -148,7 +149,7 @@ def _update_debug_special_thread_local(_): def jit( - fun: Callable, + fun: Callable[P, V_co], in_shardings=sharding_impls.UNSPECIFIED, out_shardings=sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, @@ -160,7 +161,7 @@ def jit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, -) -> stages.Wrapped: +) -> stages.Wrapped[P, V_co]: """Sets up ``fun`` for just-in-time compilation with XLA. Args: @@ -1812,7 +1813,7 @@ def cache_miss(*args, **kwargs): ### Decide whether we can support the C++ fast path use_fastpath = False if execute is not None and isinstance(execute, pxla.ExecuteReplicated): - execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) + execute_replicated = cast(pxla.ExecuteReplicated, execute) use_fastpath = ( # TODO(sharadmv): Enable effects in replicated computation not execute_replicated.has_unordered_effects @@ -1822,7 +1823,7 @@ def cache_miss(*args, **kwargs): ### If we can use the fastpath, we return required info to the caller. if use_fastpath: - execute_replicated = typing.cast(pxla.ExecuteReplicated, execute) + execute_replicated = cast(pxla.ExecuteReplicated, execute) out_handler = execute_replicated.out_handler in_handler = execute_replicated.in_handler diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3af6be3d6f4f..9fa7f8eb1ae9 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -32,7 +32,8 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, NamedTuple, Protocol, Union +from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union +from typing_extensions import ParamSpec import jax @@ -638,7 +639,11 @@ def cost_analysis(self) -> Any | None: return None -class Wrapped(Protocol): +V_co = TypeVar("V_co", covariant=True) +P = ParamSpec("P") + + +class Wrapped(Protocol, Generic[P, V_co]): """A function ready to be specialized, lowered, and compiled. This protocol reflects the output of functions such as @@ -647,7 +652,7 @@ class Wrapped(Protocol): to compilation, and the result compiled prior to execution. """ - def __call__(self, *args, **kwargs): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> V_co: """Executes the wrapped function, lowering and compiling as needed.""" raise NotImplementedError diff --git a/setup.py b/setup.py index 36294aad6936..9f6156ebd01a 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ def generate_proto(source): # Python versions < 3.10. Can be dropped when 3.10 is the minimum # required Python version. 'importlib_metadata>=4.6;python_version<"3.10"', + 'typing_extensions>=4.5.0', ], extras_require={ # Minimum jaxlib version; used in testing.