Skip to content

Commit

Permalink
tests: fixture for use_deterministic_algorithms (#2351)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Feb 7, 2024
1 parent 6f89034 commit dc115fe
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
8 changes: 8 additions & 0 deletions tests/unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
USE_PYTEST_POOL = os.getenv("USE_PYTEST_POOL", "0") == "1"


@pytest.fixture()
def use_deterministic_algorithms(): # noqa: PT004
"""Set deterministic algorithms for the test."""
torch.use_deterministic_algorithms(True)
yield
torch.use_deterministic_algorithms(False)


def setup_ddp(rank, world_size):
"""Initialize ddp environment."""
global CURRENT_PORT
Expand Down
14 changes: 3 additions & 11 deletions tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,12 @@ def test_flatten_dict():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")
def test_bincount():
def test_bincount(use_deterministic_algorithms):
"""Test that bincount works in deterministic setting on GPU."""
torch.use_deterministic_algorithms(True)

x = torch.randint(10, size=(100,))
# uses custom implementation
res1 = _bincount(x, minlength=10)

torch.use_deterministic_algorithms(False)

# uses torch.bincount
res2 = _bincount(x, minlength=10)

Expand Down Expand Up @@ -183,22 +179,19 @@ def test_recursive_allclose(inputs, expected):
@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_13, reason="earlier versions was silently non-deterministic, even in deterministic mode"
)
def test_cumsum_still_not_supported():
def test_cumsum_still_not_supported(use_deterministic_algorithms):
"""Make sure that cumsum on gpu and deterministic mode still fails.
If this test begins to pass, it means newer Pytorch versions support this and we can drop internal support.
"""
torch.use_deterministic_algorithms(True)
with pytest.raises(RuntimeError, match="cumsum_cuda_kernel does not have a deterministic implementation.*"):
torch.arange(10).float().cuda().cumsum(0)
torch.use_deterministic_algorithms(False)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU")
def test_custom_cumsum():
def test_custom_cumsum(use_deterministic_algorithms):
"""Test custom cumsum implementation."""
torch.use_deterministic_algorithms(True)
x = torch.arange(100).float().cuda()
if sys.platform != "win32":
with pytest.warns(
Expand All @@ -207,7 +200,6 @@ def test_custom_cumsum():
res = _cumsum(x, dim=0).cpu()
else:
res = _cumsum(x, dim=0).cpu()
torch.use_deterministic_algorithms(False)
res2 = np.cumsum(x.cpu(), axis=0)
assert torch.allclose(res, res2)

Expand Down

0 comments on commit dc115fe

Please sign in to comment.