diff --git a/jax/_src/api.py b/jax/_src/api.py index cbc75b305208..44852e335fc2 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -28,11 +28,13 @@ from functools import partial import inspect import math -from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload +from typing import (Any, Callable, Literal, NamedTuple, TypeVar, cast, + overload, TYPE_CHECKING) import weakref import numpy as np -from typing_extensions import ParamSpec +if TYPE_CHECKING: + from typing_extensions import ParamSpec from jax._src import linear_util as lu from jax._src import stages @@ -94,7 +96,8 @@ T = TypeVar("T") U = TypeVar("U") V_co = TypeVar("V_co", covariant=True) -P = ParamSpec("P") +if TYPE_CHECKING: + P = ParamSpec("P") map, unsafe_map = safe_map, map diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 0e050817b0e7..2bb677ec40f1 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -32,8 +32,10 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union -from typing_extensions import ParamSpec +from typing import (Any, Generic, NamedTuple, Protocol, TypeVar, Union, + TYPE_CHECKING) +if TYPE_CHECKING: + from typing_extensions import ParamSpec import jax @@ -648,7 +650,10 @@ def cost_analysis(self) -> Any | None: V_co = TypeVar("V_co", covariant=True) -P = ParamSpec("P") +if TYPE_CHECKING: + P = ParamSpec("P") +else: + P = TypeVar("P") class Wrapped(Protocol, Generic[P, V_co]): diff --git a/setup.py b/setup.py index bec977b14d0a..857c42f23cff 100644 --- a/setup.py +++ b/setup.py @@ -84,9 +84,9 @@ def generate_proto(source): # Python versions < 3.10. Can be dropped when 3.10 is the minimum # required Python version. 'importlib_metadata>=4.6;python_version<"3.10"', - 'typing_extensions>=4.5.0', ], extras_require={ + 'dev': ['typing_extensions>=4.8.0'], # Minimum jaxlib version; used in testing. 'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],