Skip to content

Commit

Permalink
[TIR][Arith] Add more strict checking in imm construction and folding. (
Browse files Browse the repository at this point in the history
#12515)

* Add more strict check in tir imm construction and folding.

* fix bool-compare compile error

* fix some illegal imm construction in testcases

* do not test i64 overflow behaviour because it is not consistent on cython and ctypes

* fix float32 testcase

* auto-inferred dtype should be int64 when value exceeds int32 range

* add floatimm range check for fp16 and fp32

* add more folding testcases and fix store fp32 folding result to double

* fix i386 fp16 cases
  • Loading branch information
wrongtest-intellif authored Sep 9, 2022
1 parent b21bf66 commit 029fa46
Show file tree
Hide file tree
Showing 13 changed files with 743 additions and 36 deletions.
9 changes: 8 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,9 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
if (t.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
if (value < static_cast<ValueType>(0)) {
LOG(FATAL) << "cannot make uint from negative value " << value;
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return IntImm(t, static_cast<int64_t>(value), span);
} else {
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
Expand All @@ -932,6 +934,11 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return PrimExpr();
}

template <>
inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
return MakeConstScalar(t, static_cast<int>(value), span);
}

template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
if (t.lanes() == 1) {
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,17 @@ def _scalar_type_inference(value):
elif isinstance(value, bool):
dtype = "bool"
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = "float32"
# We intentionally prefer convert the float to float32 since it's more common in DL.
if -3.40282347e38 <= value <= 3.40282347e38:
dtype = "float32"
else:
dtype = "float64"
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = "int32"
# We intentionally prefer convert the python int to int32 since it's more common in DL.
if -2147483648 <= value <= 2147483647:
dtype = "int32"
else:
dtype = "int64"
else:
raise NotImplementedError(
"Cannot automatically inference the type." " value={}".format(value)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def truncmod(x, y, span):
return tvm.tir.truncmod(x, y, span)


@register
def truncdiv(x, y, span):
return tvm.tir.truncdiv(x, y, span)


@register
def ceildiv(x, y, span):
return tvm.tir.ceildiv(x, y, span)
Expand Down
112 changes: 100 additions & 12 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <algorithm>
#include <cmath>
#include <limits>

#include "int_operator.h"

Expand Down Expand Up @@ -73,6 +74,39 @@ inline bool IsIndexType(const DataType& type) {
return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64);
}

/*! \brief Helper to get const folding result repr in int64. */
inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) {
if (dtype.bits() < 64) {
x &= (1LL << dtype.bits()) - 1;
}
if (dtype.is_int()) {
// get sign extended value of integer with specified bits
int64_t m = 1LL << (dtype.bits() - 1);
x = (x ^ m) - m;
}
return x;
}

/*! \brief Helper to get fp32 const folding result repr in double. */
inline double GetFoldResultDoubleRepr(float x) {
double res = static_cast<double>(x);
if (std::isinf(res) || std::isnan(res)) {
return res;
}
// certain platform (eg, on gcc7-i386) do the folding arithmetic
// on float and write back to double is optimized to double
// precision arithmetic, this is legal and we check the output
// range thus to ensure consistency when the float result is inf.
if (res < std::numeric_limits<float>::lowest()) {
LOG(WARNING) << "underlying float value overflow";
return -std::numeric_limits<double>::infinity();
} else if (res > std::numeric_limits<float>::max()) {
LOG(WARNING) << "underlying float value overflow";
return std::numeric_limits<double>::infinity();
}
return res;
}

#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using tir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
Expand All @@ -95,10 +129,22 @@ template <>
inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value + pb->value);
if (pa && pb) {
int64_t res = pa->value + pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) +
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
Expand All @@ -113,9 +159,21 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
<< "Checked failed. Minuend 's value is 0U and it's dtype is uint "
<< "while Subtrahend's dtype is uint; which will cause a negative uint";
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value - pb->value);
if (pa && pb) {
int64_t res = pa->value - pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) -
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
return PrimExpr();
}
}
if (fb && fb->value == 0) return a;
});
return PrimExpr();
Expand All @@ -125,7 +183,10 @@ template <>
inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value * pb->value);
if (pa && pb) {
int64_t res = pa->value * pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 1) return b;
if (pa->value == 0) return a;
Expand All @@ -134,7 +195,16 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) *
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
return PrimExpr();
}
}
if (fa) {
if (fa->value == 1) return b;
if (fa->value == 0) return a;
Expand All @@ -155,7 +225,8 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value / pb->value);
int64_t res = pa->value / pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -165,7 +236,14 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, fa->value / fb->value);
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) /
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -182,7 +260,8 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value);
int64_t res = pa->value % pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -201,7 +280,8 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, arith::floordiv(pa->value, pb->value));
int64_t res = arith::floordiv(pa->value, pb->value);
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -211,7 +291,14 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast<float>(fa->value) /
static_cast<float>(fb->value))));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -228,7 +315,8 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value));
int64_t res = arith::floormod(pa->value, pb->value);
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand Down
32 changes: 31 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>

#include "../support/scalars.h"

namespace tvm {

PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {}
Expand Down Expand Up @@ -76,7 +78,20 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) {
ICHECK(dtype.is_int() || dtype.is_uint())
<< "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
if (dtype.is_uint()) {
ICHECK_GE(value, 0U);
ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
<< " is negative for unsigned integer type " << dtype;
if (dtype.bits() < 64) {
ICHECK_LT(value, 1LL << dtype.bits())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
} else if (dtype.bits() == 1) {
// int(1)
ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
} else if (dtype.bits() < 64) {
ICHECK_GE(value, -(1LL << (dtype.bits() - 1)))
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LT(value, 1LL << (dtype.bits() - 1))
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
node->dtype = dtype;
Expand All @@ -103,6 +118,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";

// check range for float32 and float16 since they have specified range.
if (!std::isinf(value) && !std::isnan(value)) {
if (dtype.bits() == 32) {
ICHECK_GE(value, std::numeric_limits<float>::lowest())
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, std::numeric_limits<float>::max())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
} else if (dtype.is_float16()) {
ICHECK_GE(value, -support::kMaxFloat16)
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, support::kMaxFloat16)
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
}
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = dtype;
node->value = value;
Expand Down
4 changes: 0 additions & 4 deletions src/support/scalars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ IntImm ValueToIntImm(int64_t value, int width) {
}
}

// 2^15 * (1 + 1023/1024)
// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
constexpr double kMaxFloat16 = 65504.0;

FloatImm ValueToFloatImm(double value, int width) {
if (width == 16) {
if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) {
Expand Down
4 changes: 4 additions & 0 deletions src/support/scalars.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ std::string FloatImmToString(const FloatImm& float_imm);
IntImm ValueToIntImm(int64_t value, int width);
FloatImm ValueToFloatImm(double value, int width);

// 2^15 * (1 + 1023/1024)
// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
constexpr double kMaxFloat16 = 65504.0;

} // namespace support
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def verify(
# Test backwards slicing.
verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3))
# Test slicing with overlarge indices.
verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int64).max] * 3, [1, 1, 1], (3, 4, 3))
verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int32).max] * 3, [1, 1, 1], (3, 4, 3))
# Test slice mode.
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def test_fuse_dynamic_squeeze_slice_take():

squeeze = relay.op.squeeze(x, axis=[0])
strided_slice = relay.op.strided_slice(
squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1]
squeeze, begin=[0, 0], end=[15130, 2147483647], strides=[1, 1]
)
take = relay.op.take(strided_slice, take_val, axis=0)

Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,8 @@ def test_cast_simplify():
ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.tir.const(1, dtype1))
for dtype2 in dtypes:
for i in [0, 1, 2, 3]:
if i > 1 and (dtype1 == "bool" or dtype2 == "bool"):
continue
ck.verify(tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1))


Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Licensed to the Apache Software Foundation (ASF) under one

# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
Expand Down Expand Up @@ -194,13 +195,13 @@ def check_cuda(n, value, lanes):
fun(a)
np.testing.assert_equal(a.numpy(), np_a)

check_cuda(64, 0xAB, 4)
check_cuda(64, np.int8(0xAB), 4)
check_cuda(64, 0, 4)
check_cuda(64, -3, 4)
check_cuda(64, 0xAB, 3)
check_cuda(64, np.int8(0xAB), 3)
check_cuda(64, 0, 3)
check_cuda(64, -3, 3)
check_cuda(64, 0xAB, 2)
check_cuda(64, np.int8(0xAB), 2)
check_cuda(64, 0, 2)
check_cuda(64, -3, 2)

Expand Down
Loading

0 comments on commit 029fa46

Please sign in to comment.