Skip to content

Commit

Permalink
Fix _cumsum helper function in multi-gpu (#2636)
Browse files Browse the repository at this point in the history
* fixing staff
* add test case
* Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
SkafteNicki and Borda authored Jul 22, 2024
1 parent 55f036e commit f7eeace
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _cumsum(x: Tensor, dim: Optional[int] = 0, dtype: Optional[torch.dtype] = No
"Expect some slowdowns.",
TorchMetricsUserWarning,
)
return x.cpu().cumsum(dim=dim, dtype=dtype).cuda()
return x.cpu().cumsum(dim=dim, dtype=dtype).to(x.device)
return torch.cumsum(x, dim=dim, dtype=dtype)


Expand Down
4 changes: 3 additions & 1 deletion tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def test_cumsum_still_not_supported(use_deterministic_algorithms):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU")
def test_custom_cumsum(use_deterministic_algorithms):
"""Test custom cumsum implementation."""
x = torch.arange(100).float().cuda()
# check that cumsum works as expected on non-default cuda device
device = torch.device("cuda:1") if torch.cuda.device_count() > 1 else torch.device("cuda:0")
x = torch.arange(100).float().to(device)
if sys.platform != "win32":
with pytest.warns(
TorchMetricsUserWarning, match="You are trying to use a metric in deterministic mode on GPU that.*"
Expand Down

0 comments on commit f7eeace

Please sign in to comment.