diff --git a/eagerpy/astensor.py b/eagerpy/astensor.py index 836b7c8..f547179 100644 --- a/eagerpy/astensor.py +++ b/eagerpy/astensor.py @@ -43,11 +43,12 @@ def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore # to avoid importing all the frameworks name = _get_module_name(x) m = sys.modules + if name == "torch" and isinstance(x, m[name].Tensor): # type: ignore return PyTorchTensor(x) if name == "tensorflow" and isinstance(x, m[name].Tensor): # type: ignore return TensorFlowTensor(x) - if name == "jax" and isinstance(x, m[name].numpy.ndarray): # type: ignore + if (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray): # type: ignore return JAXTensor(x) if name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore return NumPyTensor(x)