diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 739c9b09710..2a6f7882d60 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -215,7 +215,7 @@ def _cumsum(x: Tensor, dim: Optional[int] = 0, dtype: Optional[torch.dtype] = No "Expect some slowdowns.", TorchMetricsUserWarning, ) - return x.cpu().cumsum(dim=dim, dtype=dtype).cuda() + return x.cpu().cumsum(dim=dim, dtype=dtype).to(x.device) return torch.cumsum(x, dim=dim, dtype=dtype) diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index e61b2ec0e33..ee46adc349d 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -188,7 +188,9 @@ def test_cumsum_still_not_supported(use_deterministic_algorithms): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") def test_custom_cumsum(use_deterministic_algorithms): """Test custom cumsum implementation.""" - x = torch.arange(100).float().cuda() + # check that cumsum works as expected on non-default cuda device + device = torch.device("cuda:1") if torch.cuda.device_count() > 1 else torch.device("cuda:0") + x = torch.arange(100).float().to(device) if sys.platform != "win32": with pytest.warns( TorchMetricsUserWarning, match="You are trying to use a metric in deterministic mode on GPU that.*"