From 1a386edc76ef8e0ea5c739ea7c27f7a81b1a813a Mon Sep 17 00:00:00 2001 From: Mark Kraay Date: Tue, 6 Aug 2024 13:52:42 -0700 Subject: [PATCH] enable compile fixture in more ops --- tripy/tests/integration/test_convolution.py | 16 ++++++++-------- tripy/tests/integration/test_dequantize.py | 21 ++++++++++----------- tripy/tests/integration/test_expand.py | 16 ++++++++-------- tripy/tests/integration/test_flip.py | 6 +++--- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/tripy/tests/integration/test_convolution.py b/tripy/tests/integration/test_convolution.py index 343f66dbb..d14819f64 100644 --- a/tripy/tests/integration/test_convolution.py +++ b/tripy/tests/integration/test_convolution.py @@ -75,7 +75,7 @@ class ConvTestCase: @pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES) class TestConvolution: @pytest.mark.parametrize("test_case", test_cases_1d) - def test_convolution_1d(self, torch_dtype, tp_dtype, test_case): + def test_convolution_1d(self, torch_dtype, tp_dtype, test_case, compile_fixture): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -122,7 +122,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) # FP32 kernel seems to lose some precision, and FP16 needs to be run in FP32 on torch rtol_ = 4e-5 if tp_dtype == tp.float32 else 1e-3 @@ -131,7 +131,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case): assert output_torch.shape == expected.shape @pytest.mark.parametrize("test_case", test_cases_2d) - def test_convolution_2d(self, torch_dtype, tp_dtype, test_case): + def test_convolution_2d(self, torch_dtype, tp_dtype, test_case, compile_fixture): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -178,7 +178,7 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) rtol_ = 2e-7 if tp_dtype == tp.float32 else 1.5e-3 output_torch = torch.from_dlpack(output) @@ -186,7 +186,7 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case): assert output_torch.shape == expected.shape @pytest.mark.parametrize("test_case", test_cases_3d) - def test_convolution_3d(self, torch_dtype, tp_dtype, test_case): + def test_convolution_3d(self, torch_dtype, tp_dtype, test_case, compile_fixture): pytest.skip("TODO (#260): Fix accuracy bugs in 3D conv") if not test_case.torch_pad: test_case.torch_pad = 0 @@ -245,14 +245,14 @@ def test_convolution_3d(self, torch_dtype, tp_dtype, test_case): return expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) rtol_ = 2e-4 if tp_dtype == tp.float32 else 1.4e-3 # 3d conv has greater accumulation error output_torch = torch.from_dlpack(output) assert torch.allclose(output_torch, expected, rtol=rtol_) assert output_torch.shape == expected.shape - def test_uneven_padding(self, torch_dtype, tp_dtype): + def test_uneven_padding(self, torch_dtype, tp_dtype, compile_fixture): input_torch = torch.arange(200, dtype=torch.float32, device=torch.device("cuda")).reshape(*(2, 4, 5, 5)) input = tp.cast(tp.Tensor(input_torch), tp_dtype) @@ -282,7 +282,7 @@ def test_uneven_padding(self, torch_dtype, tp_dtype): input_torch = torch_pad(input_torch) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) rtol_ = 2e-7 if tp_dtype == tp.float32 else 2e-3 output_torch = torch.from_dlpack(output) diff --git a/tripy/tests/integration/test_dequantize.py b/tripy/tests/integration/test_dequantize.py index 7223e3429..40f06dde9 100644 --- a/tripy/tests/integration/test_dequantize.py +++ b/tripy/tests/integration/test_dequantize.py @@ -30,12 +30,12 @@ class TestDequantize: @pytest.mark.parametrize( "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) - def test_dequantize_int8_per_tensor(self, dtype): + def test_dequantize_int8_per_tensor(self, dtype, compile_fixture): data = [4, 8] input_tp = tp.Tensor(data, dtype=tp.int8) scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype]) scale_tp = tp.Tensor(scale, dtype=dtype) - dequantized = tp.dequantize(input_tp, scale_tp, dtype) + dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype) expected = torch.tensor(data) * scale output = torch.from_dlpack(dequantized) assert torch.allclose(expected, output.to("cpu")) @@ -43,7 +43,7 @@ def test_dequantize_int8_per_tensor(self, dtype): @pytest.mark.parametrize( "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) - def test_dequantize_int8_per_channel(self, dtype): + def test_dequantize_int8_per_channel(self, dtype, compile_fixture): # TODO: Fix in #153 if dtype == tp.float16: pytest.skip("TRT does not support fp16->int8 per-channel dequant.") @@ -51,7 +51,7 @@ def test_dequantize_int8_per_channel(self, dtype): input_tp = tp.Tensor(data, dtype=tp.int8) scale = torch.tensor([0.8, 0.9], dtype=TORCH_DTYPES[dtype]) scale_tp = tp.Tensor(scale, dtype=dtype) - dequantized = tp.dequantize(input_tp, scale_tp, dtype, dim=0) + dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype, dim=0) expected = torch.tensor(data) * scale.reshape((2, 1)) output = torch.from_dlpack(dequantized) assert torch.allclose(expected, output.to("cpu")) @@ -61,14 +61,13 @@ def test_dequantize_int8_per_channel(self, dtype): "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) @skip_if_older_than_sm89 - def test_dequantize_fp8_per_tensor(self, dtype): + def test_dequantize_fp8_per_tensor(self, dtype, compile_fixture): data_value = [1.0, 1.0] input_tp = tp.Tensor(data_value, dtype=tp.float8) scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype]) scale_tp = tp.Tensor(scale, dtype=dtype) - dequantized = tp.dequantize(input_tp, scale_tp, dtype) + dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype) assert dequantized.dtype == dtype - print(dequantized) expected = torch.Tensor(data_value) * scale output = torch.from_dlpack(dequantized).to(dtype=torch.float32) assert torch.allclose(expected, output.to("cpu")) @@ -77,23 +76,23 @@ def test_dequantize_fp8_per_tensor(self, dtype): "dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)] ) @skip_if_older_than_sm89 - def test_dequantize_fp8_per_channel(self, dtype): + def test_dequantize_fp8_per_channel(self, dtype, compile_fixture): data_value = [[1.0, 1.0], [1.0, 1.0]] input_tp = tp.Tensor(data_value, dtype=tp.float8) scale = torch.tensor([0.8, 0.9], dtype=TORCH_DTYPES[dtype]) scale_tp = tp.Tensor(scale, dtype=dtype) - dequantized = tp.dequantize(input_tp, scale_tp, dtype, dim=0) + dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype, dim=0) assert dequantized.dtype == dtype print(dequantized) expected = torch.Tensor(data_value) * scale.reshape((2, 1)) output = torch.from_dlpack(dequantized).to(dtype=torch.float32) assert torch.allclose(expected, output.to("cpu")) - def test_negative_non_constant_scale(self): + def test_negative_non_constant_scale(self, compile_fixture): data = [[4, 8], [4, 8]] input = tp.Tensor(data, dtype=tp.int8) scale = tp.ones((2,)) - dequantized = tp.dequantize(input, scale, tp.float32, dim=0) + dequantized = compile_fixture(tp.dequantize, input, scale, tp.float32, dim=0) with raises( tp.TripyException, match="Scale must be a constant tensor in dequantize op", diff --git a/tripy/tests/integration/test_expand.py b/tripy/tests/integration/test_expand.py index 3e144f564..ab9edc1f4 100644 --- a/tripy/tests/integration/test_expand.py +++ b/tripy/tests/integration/test_expand.py @@ -22,24 +22,24 @@ class TestExpand: - def test_int_sizes(self): + def test_int_sizes(self, compile_fixture): input = tp.ones((2, 1)) - out = tp.expand(input, (-1, 2)) + out = compile_fixture(tp.expand, input, (-1, 2)) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 2), dtype=np.float32)) - def test_shape_sizes(self): + def test_shape_sizes(self, compile_fixture): input = tp.ones((2, 1)) a = tp.ones((2, 4)) - out = tp.expand(input, a.shape) + out = compile_fixture(tp.expand, input, a.shape) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 4), dtype=np.float32)) - def test_extra_dims(self): + def test_extra_dims(self, compile_fixture): input = tp.ones((2, 1)) - out = tp.expand(input, (1, -1, 2)) + out = compile_fixture(tp.expand, input, (1, -1, 2)) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((1, 2, 2), dtype=np.float32)) - def test_mixed_sizes(self): + def test_mixed_sizes(self, compile_fixture): input = tp.ones((2, 1, 1)) a = tp.ones((4, 4)) - out = tp.expand(input, (-1, a.shape[0], a.shape[1])) + out = compile_fixture(tp.expand, input, (-1, a.shape[0], a.shape[1])) assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 4, 4), dtype=np.float32)) diff --git a/tripy/tests/integration/test_flip.py b/tripy/tests/integration/test_flip.py index a3703543f..6184f8560 100644 --- a/tripy/tests/integration/test_flip.py +++ b/tripy/tests/integration/test_flip.py @@ -26,14 +26,14 @@ class TestFlip: "dims", [0, 1, None, [0, 1], [1, 0], -1, -2, [0, -1], [-2, 1]], ) - def test_flip(self, dims): + def test_flip(self, dims, compile_fixture): cp_a = cp.arange(16).reshape((4, 4)).astype(cp.float32) a = tp.Tensor(cp_a, device=tp.device("gpu")) - f = tp.flip(a, dims=dims) + f = compile_fixture(tp.flip, a, dims=dims) assert np.array_equal(cp.from_dlpack(f).get(), np.flip(cp_a.get(), axis=dims)) # also ensure that flipping a second time restores the original value - f2 = tp.flip(f, dims=dims) + f2 = compile_fixture(tp.flip, f, dims=dims) assert cp.array_equal(cp.from_dlpack(f2), cp_a) def test_no_op(self):