diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index e4ee19c..db779d8 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -192,6 +192,7 @@ def f(...): ... if ( _tb_flag and importlib.util.find_spec("jax") is not None + and importlib.util.find_spec("jaxlib") is not None and importlib.util.find_spec("jax._src.traceback_util") is not None ): import jax._src.traceback_util as traceback_util