Skip to content

Commit

Permalink
use torch.as_tensor to preserve gradients (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane authored Oct 23, 2024
1 parent 0c1738a commit efe651e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 3 additions & 5 deletions scoringrules/backend/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def __getitem__(self, __key: str) -> ArrayBackend:
"""Get a backend from the registry."""
try:
return super().__getitem__(__key)
except KeyError as err:
raise BackendNotRegistered(
f"The backend '{__key}' is not registered. "
f"You can register it with scoringrules.register_backend('{__key}')"
) from err
except KeyError:
self.register_backend(__key)
return super().__getitem__(__key)

def set_active(self, backend: str):
self._active = backend
Expand Down
3 changes: 2 additions & 1 deletion scoringrules/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def asarray(
*,
dtype: Dtype | None = None,
) -> "Tensor":
return torch.asarray(obj, dtype=dtype)
# torch.asarray(obj) would cancel gradients!
return torch.as_tensor(obj, dtype=dtype)

def broadcast_arrays(self, *arrays: "Tensor") -> tuple["Tensor", ...]:
return torch.broadcast_tensors(*arrays)
Expand Down

0 comments on commit efe651e

Please sign in to comment.