Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix horrifying bug in lossless_cast of a subtract #8155

Merged
merged 40 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7d80f8b
Fix horrifying bug in lossless_cast of a subtract
abadams Mar 14, 2024
9c33c94
Use constant integer intervals to analyze safety for lossless_cast
abadams Mar 18, 2024
e0f9f8e
Fix ARM and HVX instruction selection
abadams Mar 21, 2024
214f0fd
Using constant_integer_bounds to strengthen FindIntrinsics
abadams Mar 22, 2024
67855a5
Move new classes to new files
abadams Mar 25, 2024
bee38ce
Make the simplifier use ConstantInterval
abadams Mar 25, 2024
7f4bb38
Handle bounds of narrower types in the simplifier too
abadams Mar 25, 2024
6434210
Fix * operator. Add min/max/mod
abadams Mar 28, 2024
f308a8c
Add cache for constant bounds queries
abadams Mar 28, 2024
cffadd8
Fix ConstantInterval multiplication
abadams Apr 1, 2024
2f14881
Add a simplifier rule which is apparently now necessary
abadams Apr 1, 2024
26efb7c
Misc cleanups and test improvements
abadams Apr 1, 2024
b053ec6
Add missing files
abadams Apr 1, 2024
413b4a6
Account for more aggressive simplification in fuse test
abadams Apr 1, 2024
854122f
Remove redundant helpers
abadams Apr 1, 2024
4a293b1
Add missing comment
abadams Apr 1, 2024
0856319
clear_bounds_info -> clear_expr_info
abadams Apr 1, 2024
16a706d
Remove bad TODO
abadams Apr 1, 2024
ecfae44
It's too late to change the semantics of fixed point intrinsics
abadams Apr 1, 2024
66c56f1
Fix some UB
abadams Apr 1, 2024
0fb8d38
Stronger assert in Simplify_Div
abadams Apr 2, 2024
c6065ff
Delete bad rewrite rules
abadams Apr 2, 2024
6bcc66a
Fix bad test when lowering mul_shift_right
abadams Apr 2, 2024
c652667
Avoid UB in lowering of rounding_shift_right/left
abadams Apr 2, 2024
1737a52
Add shifts to the lossless cast fuzzer
abadams Apr 2, 2024
ddab1cf
Fix bug in lossless_negate
abadams Apr 5, 2024
a0f1d23
Add constant interval test
abadams Jun 2, 2024
bf28e00
Merge remote-tracking branch 'origin/main' into abadams/fix_lossless_…
abadams Jun 2, 2024
ac5b13d
Rework find_mpy_ops to handle more structures
abadams Jun 3, 2024
c8f7e8f
Fix bugs in lossless_cast
abadams Jun 3, 2024
9570818
Fix mul_shift_right expansion
abadams Jun 3, 2024
7414ee6
Delete commented-out code
abadams Jun 3, 2024
c33dbfb
Don't introduce out-of-range shifts in lossless_cast
abadams Jun 4, 2024
360add6
Merge branch 'main' into abadams/fix_lossless_cast_of_sub
steven-johnson Jun 5, 2024
0409f2f
Merge remote-tracking branch 'origin/main' into abadams/fix_lossless_…
abadams Jun 6, 2024
0b561c7
Some constant folding only happens after lowering intrinsics in codegen
abadams Jun 10, 2024
adb3a6b
Merge branch 'abadams/fix_lossless_cast_of_sub' of https://github.com…
abadams Jun 10, 2024
e90db04
Merge remote-tracking branch 'origin/main' into abadams/fix_lossless_…
abadams Jun 10, 2024
19b1091
Merge remote-tracking branch 'origin/main' into abadams/fix_lossless_…
abadams Jun 19, 2024
ff352ef
Merge remote-tracking branch 'origin/main' into abadams/fix_lossless_…
abadams Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 28 additions & 36 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,50 +1212,42 @@ void CodeGen_ARM::visit(const Add *op) {
Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc");
Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc");

// clang-format off
Expr ma_i8 = widening_mul(a_i8, ac_i8);
Expr mb_i8 = widening_mul(b_i8, bc_i8);
Expr mc_i8 = widening_mul(c_i8, cc_i8);
Expr md_i8 = widening_mul(d_i8, dc_i8);

Expr ma_u8 = widening_mul(a_u8, ac_u8);
Expr mb_u8 = widening_mul(b_u8, bc_u8);
Expr mc_u8 = widening_mul(c_u8, cc_u8);
Expr md_u8 = widening_mul(d_u8, dc_u8);

static const Pattern patterns[] = {
// If we had better normalization, we could drastically reduce the number of patterns here.
// Signed variants.
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
// Signed variants (associative).
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{(init_i32 + widening_add(ma_i8, mb_i8)) + widening_add(mc_i8, md_i8), "dot_product"},
{init_i32 + (widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8)), "dot_product"},
{widening_add(ma_i8, mb_i8) + widening_add(mc_i8, md_i8), "dot_product"},

// Unsigned variants.
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
// Unsigned variants (associative).
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{(init_u32 + widening_add(ma_u8, mb_u8)) + widening_add(mc_u8, md_u8), "dot_product"},
{init_u32 + (widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8)), "dot_product"},
{widening_add(ma_u8, mb_u8) + widening_add(mc_u8, md_u8), "dot_product"},
};
// clang-format on

std::map<std::string, Expr> matches;
for (const Pattern &p : patterns) {
if (expr_match(p.pattern, op, matches)) {
Expr init = matches["init"];
Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]});
// Coefficients can be 1 if not in the pattern.
Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes()));
// This hideous code pattern implements fetching a
// default value if the map doesn't contain a key.
Expr _ac = matches.try_emplace("ac", one).first->second;
Expr _bc = matches.try_emplace("bc", one).first->second;
Expr _cc = matches.try_emplace("cc", one).first->second;
Expr _dc = matches.try_emplace("dc", one).first->second;
Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc});
Expr init;
auto it = matches.find("init");
if (it == matches.end()) {
init = make_zero(op->type);
} else {
init = it->second;
}
Expr values = Shuffle::make_interleave({matches["a"], matches["b"],
matches["c"], matches["d"]});
Expr coeffs = Shuffle::make_interleave({matches["ac"], matches["bc"],
matches["cc"], matches["dc"]});
value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs});
if (value) {
return;
Expand Down
11 changes: 8 additions & 3 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ void CodeGen_X86::visit(const Cast *op) {
};

// clang-format off
static const Pattern patterns[] = {
// This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't
static Pattern patterns[] = {
// This isn't rounding_mul_shift_right(i16, i16, 15) because it doesn't
// saturate the result.
{"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))},

Expand Down Expand Up @@ -736,7 +736,12 @@ void CodeGen_X86::visit(const Call *op) {
// Handle edge case of possible overflow.
// See https://github.com/halide/Halide/pull/7129/files#r1008331426
// On AVX512 (and with enough lanes) we can use a mask register.
if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
ConstantInterval ca = constant_integer_bounds(a);
ConstantInterval cb = constant_integer_bounds(b);
if (!ca.contains(-32768) || !cb.contains(-32768)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-0x8000 is probably clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyone reading this code who doesn't know the bit layout of -32768 should not be reading this code :-)

// Overflow isn't possible
pmulhrs.accept(this);
} else if (target.has_feature(Target::AVX512) && op->type.lanes() >= 32) {
Expr expr = select((a == i16_min) && (b == i16_min), i16_max, pmulhrs);
expr.accept(this);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const IntImm *IntImm::make(Type t, int64_t value) {
internal_assert(t.is_int() && t.is_scalar())
<< "IntImm must be a scalar Int\n";
internal_assert(t.bits() >= 1 && t.bits() <= 64)
<< "IntImm must have between 1 and 64 bits\n";
<< "IntImm must have between 1 and 64 bits: " << t << "\n";

// Normalize the value by dropping the high bits.
// Since left-shift of negative value is UB in C++, cast to uint64 first;
Expand All @@ -28,7 +28,7 @@ const UIntImm *UIntImm::make(Type t, uint64_t value) {
internal_assert(t.is_uint() && t.is_scalar())
<< "UIntImm must be a scalar UInt\n";
internal_assert(t.bits() >= 1 && t.bits() <= 64)
<< "UIntImm must have between 1 and 64 bits\n";
<< "UIntImm must have between 1 and 64 bits " << t << "\n";

// Normalize the value by dropping the high bits
value <<= (64 - t.bits());
Expand Down
Loading
Loading