From 696cc5b90f0d997c91e54021b8e58b7f3811f54e Mon Sep 17 00:00:00 2001 From: Andy Rock <7538433+ar0ck@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:54:20 -0500 Subject: [PATCH] also require `jaxlib` --- jaxtyping/_decorator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index e4ee19c..db779d8 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -192,6 +192,7 @@ def f(...): ... if ( _tb_flag and importlib.util.find_spec("jax") is not None + and importlib.util.find_spec("jaxlib") is not None and importlib.util.find_spec("jax._src.traceback_util") is not None ): import jax._src.traceback_util as traceback_util