Skip to content

Commit

Permalink
enable compile fixture in more ops
Browse files Browse the repository at this point in the history
  • Loading branch information
markkraay committed Aug 12, 2024
1 parent 6f1b21e commit 1a386ed
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 30 deletions.
16 changes: 8 additions & 8 deletions tripy/tests/integration/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ConvTestCase:
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
class TestConvolution:
@pytest.mark.parametrize("test_case", test_cases_1d)
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

# FP32 kernel seems to lose some precision, and FP16 needs to be run in FP32 on torch
rtol_ = 4e-5 if tp_dtype == tp.float32 else 1e-3
Expand All @@ -131,7 +131,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
assert output_torch.shape == expected.shape

@pytest.mark.parametrize("test_case", test_cases_2d)
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -178,15 +178,15 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 1.5e-3
output_torch = torch.from_dlpack(output)
assert torch.allclose(output_torch, expected, rtol=rtol_)
assert output_torch.shape == expected.shape

@pytest.mark.parametrize("test_case", test_cases_3d)
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
pytest.skip("TODO (#260): Fix accuracy bugs in 3D conv")
if not test_case.torch_pad:
test_case.torch_pad = 0
Expand Down Expand Up @@ -245,14 +245,14 @@ def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
return

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 2e-4 if tp_dtype == tp.float32 else 1.4e-3 # 3d conv has greater accumulation error
output_torch = torch.from_dlpack(output)
assert torch.allclose(output_torch, expected, rtol=rtol_)
assert output_torch.shape == expected.shape

def test_uneven_padding(self, torch_dtype, tp_dtype):
def test_uneven_padding(self, torch_dtype, tp_dtype, compile_fixture):
input_torch = torch.arange(200, dtype=torch.float32, device=torch.device("cuda")).reshape(*(2, 4, 5, 5))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -282,7 +282,7 @@ def test_uneven_padding(self, torch_dtype, tp_dtype):

input_torch = torch_pad(input_torch)
expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 2e-3
output_torch = torch.from_dlpack(output)
Expand Down
21 changes: 10 additions & 11 deletions tripy/tests/integration/test_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,28 @@ class TestDequantize:
@pytest.mark.parametrize(
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
def test_dequantize_int8_per_tensor(self, dtype):
def test_dequantize_int8_per_tensor(self, dtype, compile_fixture):
data = [4, 8]
input_tp = tp.Tensor(data, dtype=tp.int8)
scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype])
scale_tp = tp.Tensor(scale, dtype=dtype)
dequantized = tp.dequantize(input_tp, scale_tp, dtype)
dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype)
expected = torch.tensor(data) * scale
output = torch.from_dlpack(dequantized)
assert torch.allclose(expected, output.to("cpu"))

@pytest.mark.parametrize(
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
def test_dequantize_int8_per_channel(self, dtype):
def test_dequantize_int8_per_channel(self, dtype, compile_fixture):
# TODO: Fix in #153
if dtype == tp.float16:
pytest.skip("TRT does not support fp16->int8 per-channel dequant.")
data = [[4, 8], [4, 8]]
input_tp = tp.Tensor(data, dtype=tp.int8)
scale = torch.tensor([0.8, 0.9], dtype=TORCH_DTYPES[dtype])
scale_tp = tp.Tensor(scale, dtype=dtype)
dequantized = tp.dequantize(input_tp, scale_tp, dtype, dim=0)
dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype, dim=0)
expected = torch.tensor(data) * scale.reshape((2, 1))
output = torch.from_dlpack(dequantized)
assert torch.allclose(expected, output.to("cpu"))
Expand All @@ -61,14 +61,13 @@ def test_dequantize_int8_per_channel(self, dtype):
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
@skip_if_older_than_sm89
def test_dequantize_fp8_per_tensor(self, dtype):
def test_dequantize_fp8_per_tensor(self, dtype, compile_fixture):
data_value = [1.0, 1.0]
input_tp = tp.Tensor(data_value, dtype=tp.float8)
scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype])
scale_tp = tp.Tensor(scale, dtype=dtype)
dequantized = tp.dequantize(input_tp, scale_tp, dtype)
dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype)
assert dequantized.dtype == dtype
print(dequantized)
expected = torch.Tensor(data_value) * scale
output = torch.from_dlpack(dequantized).to(dtype=torch.float32)
assert torch.allclose(expected, output.to("cpu"))
Expand All @@ -77,23 +76,23 @@ def test_dequantize_fp8_per_tensor(self, dtype):
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
@skip_if_older_than_sm89
def test_dequantize_fp8_per_channel(self, dtype):
def test_dequantize_fp8_per_channel(self, dtype, compile_fixture):
data_value = [[1.0, 1.0], [1.0, 1.0]]
input_tp = tp.Tensor(data_value, dtype=tp.float8)
scale = torch.tensor([0.8, 0.9], dtype=TORCH_DTYPES[dtype])
scale_tp = tp.Tensor(scale, dtype=dtype)
dequantized = tp.dequantize(input_tp, scale_tp, dtype, dim=0)
dequantized = compile_fixture(tp.dequantize, input_tp, scale_tp, dtype, dim=0)
assert dequantized.dtype == dtype
print(dequantized)
expected = torch.Tensor(data_value) * scale.reshape((2, 1))
output = torch.from_dlpack(dequantized).to(dtype=torch.float32)
assert torch.allclose(expected, output.to("cpu"))

def test_negative_non_constant_scale(self):
def test_negative_non_constant_scale(self, compile_fixture):
data = [[4, 8], [4, 8]]
input = tp.Tensor(data, dtype=tp.int8)
scale = tp.ones((2,))
dequantized = tp.dequantize(input, scale, tp.float32, dim=0)
dequantized = compile_fixture(tp.dequantize, input, scale, tp.float32, dim=0)
with raises(
tp.TripyException,
match="Scale must be a constant tensor in dequantize op",
Expand Down
16 changes: 8 additions & 8 deletions tripy/tests/integration/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@


class TestExpand:
def test_int_sizes(self):
def test_int_sizes(self, compile_fixture):
input = tp.ones((2, 1))
out = tp.expand(input, (-1, 2))
out = compile_fixture(tp.expand, input, (-1, 2))
assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 2), dtype=np.float32))

def test_shape_sizes(self):
def test_shape_sizes(self, compile_fixture):
input = tp.ones((2, 1))
a = tp.ones((2, 4))
out = tp.expand(input, a.shape)
out = compile_fixture(tp.expand, input, a.shape)
assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 4), dtype=np.float32))

def test_extra_dims(self):
def test_extra_dims(self, compile_fixture):
input = tp.ones((2, 1))
out = tp.expand(input, (1, -1, 2))
out = compile_fixture(tp.expand, input, (1, -1, 2))
assert np.array_equal(cp.from_dlpack(out).get(), np.ones((1, 2, 2), dtype=np.float32))

def test_mixed_sizes(self):
def test_mixed_sizes(self, compile_fixture):
input = tp.ones((2, 1, 1))
a = tp.ones((4, 4))
out = tp.expand(input, (-1, a.shape[0], a.shape[1]))
out = compile_fixture(tp.expand, input, (-1, a.shape[0], a.shape[1]))
assert np.array_equal(cp.from_dlpack(out).get(), np.ones((2, 4, 4), dtype=np.float32))
6 changes: 3 additions & 3 deletions tripy/tests/integration/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class TestFlip:
"dims",
[0, 1, None, [0, 1], [1, 0], -1, -2, [0, -1], [-2, 1]],
)
def test_flip(self, dims):
def test_flip(self, dims, compile_fixture):
cp_a = cp.arange(16).reshape((4, 4)).astype(cp.float32)
a = tp.Tensor(cp_a, device=tp.device("gpu"))
f = tp.flip(a, dims=dims)
f = compile_fixture(tp.flip, a, dims=dims)
assert np.array_equal(cp.from_dlpack(f).get(), np.flip(cp_a.get(), axis=dims))

# also ensure that flipping a second time restores the original value
f2 = tp.flip(f, dims=dims)
f2 = compile_fixture(tp.flip, f, dims=dims)
assert cp.array_equal(cp.from_dlpack(f2), cp_a)

def test_no_op(self):
Expand Down

0 comments on commit 1a386ed

Please sign in to comment.