From ed78ba505b11e5712b1764216040d6128199e59b Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Sat, 25 Feb 2023 22:45:18 -0500 Subject: [PATCH] Make typing_extensions a dev-dependency --- jax/_src/api.py | 9 ++++++--- jax/_src/stages.py | 9 ++++++--- setup.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 021b39c3dcb3..37a11f0feb82 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -28,11 +28,13 @@ from functools import partial import inspect import math -from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload +from typing import (Any, Callable, Literal, NamedTuple, TypeVar, cast, + overload, TYPE_CHECKING) import weakref import numpy as np -from typing_extensions import ParamSpec +if TYPE_CHECKING: + from typing_extensions import ParamSpec from jax._src import linear_util as lu from jax._src import stages @@ -103,7 +105,8 @@ T = TypeVar("T") U = TypeVar("U") V_co = TypeVar("V_co", covariant=True) -P = ParamSpec("P") +if TYPE_CHECKING: + P = ParamSpec("P") map, unsafe_map = safe_map, map diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 9fa7f8eb1ae9..5c198aec03a1 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -32,8 +32,10 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union -from typing_extensions import ParamSpec +from typing import (Any, Generic, NamedTuple, Protocol, TypeVar, Union, + TYPE_CHECKING) +if TYPE_CHECKING: + from typing_extensions import ParamSpec import jax @@ -640,7 +642,8 @@ def cost_analysis(self) -> Any | None: V_co = TypeVar("V_co", covariant=True) -P = ParamSpec("P") +if TYPE_CHECKING: + P = ParamSpec("P") class Wrapped(Protocol, Generic[P, V_co]): diff --git a/setup.py b/setup.py index 9f6156ebd01a..dfc0d814dc9d 100644 --- a/setup.py +++ b/setup.py @@ -71,9 +71,9 @@ def generate_proto(source): # Python versions < 3.10. Can be dropped when 3.10 is the minimum # required Python version. 'importlib_metadata>=4.6;python_version<"3.10"', - 'typing_extensions>=4.5.0', ], extras_require={ + 'dev': ['typing_extensions>=4.5.0'], # Minimum jaxlib version; used in testing. 'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],