Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compile fixture to test integration ops with compile mode #387

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tripy/tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest

import tripy as tp


@pytest.fixture(params=["compile", "eager"])
def eager_or_compiled(request):
def wrapper(func, *args, **kwargs):
def get_input_info(x: tp.Tensor):
return tp.InputInfo(list(map(int, x.shape)), dtype=x.dtype)

if request.param == "eager":
return func(*args, **kwargs)

assert request.param == "compile"

compile_args = []
for arg in args:
# We don't want to feed DimensionSize as a dynamic input to the compiler (https://github.com/NVIDIA/TensorRT-Incubator/issues/65).
if isinstance(arg, tp.Tensor) and not isinstance(arg, tp.DimensionSize):
compile_args.append(get_input_info(arg))
else:
compile_args.append(arg)
compile_args = tuple(compile_args)

compile_kwargs = dict(
(
k,
((get_input_info(v) if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize) else v)),
)
for k, v in kwargs.items()
)

compiled_func = tp.compile(func, args=compile_args, kwargs=compile_kwargs)

tensor_args = tuple(x for x in args if isinstance(x, tp.Tensor) and not isinstance(x, tp.DimensionSize))

tensor_kwargs = {
k: v for k, v in kwargs.items() if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize)
}

return compiled_func(*tensor_args, **tensor_kwargs)

return wrapper
4 changes: 2 additions & 2 deletions tripy/tests/integration/test_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestBatchNorm:

@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
@pytest.mark.parametrize("input_shape", [(2, 2, 2, 2)])
def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape):
def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, eager_or_compiled):
eps = 1e-5
num_features = input_shape[1] # Number of channels in the input tensor
batchnorm = torch.nn.BatchNorm2d(num_features=num_features, eps=eps, dtype=torch_dtype)
Expand All @@ -45,7 +45,7 @@ def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape):
input = torch.randn(input_shape, dtype=torch_dtype).to("cuda")
tp_input = tp.Tensor(input, dtype=tp_dtype)

output = tp_batchnorm(tp_input)
output = eager_or_compiled(tp_batchnorm, tp_input)

batchnorm.to("cuda").eval()
with torch.no_grad():
Expand Down
27 changes: 13 additions & 14 deletions tripy/tests/integration/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,54 +30,53 @@ class TestCast:
[
(np.int32, np.float32),
(np.float32, np.int32),
(np.int64, np.float32),
(np.float32, np.int64),
(np.int64, np.int32),
(np.int64, np.int8),
(np.int32, np.int8),
(np.float32, np.int8),
(np.int8, np.int64),
(np.int8, np.int32),
(np.int8, np.float32),
# important to test conversion into bool because default StableHLO semantics
# are simply to truncate to i1, which is not desirable
(np.float32, bool),
(np.int32, bool),
(np.int64, bool),
# requires a dequantization first
# TODO(#219): Dequantize fails with dynamic shapes
# (np.int8, bool),
],
)
def test_cast(self, input_dtype, target_dtype):
def test_cast(self, input_dtype, target_dtype, eager_or_compiled):
tp_input_dtype = NUMPY_TO_TRIPY[input_dtype]
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]

# TODO(#222): Integer casts with negative numbers fail in many cases
input_tensor = tp.Tensor([0, 1, 2], dtype=tp_input_dtype)
np_input = cp.from_dlpack(input_tensor).get()
output = tp.cast(input_tensor, tp_target_dtype)
output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype)

assert np.array_equal(cp.from_dlpack(output).get(), np_input.astype(target_dtype))

# these dtypes don't have analogues in numpy
@pytest.mark.parametrize("source_dtype", [pytest.param(tp.float8, marks=skip_if_older_than_sm89), tp.int4])
def test_cast_quantized_dtypes_into_bool(self, source_dtype):
def test_cast_quantized_dtypes_into_bool(self, source_dtype, eager_or_compiled):
# TODO(#223): Using an odd size leads to a strange crash, so can't just use [-1.0, 0.0, 1.0]
input_tensor = tp.Tensor([-1.0, 0.0, 0.0, 1.0], dtype=tp.float32)
q = tp.quantize(input_tensor, scale=1.0, dtype=source_dtype)
output = tp.cast(q, tp.bool)

def func(input):
q = tp.quantize(input, scale=1.0, dtype=source_dtype)
output = tp.cast(q, tp.bool)
return output

output = eager_or_compiled(func, input_tensor)
assert cp.from_dlpack(output).get().tolist() == [True, False, False, True]

@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int64, np.int8])
def test_cast_from_bool(self, target_dtype):
@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int8])
def test_cast_from_bool(self, target_dtype, eager_or_compiled):
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]

# in principle, it is not important what *specific* values we convert to,
# so long as false is mapped to 0 and true to nonzero
input_tensor = tp.Tensor([False, True], dtype=tp.bool)
np_input = cp.from_dlpack(input_tensor).get()
output = tp.cast(input_tensor, tp_target_dtype)
output = eager_or_compiled(tp.cast, input_tensor, tp_target_dtype)

tp_compare_to_zero = cp.from_dlpack(output).get() == 0
np_compare_to_zero = np_input.astype(target_dtype) == 0
Expand Down
8 changes: 4 additions & 4 deletions tripy/tests/integration/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class TestConcatenate:
([(2, 3, 4)], 0),
],
)
def test_concat(self, tensor_shapes, dim):
def test_concat(self, tensor_shapes, dim, eager_or_compiled):
tensors = [tp.ones(shape) for shape in tensor_shapes]
out = tp.concatenate(tensors, dim=dim)
out = eager_or_compiled(tp.concatenate, tensors, dim=dim)
assert np.array_equal(
cp.from_dlpack(out).get(), np.concatenate([np.ones(shape) for shape in tensor_shapes], axis=dim)
)
Expand All @@ -44,8 +44,8 @@ def test_concat(self, tensor_shapes, dim):
"tensor_shapes, dim",
[([(2, 3, 4), (2, 4, 4)], 0), ([(4, 5, 6), (4, 1, 6)], -1)],
)
def test_negative_concat(self, tensor_shapes, dim):
def test_negative_concat(self, tensor_shapes, dim, eager_or_compiled):
tensors = [tp.ones(shape) for shape in tensor_shapes]
with helper.raises(tp.TripyException, match=f"not compatible at non-concat index"):
out = tp.concatenate(tensors, dim=dim)
out = eager_or_compiled(tp.concatenate, tensors, dim=dim)
print(out)
16 changes: 8 additions & 8 deletions tripy/tests/integration/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ConvTestCase:
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
class TestConvolution:
@pytest.mark.parametrize("test_case", test_cases_1d)
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)

# FP32 kernel seems to lose some precision, and FP16 needs to be run in FP32 on torch
rtol_ = 4e-5 if tp_dtype == tp.float32 else 1e-3
Expand All @@ -131,7 +131,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
assert list(output_torch.shape) == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_2d)
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -178,15 +178,15 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 1.5e-3
output_torch = torch.from_dlpack(output)
assert torch.allclose(output_torch, expected, rtol=rtol_)
assert list(output_torch.shape) == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_3d)
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
pytest.skip("TODO (#260): Fix accuracy bugs in 3D conv")
if not test_case.torch_pad:
test_case.torch_pad = 0
Expand Down Expand Up @@ -245,14 +245,14 @@ def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
return

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)

rtol_ = 2e-4 if tp_dtype == tp.float32 else 1.4e-3 # 3d conv has greater accumulation error
output_torch = torch.from_dlpack(output)
assert torch.allclose(output_torch, expected, rtol=rtol_)
assert list(output_torch.shape) == list(expected.shape)

def test_uneven_padding(self, torch_dtype, tp_dtype):
def test_uneven_padding(self, torch_dtype, tp_dtype, eager_or_compiled):
input_torch = torch.arange(200, dtype=torch.float32, device=torch.device("cuda")).reshape(*(2, 4, 5, 5))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -282,7 +282,7 @@ def test_uneven_padding(self, torch_dtype, tp_dtype):

input_torch = torch_pad(input_torch)
expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 2e-3
output_torch = torch.from_dlpack(output)
Expand Down
24 changes: 12 additions & 12 deletions tripy/tests/integration/test_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ConvTestCase:
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
class TestConvolution:
@pytest.mark.parametrize("test_case", test_cases_transpose_1d)
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -129,14 +129,14 @@ def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)

rtol_ = 1e-3
rtol_ = 3e-3
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_transpose_2d)
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -184,14 +184,14 @@ def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)

rtol_ = 1e-2
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_transpose_3d)
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -239,12 +239,12 @@ def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)
rtol_ = 1.3e-6 if tp_dtype == tp.float32 else 1.6e-3
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == list(expected.shape)

def test_transposed_equivalency(self, torch_dtype, tp_dtype):
def test_transposed_equivalency(self, torch_dtype, tp_dtype, eager_or_compiled):
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -277,8 +277,8 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):

expected = conv_layer_torch(input_torch).to(torch_dtype)
expected_transpose = conv_transpose_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output_transpose = conv_transpose_layer(input)
output = eager_or_compiled(conv_layer, input)
output_transpose = eager_or_compiled(conv_transpose_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
Expand All @@ -291,7 +291,7 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):
assert list(expected.shape) == list(expected_transpose.shape)

@pytest.mark.parametrize("test_case", test_cases_transpose_downscale)
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case, eager_or_compiled):
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -320,7 +320,7 @@ def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
conv_layer.weight = tp.cast(tp.Tensor(conv_layer_torch.weight.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = eager_or_compiled(conv_layer, input)

rtol_ = 1e-15 if tp_dtype == tp.float32 else 1e-10
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
Expand Down
5 changes: 2 additions & 3 deletions tripy/tests/integration/test_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ class TestCumsum:
([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 0, [[[1, 2], [3, 4]], [[6, 8], [10, 12]]]),
],
)
def test_cumsum(self, data, dim, expected):
def test_cumsum(self, data, dim, expected, eager_or_compiled):
inp = tp.Tensor(data, dtype=tp.float32)

out = tp.cumsum(inp, dim=dim)

out = eager_or_compiled(tp.cumsum, inp, dim=dim)
expected = tp.Tensor(expected, dtype=tp.float32)
assert tp.allclose(out, expected)
assert out.shape == expected.shape
16 changes: 12 additions & 4 deletions tripy/tests/integration/test_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,36 @@ class TestDequantize:
@pytest.mark.parametrize(
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
def test_dequantize_int8_per_tensor(self, dtype):
def test_dequantize_int8_per_tensor(self, dtype, eager_or_compiled):
data = [4, 8]
input_tp = tp.Tensor(data, dtype=tp.int8)
scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype])
scale_tp = tp.Tensor(scale, dtype=dtype)
dequantized = tp.dequantize(input_tp, scale_tp, dtype)

def func(input):
return tp.dequantize(input, scale_tp, dtype)

dequantized = eager_or_compiled(func, input_tp)
expected = torch.tensor(data) * scale
output = torch.from_dlpack(dequantized)
assert torch.allclose(expected, output.to("cpu"))

@pytest.mark.parametrize(
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
def test_dequantize_int8_per_channel(self, dtype):
def test_dequantize_int8_per_channel(self, dtype, eager_or_compiled):
# TODO: Fix in #153
if dtype == tp.float16:
pytest.skip("TRT does not support fp16->int8 per-channel dequant.")
data = [[4, 8], [4, 8]]
input_tp = tp.Tensor(data, dtype=tp.int8)
scale = torch.tensor([0.8, 0.9], dtype=TORCH_DTYPES[dtype])
scale_tp = tp.Tensor(scale, dtype=dtype)
dequantized = tp.dequantize(input_tp, scale_tp, dtype, dim=0)

def func(input):
return tp.dequantize(input, scale_tp, dtype, dim=0)

dequantized = eager_or_compiled(func, input_tp)
expected = torch.tensor(data) * scale.reshape((2, 1))
output = torch.from_dlpack(dequantized)
assert torch.allclose(expected, output.to("cpu"))
Expand Down
Loading