Skip to content

Commit

Permalink
enabled test_conv_transpose; all tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
markkraay committed Aug 14, 2024
1 parent f424b2b commit f032baa
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tripy/tests/integration/test_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ConvTestCase:
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
class TestConvolution:
@pytest.mark.parametrize("test_case", test_cases_transpose_1d)
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -131,14 +131,14 @@ def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 1e-3
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == expected.shape

@pytest.mark.parametrize("test_case", test_cases_transpose_2d)
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -186,14 +186,14 @@ def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 1e-3
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == expected.shape

@pytest.mark.parametrize("test_case", test_cases_transpose_3d)
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -241,12 +241,12 @@ def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)
rtol_ = 1.3e-6 if tp_dtype == tp.float32 else 1.6e-3
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == expected.shape

def test_transposed_equivalency(self, torch_dtype, tp_dtype):
def test_transposed_equivalency(self, torch_dtype, tp_dtype, compile_fixture):
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -279,8 +279,8 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):

expected = conv_layer_torch(input_torch).to(torch_dtype)
expected_transpose = conv_transpose_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output_transpose = conv_transpose_layer(input)
output = compile_fixture(conv_layer, input)
output_transpose = compile_fixture(conv_transpose_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
Expand All @@ -293,7 +293,7 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):
assert expected.shape == expected_transpose.shape

@pytest.mark.parametrize("test_case", test_cases_transpose_downscale)
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case, compile_fixture):
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -322,7 +322,7 @@ def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
conv_layer.weight = tp.cast(tp.Tensor(conv_layer_torch.weight.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 1e-15 if tp_dtype == tp.float32 else 1e-10
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
Expand Down

0 comments on commit f032baa

Please sign in to comment.