Skip to content

Commit

Permalink
Add compile fixture to test integration ops with compile mode
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha committed Nov 19, 2024
1 parent d37e4f8 commit fe949f7
Show file tree
Hide file tree
Showing 33 changed files with 301 additions and 191 deletions.
86 changes: 86 additions & 0 deletions tripy/tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest
import inspect

import tripy as tp


@pytest.fixture(params=["compile", "eager"])
def execution_mode(request):
return request.param


@pytest.fixture
def compile_fixture(execution_mode):
def wrapper(func, *args, exclude_arg_names=None, **kwargs):
def get_shape(x: tp.Tensor):
x.eval()
return tp.InputInfo(x.trace_tensor.shape, dtype=x.dtype)

exclude_arg_names = set() if exclude_arg_names is None else set(exclude_arg_names)

if execution_mode == "compile":
# For args, we need to get their parameter names from function signature
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())

compile_args = []
for i, arg in enumerate(args):
if i < len(param_names) and param_names[i] in exclude_arg_names:
compile_args.append(arg)
elif isinstance(arg, tp.Tensor) and not isinstance(arg, tp.DimensionSize):
compile_args.append(get_shape(arg))
else:
compile_args.append(arg)
compile_args = tuple(compile_args)

compile_kwargs = dict(
(
k,
(
v
if k in exclude_arg_names
else (get_shape(v) if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize) else v)
),
)
for k, v in kwargs.items()
)

compiled_func = tp.compile(func, args=compile_args, kwargs=compile_kwargs)

tensor_args = tuple(
x
for i, x in enumerate(args)
if isinstance(x, tp.Tensor)
and not isinstance(x, tp.DimensionSize)
and (i >= len(param_names) or param_names[i] not in exclude_arg_names)
)

tensor_kwargs = {
k: v
for k, v in kwargs.items()
if isinstance(v, tp.Tensor) and not isinstance(v, tp.DimensionSize) and k not in exclude_arg_names
}

return compiled_func(*tensor_args, **tensor_kwargs)

elif execution_mode == "eager":
return func(*args, **kwargs)

return wrapper
4 changes: 2 additions & 2 deletions tripy/tests/integration/test_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestBatchNorm:

@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
@pytest.mark.parametrize("input_shape", [(2, 2, 2, 2)])
def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape):
def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, compile_fixture):
eps = 1e-5
num_features = input_shape[1] # Number of channels in the input tensor
batchnorm = torch.nn.BatchNorm2d(num_features=num_features, eps=eps, dtype=torch_dtype)
Expand All @@ -45,7 +45,7 @@ def test_batchnorm_accuracy(self, torch_dtype, tp_dtype, input_shape):
input = torch.randn(input_shape, dtype=torch_dtype).to("cuda")
tp_input = tp.Tensor(input, dtype=tp_dtype)

output = tp_batchnorm(tp_input)
output = compile_fixture(tp_batchnorm, tp_input)

batchnorm.to("cuda").eval()
with torch.no_grad():
Expand Down
27 changes: 13 additions & 14 deletions tripy/tests/integration/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,54 +30,53 @@ class TestCast:
[
(np.int32, np.float32),
(np.float32, np.int32),
(np.int64, np.float32),
(np.float32, np.int64),
(np.int64, np.int32),
(np.int64, np.int8),
(np.int32, np.int8),
(np.float32, np.int8),
(np.int8, np.int64),
(np.int8, np.int32),
(np.int8, np.float32),
# important to test conversion into bool because default StableHLO semantics
# are simply to truncate to i1, which is not desirable
(np.float32, bool),
(np.int32, bool),
(np.int64, bool),
# requires a dequantization first
# TODO(#219): Dequantize fails with dynamic shapes
# (np.int8, bool),
],
)
def test_cast(self, input_dtype, target_dtype):
def test_cast(self, input_dtype, target_dtype, compile_fixture):
tp_input_dtype = NUMPY_TO_TRIPY[input_dtype]
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]

# TODO(#222): Integer casts with negative numbers fail in many cases
input_tensor = tp.Tensor([0, 1, 2], dtype=tp_input_dtype)
np_input = cp.from_dlpack(input_tensor).get()
output = tp.cast(input_tensor, tp_target_dtype)
output = compile_fixture(tp.cast, input_tensor, tp_target_dtype)

assert np.array_equal(cp.from_dlpack(output).get(), np_input.astype(target_dtype))

# these dtypes don't have analogues in numpy
@pytest.mark.parametrize("source_dtype", [pytest.param(tp.float8, marks=skip_if_older_than_sm89), tp.int4])
def test_cast_quantized_dtypes_into_bool(self, source_dtype):
def test_cast_quantized_dtypes_into_bool(self, source_dtype, compile_fixture):
# TODO(#223): Using an odd size leads to a strange crash, so can't just use [-1.0, 0.0, 1.0]
input_tensor = tp.Tensor([-1.0, 0.0, 0.0, 1.0], dtype=tp.float32)
q = tp.quantize(input_tensor, scale=1.0, dtype=source_dtype)
output = tp.cast(q, tp.bool)

def func(input):
q = tp.quantize(input, scale=1.0, dtype=source_dtype)
output = tp.cast(q, tp.bool)
return output

output = compile_fixture(func, input_tensor)
assert cp.from_dlpack(output).get().tolist() == [True, False, False, True]

@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int64, np.int8])
def test_cast_from_bool(self, target_dtype):
@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int8])
def test_cast_from_bool(self, target_dtype, compile_fixture):
tp_target_dtype = NUMPY_TO_TRIPY[target_dtype]

# in principle, it is not important what *specific* values we convert to,
# so long as false is mapped to 0 and true to nonzero
input_tensor = tp.Tensor([False, True], dtype=tp.bool)
np_input = cp.from_dlpack(input_tensor).get()
output = tp.cast(input_tensor, tp_target_dtype)
output = compile_fixture(tp.cast, input_tensor, tp_target_dtype)

tp_compare_to_zero = cp.from_dlpack(output).get() == 0
np_compare_to_zero = np_input.astype(target_dtype) == 0
Expand Down
8 changes: 4 additions & 4 deletions tripy/tests/integration/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class TestConcatenate:
([(2, 3, 4)], 0),
],
)
def test_concat(self, tensor_shapes, dim):
def test_concat(self, tensor_shapes, dim, compile_fixture):
tensors = [tp.ones(shape) for shape in tensor_shapes]
out = tp.concatenate(tensors, dim=dim)
out = compile_fixture(tp.concatenate, tensors, dim=dim)
assert np.array_equal(
cp.from_dlpack(out).get(), np.concatenate([np.ones(shape) for shape in tensor_shapes], axis=dim)
)
Expand All @@ -44,8 +44,8 @@ def test_concat(self, tensor_shapes, dim):
"tensor_shapes, dim",
[([(2, 3, 4), (2, 4, 4)], 0), ([(4, 5, 6), (4, 1, 6)], -1)],
)
def test_negative_concat(self, tensor_shapes, dim):
def test_negative_concat(self, tensor_shapes, dim, compile_fixture):
tensors = [tp.ones(shape) for shape in tensor_shapes]
with helper.raises(tp.TripyException, match=f"not compatible at non-concat index"):
out = tp.concatenate(tensors, dim=dim)
out = compile_fixture(tp.concatenate, tensors, dim=dim)
print(out)
16 changes: 8 additions & 8 deletions tripy/tests/integration/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ConvTestCase:
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
class TestConvolution:
@pytest.mark.parametrize("test_case", test_cases_1d)
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_1d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

# FP32 kernel seems to lose some precision, and FP16 needs to be run in FP32 on torch
rtol_ = 4e-5 if tp_dtype == tp.float32 else 1e-3
Expand All @@ -131,7 +131,7 @@ def test_convolution_1d(self, torch_dtype, tp_dtype, test_case):
assert list(output_torch.shape) == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_2d)
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_2d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -178,15 +178,15 @@ def test_convolution_2d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 1.5e-3
output_torch = torch.from_dlpack(output)
assert torch.allclose(output_torch, expected, rtol=rtol_)
assert list(output_torch.shape) == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_3d)
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
def test_convolution_3d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
pytest.skip("TODO (#260): Fix accuracy bugs in 3D conv")
if not test_case.torch_pad:
test_case.torch_pad = 0
Expand Down Expand Up @@ -245,14 +245,14 @@ def test_convolution_3d(self, torch_dtype, tp_dtype, test_case):
return

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 2e-4 if tp_dtype == tp.float32 else 1.4e-3 # 3d conv has greater accumulation error
output_torch = torch.from_dlpack(output)
assert torch.allclose(output_torch, expected, rtol=rtol_)
assert list(output_torch.shape) == list(expected.shape)

def test_uneven_padding(self, torch_dtype, tp_dtype):
def test_uneven_padding(self, torch_dtype, tp_dtype, compile_fixture):
input_torch = torch.arange(200, dtype=torch.float32, device=torch.device("cuda")).reshape(*(2, 4, 5, 5))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -282,7 +282,7 @@ def test_uneven_padding(self, torch_dtype, tp_dtype):

input_torch = torch_pad(input_torch)
expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 2e-3
output_torch = torch.from_dlpack(output)
Expand Down
24 changes: 12 additions & 12 deletions tripy/tests/integration/test_conv_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ConvTestCase:
@pytest.mark.parametrize("torch_dtype,tp_dtype", DTYPES)
class TestConvolution:
@pytest.mark.parametrize("test_case", test_cases_transpose_1d)
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -129,14 +129,14 @@ def test_transposed_convolution_1d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 1e-3
rtol_ = 3e-3
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_transpose_2d)
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -184,14 +184,14 @@ def test_transposed_convolution_2d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 1e-2
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == list(expected.shape)

@pytest.mark.parametrize("test_case", test_cases_transpose_3d)
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case, compile_fixture):
if not test_case.torch_pad:
test_case.torch_pad = 0
if not test_case.stride:
Expand Down Expand Up @@ -239,12 +239,12 @@ def test_transposed_convolution_3d(self, torch_dtype, tp_dtype, test_case):
conv_layer.bias = tp.cast(tp.Tensor(conv_layer_torch.bias.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)
rtol_ = 1.3e-6 if tp_dtype == tp.float32 else 1.6e-3
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
assert output.shape == list(expected.shape)

def test_transposed_equivalency(self, torch_dtype, tp_dtype):
def test_transposed_equivalency(self, torch_dtype, tp_dtype, compile_fixture):
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -277,8 +277,8 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):

expected = conv_layer_torch(input_torch).to(torch_dtype)
expected_transpose = conv_transpose_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output_transpose = conv_transpose_layer(input)
output = compile_fixture(conv_layer, input)
output_transpose = compile_fixture(conv_transpose_layer, input)

rtol_ = 2e-7 if tp_dtype == tp.float32 else 9e-4
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
Expand All @@ -291,7 +291,7 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype):
assert list(expected.shape) == list(expected_transpose.shape)

@pytest.mark.parametrize("test_case", test_cases_transpose_downscale)
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case, compile_fixture):
input_torch = torch.arange(9, dtype=torch.float32, device=torch.device("cuda")).reshape(*(1, 1, 3, 3))
input = tp.cast(tp.Tensor(input_torch), tp_dtype)

Expand Down Expand Up @@ -320,7 +320,7 @@ def test_transposed_downscale(self, torch_dtype, tp_dtype, test_case):
conv_layer.weight = tp.cast(tp.Tensor(conv_layer_torch.weight.data), tp_dtype)

expected = conv_layer_torch(input_torch).to(torch_dtype)
output = conv_layer(input)
output = compile_fixture(conv_layer, input)

rtol_ = 1e-15 if tp_dtype == tp.float32 else 1e-10
assert tp.allclose(output, tp.Tensor(expected), rtol=rtol_)
Expand Down
5 changes: 2 additions & 3 deletions tripy/tests/integration/test_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ class TestCumsum:
([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 0, [[[1, 2], [3, 4]], [[6, 8], [10, 12]]]),
],
)
def test_cumsum(self, data, dim, expected):
def test_cumsum(self, data, dim, expected, compile_fixture):
inp = tp.Tensor(data, dtype=tp.float32)

out = tp.cumsum(inp, dim=dim)

out = compile_fixture(tp.cumsum, inp, dim=dim)
expected = tp.Tensor(expected, dtype=tp.float32)
assert tp.allclose(out, expected)
assert out.shape == expected.shape
Loading

0 comments on commit fe949f7

Please sign in to comment.