diff --git a/tripy/tests/integration/test_conv_transpose.py b/tripy/tests/integration/test_conv_transpose.py index 6f81dec50..88a03fbf1 100644 --- a/tripy/tests/integration/test_conv_transpose.py +++ b/tripy/tests/integration/test_conv_transpose.py @@ -83,7 +83,7 @@ class ConvTestCase: @pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES) class TestConvolution: @pytest.mark.parametrize("test_case", test_cases_transpose_1d) - def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case): + def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case, compile_fixture): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -131,14 +131,14 @@ def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) rtol_ = 1e-3 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) assert output.shape == expected.shape @pytest.mark.parametrize("test_case", test_cases_transpose_2d) - def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case): + def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case, compile_fixture): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -186,14 +186,14 @@ def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) rtol_ = 1e-3 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) assert output.shape == expected.shape @pytest.mark.parametrize("test_case", test_cases_transpose_3d) - def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case): + def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case, compile_fixture): if not test_case.torch_pad: test_case.torch_pad = 0 if not test_case.stride: @@ -241,12 +241,12 @@ def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case): conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) rtol_ = 1.3e-6 if tp_dtype == tp.float32 else 1.6e-3 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) assert output.shape == expected.shape - def test_transposed_equivalency(self, torch_dtype, tp_dtype): + def test_transposed_equivalency(self, torch_dtype, tp_dtype, compile_fixture): input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3)) input = tp.cast(tp.Tensor(input_torch), tp_dtype) @@ -279,8 +279,8 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype): expected = conv_layer_torch(input_torch).to(torch_dtype) expected_transpose = conv_transpose_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) - output_transpose = conv_transpose_layer(input) + output = compile_fixture(conv_layer, input) + output_transpose = compile_fixture(conv_transpose_layer, input) rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_) @@ -293,7 +293,7 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype): assert expected.shape == expected_transpose.shape @pytest.mark.parametrize("test_case", test_cases_transpose_downscale) - def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case): + def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case, compile_fixture): input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3)) input = tp.cast(tp.Tensor(input_torch), tp_dtype) @@ -322,7 +322,7 @@ def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case): conv_layer.weight = tp.cast(tp.Tensor(conv_layer_torch.weight.data), tp_dtype) expected = conv_layer_torch(input_torch).to(torch_dtype) - output = conv_layer(input) + output = compile_fixture(conv_layer, input) rtol_ = 1e-15 if tp_dtype == tp.float32 else 1e-10 assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)