diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index b4c5d45cbf8e..0939e25efddf 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -911,7 +911,9 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) if (t.is_uint()) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); - if (uval <= static_cast(std::numeric_limits::max())) { + if (value < static_cast(0)) { + LOG(FATAL) << "cannot make uint from negative value " << value; + } else if (uval <= static_cast(std::numeric_limits::max())) { return IntImm(t, static_cast(value), span); } else { uint64_t mask = (static_cast(1) << 32U) - 1U; @@ -932,6 +934,11 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return PrimExpr(); } +template <> +inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) { + return MakeConstScalar(t, static_cast(value), span); +} + template inline PrimExpr make_const(DataType t, ValueType value, Span span) { if (t.lanes() == 1) { diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 7a55d3ef244e..05426dfb1aeb 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -115,11 +115,17 @@ def _scalar_type_inference(value): elif isinstance(value, bool): dtype = "bool" elif isinstance(value, float): - # We intentionally convert the float to float32 since it's more common in DL. - dtype = "float32" + # We intentionally prefer convert the float to float32 since it's more common in DL. + if -3.40282347e38 <= value <= 3.40282347e38: + dtype = "float32" + else: + dtype = "float64" elif isinstance(value, int): - # We intentionally convert the python int to int32 since it's more common in DL. - dtype = "int32" + # We intentionally prefer convert the python int to int32 since it's more common in DL. + if -2147483648 <= value <= 2147483647: + dtype = "int32" + else: + dtype = "int64" else: raise NotImplementedError( "Cannot automatically inference the type." " value={}".format(value) diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index f3919afe5a24..bd9aa1fdadfd 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -89,6 +89,11 @@ def truncmod(x, y, span): return tvm.tir.truncmod(x, y, span) +@register +def truncdiv(x, y, span): + return tvm.tir.truncdiv(x, y, span) + + @register def ceildiv(x, y, span): return tvm.tir.ceildiv(x, y, span) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 9c3afe41b901..d0e09a1a7429 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -29,6 +29,7 @@ #include #include +#include #include "int_operator.h" @@ -73,6 +74,39 @@ inline bool IsIndexType(const DataType& type) { return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); } +/*! \brief Helper to get const folding result repr in int64. */ +inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) { + if (dtype.bits() < 64) { + x &= (1LL << dtype.bits()) - 1; + } + if (dtype.is_int()) { + // get sign extended value of integer with specified bits + int64_t m = 1LL << (dtype.bits() - 1); + x = (x ^ m) - m; + } + return x; +} + +/*! \brief Helper to get fp32 const folding result repr in double. */ +inline double GetFoldResultDoubleRepr(float x) { + double res = static_cast(x); + if (std::isinf(res) || std::isnan(res)) { + return res; + } + // certain platform (eg, on gcc7-i386) do the folding arithmetic + // on float and write back to double is optimized to double + // precision arithmetic, this is legal and we check the output + // range thus to ensure consistency when the float result is inf. + if (res < std::numeric_limits::lowest()) { + LOG(WARNING) << "underlying float value overflow"; + return -std::numeric_limits::infinity(); + } else if (res > std::numeric_limits::max()) { + LOG(WARNING) << "underlying float value overflow"; + return std::numeric_limits::infinity(); + } + return res; +} + #define TVM_ARITH_CONST_PROPAGATION(BODY) \ using tir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ @@ -95,10 +129,22 @@ template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value + pb->value); + if (pa && pb) { + int64_t res = pa->value + pb->value; + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + } if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value + fb->value); + if (fa && fb) { + if (rtype.bits() == 32) { + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) + + static_cast(fb->value))); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value + fb->value); + } else { + return PrimExpr(); + } + } if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); @@ -113,9 +159,21 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { << "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value - pb->value); + if (pa && pb) { + int64_t res = pa->value - pb->value; + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + } if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value - fb->value); + if (fa && fb) { + if (rtype.bits() == 32) { + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) - + static_cast(fb->value))); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value - fb->value); + } else { + return PrimExpr(); + } + } if (fb && fb->value == 0) return a; }); return PrimExpr(); @@ -125,7 +183,10 @@ template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value * pb->value); + if (pa && pb) { + int64_t res = pa->value * pb->value; + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + } if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; @@ -134,7 +195,16 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pb->value == 1) return a; if (pb->value == 0) return b; } - if (fa && fb) return FloatImm(rtype, fa->value * fb->value); + if (fa && fb) { + if (rtype.bits() == 32) { + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) * + static_cast(fb->value))); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value * fb->value); + } else { + return PrimExpr(); + } + } if (fa) { if (fa->value == 1) return b; if (fa->value == 0) return a; @@ -155,7 +225,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { // due to division and mod can have different modes // NOTE: this will assumes truc div. ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value / pb->value); + int64_t res = pa->value / pb->value; + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -165,7 +236,14 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImm(rtype, fa->value / fb->value); + if (rtype.bits() == 32) { + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) / + static_cast(fb->value))); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value / fb->value); + } else { + return PrimExpr(); + } } if (fa && fa->value == 0) return a; if (fb) { @@ -182,7 +260,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value % pb->value); + int64_t res = pa->value % pb->value; + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -201,7 +280,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, arith::floordiv(pa->value, pb->value)); + int64_t res = arith::floordiv(pa->value, pb->value); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -211,7 +291,14 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImm(rtype, std::floor(fa->value / fb->value)); + if (rtype.bits() == 32) { + return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast(fa->value) / + static_cast(fb->value)))); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, std::floor(fa->value / fb->value)); + } else { + return PrimExpr(); + } } if (fa && fa->value == 0) return a; if (fb) { @@ -228,7 +315,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, floormod(pa->value, pb->value)); + int64_t res = arith::floormod(pa->value, pb->value); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; diff --git a/src/ir/expr.cc b/src/ir/expr.cc index d3e23800d6c7..c926cc56e89a 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -33,6 +33,8 @@ #include #include +#include "../support/scalars.h" + namespace tvm { PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} @@ -76,7 +78,20 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied."; if (dtype.is_uint()) { - ICHECK_GE(value, 0U); + ICHECK_GE(value, 0U) << "ValueError: Literal value " << value + << " is negative for unsigned integer type " << dtype; + if (dtype.bits() < 64) { + ICHECK_LT(value, 1LL << dtype.bits()) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } + } else if (dtype.bits() == 1) { + // int(1) + ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype; + } else if (dtype.bits() < 64) { + ICHECK_GE(value, -(1LL << (dtype.bits() - 1))) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LT(value, 1LL << (dtype.bits() - 1)) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } ObjectPtr node = make_object(); node->dtype = dtype; @@ -103,6 +118,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; + + // check range for float32 and float16 since they have specified range. + if (!std::isinf(value) && !std::isnan(value)) { + if (dtype.bits() == 32) { + ICHECK_GE(value, std::numeric_limits::lowest()) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LE(value, std::numeric_limits::max()) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } else if (dtype.is_float16()) { + ICHECK_GE(value, -support::kMaxFloat16) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LE(value, support::kMaxFloat16) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } + } ObjectPtr node = make_object(); node->dtype = dtype; node->value = value; diff --git a/src/support/scalars.cc b/src/support/scalars.cc index 9caa7ca58915..0ab16899bae9 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -174,10 +174,6 @@ IntImm ValueToIntImm(int64_t value, int width) { } } -// 2^15 * (1 + 1023/1024) -// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format -constexpr double kMaxFloat16 = 65504.0; - FloatImm ValueToFloatImm(double value, int width) { if (width == 16) { if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) { diff --git a/src/support/scalars.h b/src/support/scalars.h index 60b8fc40a8de..2fdbb001d922 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -61,6 +61,10 @@ std::string FloatImmToString(const FloatImm& float_imm); IntImm ValueToIntImm(int64_t value, int width); FloatImm ValueToFloatImm(double value, int width); +// 2^15 * (1 + 1023/1024) +// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format +constexpr double kMaxFloat16 = 65504.0; + } // namespace support } // namespace tvm diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 89de2f6a9520..a8eb7f406c37 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -512,7 +512,7 @@ def verify( # Test backwards slicing. verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3)) # Test slicing with overlarge indices. - verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int64).max] * 3, [1, 1, 1], (3, 4, 3)) + verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int32).max] * 3, [1, 1, 1], (3, 4, 3)) # Test slice mode. verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index cacce5603e5f..fe662a30766c 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -777,7 +777,7 @@ def test_fuse_dynamic_squeeze_slice_take(): squeeze = relay.op.squeeze(x, axis=[0]) strided_slice = relay.op.strided_slice( - squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1] + squeeze, begin=[0, 0], end=[15130, 2147483647], strides=[1, 1] ) take = relay.op.take(strided_slice, take_val, axis=0) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 82e1372f991e..c880f90ddffe 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -951,6 +951,8 @@ def test_cast_simplify(): ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.tir.const(1, dtype1)) for dtype2 in dtypes: for i in [0, 1, 2, 3]: + if i > 1 and (dtype1 == "bool" or dtype2 == "bool"): + continue ck.verify(tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1)) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 994a85095728..96b947e20655 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -1,4 +1,5 @@ # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -194,13 +195,13 @@ def check_cuda(n, value, lanes): fun(a) np.testing.assert_equal(a.numpy(), np_a) - check_cuda(64, 0xAB, 4) + check_cuda(64, np.int8(0xAB), 4) check_cuda(64, 0, 4) check_cuda(64, -3, 4) - check_cuda(64, 0xAB, 3) + check_cuda(64, np.int8(0xAB), 3) check_cuda(64, 0, 3) check_cuda(64, -3, 3) - check_cuda(64, 0xAB, 2) + check_cuda(64, np.int8(0xAB), 2) check_cuda(64, 0, 2) check_cuda(64, -3, 2) diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py new file mode 100644 index 000000000000..a2a19a09ad87 --- /dev/null +++ b/tests/python/unittest/test_tir_imm_values.py @@ -0,0 +1,577 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +import random +import numpy as np +import tvm +import tvm.testing +import pytest +from tvm import tir +from tvm.script import tir as T +import pytest + + +@pytest.mark.parametrize( + "dtype, literals", + [ + ["int8", [-128, 0, 127]], + ["uint8", [0, 255]], + ["int32", [-2147483648, 2147483647]], + ["uint32", [0, 4294967295]], + ["int64", [-9223372036854775808, 9223372036854775807]], + ["uint64", [0, 9223372036854775807]], + ], +) +def test_tir_make_intimm(dtype, literals): + for l in literals: + imm = tir.const(l, dtype) + assert imm.value == l, imm + + +@pytest.mark.parametrize( + "dtype, literals", + [ + ["int8", [-129, 128]], + ["uint8", [-1, 256]], + ["int32", [-2147483650, 2147483648]], + ["uint32", [-1, 4294967296]], + ["uint64", [-1, 18446744073709551616]], + ], +) +def test_tir_invalid_intimm(dtype, literals): + for l in literals: + with pytest.raises(tvm.TVMError): + tir.const(l, dtype) + + +@pytest.mark.parametrize( + "dtype, literals", + [ + [ + "uint64", + { + 9223372036854775807: 9223372036854775807, + 18446744073709551615: 18446744073709551615, + }, + ], + ], +) +def test_tir_large_py_int_literals(dtype, literals): + """ + For large uint value, use LargeUIntImm intrin, + """ + for l in literals: + x = tir.const(l, dtype) + if isinstance(x, (tir.IntImm, tir.FloatImm)): + assert x.value == literals[l] + else: + # LargeUIntImm(low32, hi32) + assert (int(x.args[1]) << 32) + int(x.args[0]) == literals[l] + + +def test_tir_intimm_overflow(): + assert int(tir.const(255, "uint8") + tir.const(1, "uint8")) == 0 + assert int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) == -(2**31) + assert int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) == 0 + assert int(tir.const(2**63 - 1, "int64") + tir.const(1, "int64")) == -(2**63) + assert int(tir.const(2**32, "uint64") * tir.const(2**32, "uint64")) == 0 + # customized int types + assert int(tir.const(7, "int4") + tir.const(1, "int4")) == -8 + assert int(tir.const(2**39 - 1, "int40") + tir.const(1, "int40")) == -(2**39) + + +def compare_float_value(value, expect, msg): + if math.isfinite(value): + assert np.abs(value - expect) < 1e-5, f"{value} vs {expect}, {msg}" + elif math.isnan(value): + assert math.isnan(expect), f"{value} vs {expect}, {msg}" + elif math.isinf(value): + assert math.isinf(expect), f"{value} vs {expect}, {msg}" + + +@pytest.mark.parametrize( + "dtype, literals", + [ + ["float16", [-65504.0, 3.14, 65504.0, np.inf, np.nan]], + ["bfloat16", [-3.38953139e38, 3.38953139e38, 3.14]], + ["float32", [np.finfo("float32").min, 3.14, np.finfo("float32").max, np.inf, np.nan]], + ["float64", [np.finfo("float64").min, 3.14, np.finfo("float64").max, np.inf, np.nan]], + ], +) +def test_tir_make_floatimm(dtype, literals): + for l in literals: + imm = tir.const(l, dtype) + compare_float_value(imm.value, l, "imm value should match feed value") + + +@pytest.mark.parametrize( + "dtype, literals", + [ + ["float16", [-65505.0, 65505.0]], + ["float32", [-3.402e39, 3.402e39]], + ], +) +def test_tir_invalid_floatimm(dtype, literals): + """Currently only fp16 and fp32 have range check.""" + for l in literals: + with pytest.raises(tvm.TVMError): + tir.const(l, dtype) + + +@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"]) +@pytest.mark.parametrize("literal", [3.14, np.nan, np.inf]) +def test_tir_special_floatimms(dtype, literal): + x = tir.const(literal, dtype) + compare_float_value(x.value, literal, "imm value should match feed value") + + +@tvm.testing.requires_llvm() +def test_tir_too_large_literal_f64(): + # Behavior check: if literal f64 value is out of dtype range, the + # object is still constructed, and eval to infinity. + @T.prim_func + def imm_overflow_fp64() -> T.float64: + T.evaluate(T.ret(T.float64(1.7976e309), dtype="float64")) + + f = tvm.build(imm_overflow_fp64, target="llvm") + assert math.isinf(f()) + + +@pytest.mark.parametrize( + "literal, expect_dtype", + [ + (256, "int32"), + (2147483647, "int32"), + (-2147483648, "int32"), + (2147483648, "int64"), + (-2147483649, "int64"), + (3.14159, "float32"), + (np.finfo("float32").min, "float32"), + (np.finfo("float32").max, "float32"), + (-3.402e39, "float64"), + (3.402e39, "float64"), + ], +) +def test_tir_const_auto_dtype(literal, expect_dtype): + x = tir.const(literal, dtype=None) + assert x.dtype == expect_dtype + assert x.value == literal + + +def check_tir_const_fold( + dtype, foldf, calcf, x_range=None, y_range=None, expect=None, skip_overflow=False +): + """Helper to check constant folding behavior + + Parameters + ---------- + dtype: str + Datatype of constants + + foldf: (x, y) -> z + Folding function to call + + calcf: (x, y) -> z + Compiled calculation function to call + + x_range: Union[int, float, tuple] + Single value or value range [min, max] + + y_range: Union[int, float, tuple] + Single value or value range [min, max] + + expect: Union[int, float] + Expected calculation result + + skip_overflow: bool + Skip assertion if the overflow happens + """ + seed = random.randint(0, 2147483648) + np.random.seed(seed) + ninfo = np.finfo(dtype) if dtype.startswith("float") else np.iinfo(dtype) + + if x_range is None: + x_range = (ninfo.min, ninfo.max) + if isinstance(x_range, (int, float)): + x = x_range + elif dtype.startswith("int") or dtype.startswith("uint"): + x = np.random.randint(x_range[0], x_range[1] + 1, dtype=dtype) + else: + x = np.random.uniform(x_range[0], x_range[1]) + + if y_range is None: + y_range = (ninfo.min, ninfo.max) + if isinstance(y_range, (int, float)): + y = y_range + elif dtype.startswith("int") or dtype.startswith("uint"): + y = np.random.randint(y_range[0], y_range[1] + 1, dtype=dtype) + else: + y = np.random.uniform(y_range[0], y_range[1]) + + if skip_overflow: + py_res = foldf(x, y) + if isinstance(py_res, (tir.IntImm, tir.FloatImm)): + py_res = py_res.value + if not (ninfo.min <= py_res <= ninfo.max): + # If the result overflow, certain arithmetics is non-defined + # thus we intentionally do not make the test failed. + return + + fold_res = foldf(tir.const(x, dtype), tir.const(y, dtype)) + calc_res = calcf(x, y) + + flaky_msg = ( + f"{dtype} ({x}, {y}, {expect}) const folding check failed.\n" + + "This test is intentionally non-deterministic, " + + f"if it fails please report it in github issue together with this seed {seed}\n" + ) + if dtype.startswith("float"): + compare_float_value(calc_res, fold_res.value, flaky_msg) + if expect: + compare_float_value(expect, calc_res, flaky_msg) + else: + assert calc_res == fold_res.value, flaky_msg + if expect: + assert expect == calc_res, flaky_msg + + +@tvm.testing.requires_llvm() +def test_tir_floatimm_const_fold(): + """Behavior check: folding fp32 match platform f32 arithmetic""" + + @T.prim_func + def float_imm_multiply(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x * y + + @T.prim_func + def float_imm_add(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x + y + + @T.prim_func + def float_imm_sub(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x - y + + @T.prim_func + def float_imm_div(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x / y + + def __wrap_build(f): + lib = tvm.build(f, target="llvm") + z = tvm.nd.array(np.zeros([]).astype("float32")) + + def _func(x, y): + lib(x, y, z) + return z.numpy() + + return _func + + fmul = __wrap_build(float_imm_multiply) + fadd = __wrap_build(float_imm_add) + fsub = __wrap_build(float_imm_sub) + fdiv = __wrap_build(float_imm_div) + + # overflow + check_tir_const_fold("float32", lambda x, y: x * y, fmul, 3.0e30, 3.0e30, np.inf) + check_tir_const_fold("float32", lambda x, y: x * y, fmul, 3.0e30, -3.0e30, -np.inf) + check_tir_const_fold("float32", lambda x, y: x / y, fdiv, 3.0e30, 3.0e-30, np.inf) + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("float32", lambda x, y: x / y, fdiv, 1.0, 0.0) + + # nan and inf + check_tir_const_fold("float32", lambda x, y: x + y, fadd, 1.0, np.nan, np.nan) + check_tir_const_fold("float32", lambda x, y: x + y, fadd, 1.0, np.inf, np.inf) + check_tir_const_fold("float32", lambda x, y: x + y, fadd, 1.0, -np.inf, -np.inf) + + # randomized check + check_tir_const_fold("float32", lambda x, y: x * y, fmul) + check_tir_const_fold("float32", lambda x, y: x + y, fadd) + check_tir_const_fold("float32", lambda x, y: x - y, fsub) + check_tir_const_fold( + "float32", lambda x, y: x / y, fdiv, y_range=(0.01, np.finfo("float32").max) + ) + + +@tvm.testing.requires_llvm() +def test_tir_int8_const_fold(): + """Behavior check: folding i8 operation match platform i8 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(x * y, dtype="int8")) + + @T.prim_func + def imm_add(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(x + y, dtype="int8")) + + @T.prim_func + def imm_sub(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(x - y, dtype="int8")) + + @T.prim_func + def imm_truncdiv(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="int8")) + + @T.prim_func + def imm_floordiv(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(T.floordiv(x, y), dtype="int8")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + + # overflow + check_tir_const_fold("int8", lambda x, y: x + y, fadd, 127, 1, -128) + check_tir_const_fold("int8", lambda x, y: x * y, fmul, 127, 127, 1) + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int8", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + + # i8 mod folding is not implemented + assert not isinstance(tir.floormod(tir.const(7, "int8"), tir.const(3, "int8")), tir.IntImm) + assert not isinstance(tir.truncmod(tir.const(7, "int8"), tir.const(3, "int8")), tir.IntImm) + + # randomized check + check_tir_const_fold("int8", lambda x, y: x * y, fmul) + check_tir_const_fold("int8", lambda x, y: x + y, fadd) + check_tir_const_fold("int8", lambda x, y: x - y, fsub) + check_tir_const_fold( + "int8", lambda x, y: tir.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("int8").max) + ) + check_tir_const_fold( + "int8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("int8").max) + ) + + +@tvm.testing.requires_llvm() +def test_tir_uint8_const_fold(): + """Behavior check: folding u8 operation match platform u8 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(x * y, dtype="uint8")) + + @T.prim_func + def imm_add(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(x + y, dtype="uint8")) + + @T.prim_func + def imm_sub(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(x - y, dtype="uint8")) + + @T.prim_func + def imm_truncdiv(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="uint8")) + + @T.prim_func + def imm_floordiv(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(T.floordiv(x, y), dtype="uint8")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + + # overflow + check_tir_const_fold("uint8", lambda x, y: x + y, fadd, 255, 1, 0) + + # zero sub + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint8", lambda x, y: x - y, fsub, 0, 10) + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint8", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + + # u8 mod folding is not implemented + assert not isinstance(tir.floormod(tir.const(7, "uint8"), tir.const(3, "uint8")), tir.IntImm) + assert not isinstance(tir.truncmod(tir.const(7, "uint8"), tir.const(3, "uint8")), tir.IntImm) + + # randomized check + check_tir_const_fold("uint8", lambda x, y: x * y, fmul) + check_tir_const_fold("uint8", lambda x, y: x + y, fadd) + check_tir_const_fold("uint8", lambda x, y: x - y, fsub) + check_tir_const_fold( + "uint8", lambda x, y: tir.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("uint8").max) + ) + check_tir_const_fold( + "uint8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("uint8").max) + ) + + +@tvm.testing.requires_llvm() +def test_tir_int32_const_fold(): + """Behavior check: folding i32 operation match platform i32 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(x * y, dtype="int32")) + + @T.prim_func + def imm_add(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(x + y, dtype="int32")) + + @T.prim_func + def imm_sub(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(x - y, dtype="int32")) + + @T.prim_func + def imm_truncdiv(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="int32")) + + @T.prim_func + def imm_truncmod(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.truncmod(x, y), dtype="int32")) + + @T.prim_func + def imm_floordiv(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.floordiv(x, y), dtype="int32")) + + @T.prim_func + def imm_floormod(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.floormod(x, y), dtype="int32")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ffloormod = tvm.build(imm_floormod, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + ftruncmod = tvm.build(imm_truncmod, target="llvm") + + # i32 overflow is not specified, only check for range + assert -(2**31) <= int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) < 2**31 + assert -(2**31) <= int(tir.const(-(2**31), "int32") - tir.const(1, "int32")) < 2**31 + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.floormod(x, y), ffloormod, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.truncmod(x, y), ftruncmod, 1, 0) + + # randomized check + check_tir_const_fold("int32", lambda x, y: x * y, fmul, skip_overflow=True) + check_tir_const_fold("int32", lambda x, y: x + y, fadd, skip_overflow=True) + check_tir_const_fold("int32", lambda x, y: x - y, fsub, skip_overflow=True) + check_tir_const_fold( + "int32", + lambda x, y: tir.floordiv(x, y), + ffloordiv, + y_range=(1, np.iinfo("int32").max), + skip_overflow=True, + ) + check_tir_const_fold( + "int32", + lambda x, y: tir.truncdiv(x, y), + ftruncdiv, + y_range=(1, np.iinfo("int32").max), + skip_overflow=True, + ) + check_tir_const_fold( + "int32", + lambda x, y: tir.floormod(x, y), + ffloormod, + y_range=(1, np.iinfo("int32").max), + skip_overflow=False, + ) + check_tir_const_fold( + "int32", + lambda x, y: tir.truncmod(x, y), + ftruncmod, + y_range=(1, np.iinfo("int32").max), + skip_overflow=False, + ) + + +@tvm.testing.requires_llvm() +def test_tir_uint32_const_fold(): + """Behavior check: folding u32 operation match platform u32 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(x * y, dtype="uint32")) + + @T.prim_func + def imm_add(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(x + y, dtype="uint32")) + + @T.prim_func + def imm_sub(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(x - y, dtype="uint32")) + + @T.prim_func + def imm_truncdiv(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="uint32")) + + @T.prim_func + def imm_floordiv(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(T.floordiv(x, y), dtype="uint32")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + + # u32 overflow is not specified, only check for range + assert 0 <= int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) < 2**32 + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint32", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint32", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + + # u8 mod folding is not implemented + assert not isinstance(tir.floormod(tir.const(7, "uint32"), tir.const(3, "uint32")), tir.IntImm) + assert not isinstance(tir.truncmod(tir.const(7, "uint32"), tir.const(3, "uint32")), tir.IntImm) + + # randomized check + check_tir_const_fold("uint32", lambda x, y: x * y, fmul, skip_overflow=True) + check_tir_const_fold("uint32", lambda x, y: x + y, fadd, skip_overflow=True) + check_tir_const_fold("uint32", lambda x, y: x - y, fsub, skip_overflow=True) + check_tir_const_fold( + "uint32", + lambda x, y: tir.floordiv(x, y), + ffloordiv, + y_range=(1, np.iinfo("uint32").max), + skip_overflow=False, + ) + check_tir_const_fold( + "uint32", + lambda x, y: tir.truncdiv(x, y), + ftruncdiv, + y_range=(1, np.iinfo("uint32").max), + skip_overflow=False, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index d66b4ef5dd5b..20818a5b326a 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -67,8 +67,6 @@ def check(m, n, target_bits, target_dtype): # const shape # i32 -> i32 check(2, 2, 32, "int32") - # i32 + i32 is not promoted to i64 even if overflow - check(2**16, 2**16, 32, "int32") # i64 -> i32 check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32") check(const(2**16, dtype="int64"), const(2**16, dtype="int64"), 32, "int64") @@ -100,12 +98,6 @@ def check(m, n, target_bits, target_dtype): # i32 -> i32 check(2, 32, target_bits=32, target_dtype="int32") - check( - 2**30, - 32, # i32 + i32 is not promoted to i64 even in the case of overflow - target_bits=32, - target_dtype="int32", - ) # i64 -> i32 check(const(2, dtype="int64"), const(32, dtype="int64"), target_bits=32, target_dtype="int32") check( @@ -162,7 +154,6 @@ def check(m, lanes, target_bits, target_dtype): # i32 -> i32 check(const(2**10, dtype="int32"), 2, target_bits=32, target_dtype="int32") - check(const(2**32, dtype="int32"), 2, target_bits=32, target_dtype="int32") # i64 -> i32 check(const(2**10, dtype="int64"), 2, target_bits=32, target_dtype="int32") check(const(2**32, dtype="int64"), 2, target_bits=32, target_dtype="int64")