From d53cec8312da48008b4f80271156712770c58354 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Tue, 19 Nov 2024 11:26:12 -0800 Subject: [PATCH] Allow passing None for dim to iota apis --- tripy/tests/integration/test_iota.py | 8 ++------ tripy/tests/integration/test_quantize.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tripy/tests/integration/test_iota.py b/tripy/tests/integration/test_iota.py index 7962f411b..44cb38ab6 100644 --- a/tripy/tests/integration/test_iota.py +++ b/tripy/tests/integration/test_iota.py @@ -49,17 +49,13 @@ def _compute_ref_iota(self, dtype, shape, dim): "shape, dim", [ ((2, 3), 1), - ((2, 3), None), + ((2, 3), 0), ((2, 3), -1), ((2, 3, 4), 2), ], ) def test_iota(self, dtype, shape, dim, eager_or_compiled): - if dim: - output = eager_or_compiled(tp.iota, shape, dim, dtype[1]) - else: - output = eager_or_compiled(tp.iota, shape, dtype=dtype[1]) - + output = eager_or_compiled(tp.iota, shape, dim, dtype[1]) assert np.array_equal(cp.from_dlpack(output).get(), self._compute_ref_iota(dtype[0], shape, dim)) @pytest.mark.parametrize("dtype", DTYPE_PARAMS) diff --git a/tripy/tests/integration/test_quantize.py b/tripy/tests/integration/test_quantize.py index 6757e6dd7..ee458d108 100644 --- a/tripy/tests/integration/test_quantize.py +++ b/tripy/tests/integration/test_quantize.py @@ -44,7 +44,17 @@ def func(input): assert torch.equal(expected, torch.from_dlpack(quantized).to("cpu")) @pytest.mark.parametrize( - "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] + "dtype", + [ + tp.float32, + pytest.param( + tp.float16, + marks=pytest.mark.skip( + reason="Known float16 precision issues due to https://github.com/NVIDIA/TensorRT-Incubator/issues/392" + ), + ), + pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80), + ], ) def test_quantize_int8_per_channel(self, dtype, eager_or_compiled): input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=TORCH_DTYPES[dtype])