From 465683977b4757d88a9c1b573c95669697b2c677 Mon Sep 17 00:00:00 2001 From: Roland Zimmermann <5895436+zimmerrol@users.noreply.github.com> Date: Thu, 15 Apr 2021 14:22:33 +0200 Subject: [PATCH] Jax tensors are not correctly recognized if they are stored on GPUs (#31) * Fix jax on GPUs * Simplify switch statement --- eagerpy/astensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eagerpy/astensor.py b/eagerpy/astensor.py index 836b7c8..f547179 100644 --- a/eagerpy/astensor.py +++ b/eagerpy/astensor.py @@ -43,11 +43,12 @@ def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore # to avoid importing all the frameworks name = _get_module_name(x) m = sys.modules + if name == "torch" and isinstance(x, m[name].Tensor): # type: ignore return PyTorchTensor(x) if name == "tensorflow" and isinstance(x, m[name].Tensor): # type: ignore return TensorFlowTensor(x) - if name == "jax" and isinstance(x, m[name].numpy.ndarray): # type: ignore + if (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray): # type: ignore return JAXTensor(x) if name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore return NumPyTensor(x)