From efe651e3bae7a279b60cda151b4bedf31e605bd6 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta <62377868+frazane@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:23:07 +0200 Subject: [PATCH] use torch.as_tensor to preserve gradients (#79) --- scoringrules/backend/registry.py | 8 +++----- scoringrules/backend/torch.py | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/scoringrules/backend/registry.py b/scoringrules/backend/registry.py index 3b174b6..7c4aae5 100644 --- a/scoringrules/backend/registry.py +++ b/scoringrules/backend/registry.py @@ -59,11 +59,9 @@ def __getitem__(self, __key: str) -> ArrayBackend: """Get a backend from the registry.""" try: return super().__getitem__(__key) - except KeyError as err: - raise BackendNotRegistered( - f"The backend '{__key}' is not registered. " - f"You can register it with scoringrules.register_backend('{__key}')" - ) from err + except KeyError: + self.register_backend(__key) + return super().__getitem__(__key) def set_active(self, backend: str): self._active = backend diff --git a/scoringrules/backend/torch.py b/scoringrules/backend/torch.py index eb92717..933271d 100644 --- a/scoringrules/backend/torch.py +++ b/scoringrules/backend/torch.py @@ -32,7 +32,8 @@ def asarray( *, dtype: Dtype | None = None, ) -> "Tensor": - return torch.asarray(obj, dtype=dtype) + # torch.asarray(obj) would cancel gradients! + return torch.as_tensor(obj, dtype=dtype) def broadcast_arrays(self, *arrays: "Tensor") -> tuple["Tensor", ...]: return torch.broadcast_tensors(*arrays)