From 7b223ea968d73f99c0376872359ccc7efa5d6fb8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 22 Jul 2024 12:29:13 +0200 Subject: [PATCH] Fix `_cumsum` helper function in multi-gpu (#2636) * fixing staff * add test case * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> (cherry picked from commit f7eeace11325b95a21f2abca42c5c9ec83ecfaaa) --- src/torchmetrics/utilities/data.py | 2 +- tests/unittests/utilities/test_utilities.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 739c9b09710..2a6f7882d60 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -215,7 +215,7 @@ def _cumsum(x: Tensor, dim: Optional[int] = 0, dtype: Optional[torch.dtype] = No "Expect some slowdowns.", TorchMetricsUserWarning, ) - return x.cpu().cumsum(dim=dim, dtype=dtype).cuda() + return x.cpu().cumsum(dim=dim, dtype=dtype).to(x.device) return torch.cumsum(x, dim=dim, dtype=dtype) diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index e61b2ec0e33..ee46adc349d 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -188,7 +188,9 @@ def test_cumsum_still_not_supported(use_deterministic_algorithms): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") def test_custom_cumsum(use_deterministic_algorithms): """Test custom cumsum implementation.""" - x = torch.arange(100).float().cuda() + # check that cumsum works as expected on non-default cuda device + device = torch.device("cuda:1") if torch.cuda.device_count() > 1 else torch.device("cuda:0") + x = torch.arange(100).float().to(device) if sys.platform != "win32": with pytest.warns( TorchMetricsUserWarning, match="You are trying to use a metric in deterministic mode on GPU that.*"