Skip to content

Commit

Permalink
[x86 & HVX & WASM] Use bounds inference for saturating_narrow instruc…
Browse files Browse the repository at this point in the history
…tion selection (#7805)

* x86 bounds inference for saturating_narrow

* bounds inference for HVX too

* use can_represent(ConstantInterval) + clang-format

* use bounds inference for WASM IS too + add tests

* add tracking issue for scoped constant bounds

* add TODO about lossless_cast usage

---------

Co-authored-by: Steven Johnson <srj@google.com>
  • Loading branch information
rootjalex and steven-johnson authored Apr 30, 2024
1 parent d55d82b commit 8141197
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 5 deletions.
37 changes: 37 additions & 0 deletions src/CodeGen_WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "CodeGen_Posix.h"
#include "ConciseCasts.h"
#include "ConstantBounds.h"
#include "IRMatch.h"
#include "IROperator.h"
#include "LLVM_Headers.h"
Expand Down Expand Up @@ -206,6 +207,12 @@ void CodeGen_WebAssembly::visit(const Call *op) {
{"saturating_narrow", i16_sat(wild_i32x_), Target::WasmSimd128},
{"saturating_narrow", u16_sat(wild_i32x_), Target::WasmSimd128},
};
static const Pattern reinterpret_patterns[] = {
{"saturating_narrow", i8_sat(wild_u16x_), Target::WasmSimd128},
{"saturating_narrow", u8_sat(wild_u16x_), Target::WasmSimd128},
{"saturating_narrow", i16_sat(wild_u32x_), Target::WasmSimd128},
{"saturating_narrow", u16_sat(wild_u32x_), Target::WasmSimd128},
};
static const vector<pair<Expr, Expr>> cast_rewrites = {
// Some double-narrowing saturating casts can be better expressed as
// combinations of single-narrowing saturating casts.
Expand Down Expand Up @@ -235,6 +242,36 @@ void CodeGen_WebAssembly::visit(const Call *op) {
return;
}
}

// Search for saturating casts where the inner value can be
// reinterpreted to signed, so that we can use existing
// saturating_narrow instructions.
// TODO: should use lossless_cast once it is fixed.
for (const Pattern &p : reinterpret_patterns) {
if (!target.has_feature(p.required_feature)) {
continue;
}
if (expr_match(p.pattern, op, matches)) {
const Expr &expr = matches[0];
const Type &t = expr.type();
// TODO(8212): might want to keep track of scope of bounds information.
const ConstantInterval ibounds = constant_integer_bounds(expr);
const Type reint_type = t.with_code(halide_type_int);
// If the signed type can represent the maximum value unsigned value,
// we can safely reinterpret this unsigned expression as signed.
if (reint_type.can_represent(ibounds)) {
// Can safely reinterpret to signed integer.
matches[0] = cast(reint_type, matches[0]);

value = call_overloaded_intrin(op->type, p.intrin, matches);
if (value) {
return;
}
}
// No reinterpret patterns match the same input, so stop matching.
break;
}
}
}

if (op->is_intrinsic(Call::round)) {
Expand Down
40 changes: 38 additions & 2 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "CodeGen_Internal.h"
#include "CodeGen_Posix.h"
#include "ConciseCasts.h"
#include "ConstantBounds.h"
#include "Debug.h"
#include "IRMatch.h"
#include "IRMutator.h"
Expand Down Expand Up @@ -537,7 +538,7 @@ void CodeGen_X86::visit(const Cast *op) {
};

// clang-format off
static Pattern patterns[] = {
static const Pattern patterns[] = {
// This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't
// saturate the result.
{"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))},
Expand Down Expand Up @@ -647,7 +648,7 @@ void CodeGen_X86::visit(const Call *op) {
};

// clang-format off
static Pattern patterns[] = {
static const Pattern patterns[] = {
{"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)},
{"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)},
{"saturating_narrow", i16_sat(wild_i32x_)},
Expand All @@ -667,6 +668,41 @@ void CodeGen_X86::visit(const Call *op) {
}
}

// clang-format off
static const Pattern reinterpret_patterns[] = {
{"saturating_narrow", i16_sat(wild_u32x_)},
{"saturating_narrow", u16_sat(wild_u32x_)},
{"saturating_narrow", i8_sat(wild_u16x_)},
{"saturating_narrow", u8_sat(wild_u16x_)},
};
// clang-format on

// Search for saturating casts where the inner value can be
// reinterpreted to signed, so that we can use existing
// saturating_narrow instructions.
// TODO: should use lossless_cast once it is fixed.
for (const auto &pattern : reinterpret_patterns) {
if (expr_match(pattern.pattern, op, matches)) {
const Expr &expr = matches[0];
const Type &t = expr.type();
// TODO(8212): might want to keep track of scope of bounds information.
const ConstantInterval ibounds = constant_integer_bounds(expr);
const Type reint_type = t.with_code(halide_type_int);
// If the signed type can represent the maximum value unsigned value,
// we can safely reinterpret this unsigned expression as signed.
if (reint_type.can_represent(ibounds)) {
// Can safely reinterpret to signed integer.
matches[0] = cast(reint_type, matches[0]);
value = call_overloaded_intrin(op->type, pattern.intrin, matches);
if (value) {
return;
}
}
// No reinterpret patterns match the same input, so stop matching.
break;
}
}

static const vector<pair<Expr, Expr>> cast_rewrites = {
// Some double-narrowing saturating casts can be better expressed as
// combinations of single-narrowing saturating casts.
Expand Down
38 changes: 35 additions & 3 deletions src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "CSE.h"
#include "CodeGen_Internal.h"
#include "ConciseCasts.h"
#include "ConstantBounds.h"
#include "DistributeShifts.h"
#include "ExprUsesVar.h"
#include "FindIntrinsics.h"
Expand Down Expand Up @@ -189,8 +190,10 @@ struct Pattern {
// re-interleave the result.
ReinterleaveOp0 = InterleaveResult | DeinterleaveOp0,

v65orLater = 1 << 10, // Pattern should be matched only for v65 target or later
v66orLater = 1 << 11, // Pattern should be matched only for v66 target or later
SafeReinterpretOp0 = 1 << 10, // Pattern should be matched only if the first arg can be safely reinterpreted.

v65orLater = 1 << 11, // Pattern should be matched only for v65 target or later
v66orLater = 1 << 12, // Pattern should be matched only for v66 target or later
};

string intrin; // Name of the intrinsic
Expand Down Expand Up @@ -260,6 +263,27 @@ bool process_match_flags(vector<Expr> &matches, int flags) {
internal_assert(matches.size() >= 3);
std::swap(matches[1], matches[2]);
}
if (flags & Pattern::SafeReinterpretOp0) {
// Use bounds inference to check if the first operand can
// be safely reinterpreted.
// TODO: should use lossless_cast once it is fixed.
const Expr &expr = matches[0];
const Type &t = expr.type();
if (t.is_int()) {
// TODO(8212): might want to keep track of scope of bounds information.
const ConstantInterval ibounds = constant_integer_bounds(expr);
const Type reint_type = UInt(t.bits());
// A signed integer can be reinterpreted as unsigned if strictly positive.
return reint_type.can_represent(ibounds);
} else {
internal_assert(t.is_uint());
// TODO(8212): might want to keep track of scope of bounds information.
const ConstantInterval ibounds = constant_integer_bounds(expr);
const Type reint_type = Int(t.bits());
// An unsigned integer can be reinterpreted as signed if less than int max.
return reint_type.can_represent(ibounds);
}
}
return true;
}

Expand Down Expand Up @@ -915,10 +939,18 @@ class OptimizePatterns : public IRMutator {

// Saturating narrowing casts. These may interleave later with trunc_sat.
{"halide.hexagon.pack_satub.vh", u8_sat(wild_i16x)},
{"halide.hexagon.pack_satub.vuh", u8_sat(wild_u16x)},
{"halide.hexagon.pack_satuh.vw", u16_sat(wild_i32x)},
{"halide.hexagon.pack_satb.vh", i8_sat(wild_i16x)},
{"halide.hexagon.pack_sath.vw", i16_sat(wild_i32x)},
// The same patterns as above, but with safely reinterpreting the
// argument to be signed.
{"halide.hexagon.pack_satub.vh", u8_sat(wild_u16x), Pattern::SafeReinterpretOp0},
{"halide.hexagon.pack_satuh.vw", u16_sat(wild_u32x), Pattern::SafeReinterpretOp0},
{"halide.hexagon.pack_satb.vh", i8_sat(wild_u16x), Pattern::SafeReinterpretOp0},
{"halide.hexagon.pack_sath.vw", i16_sat(wild_u32x), Pattern::SafeReinterpretOp0},
// Slightly more expensive versions of uint saturation casts than the reinterpret
// patterns above, these perform vpack(min(UMAX, x)).
{"halide.hexagon.pack_satub.vuh", u8_sat(wild_u16x)},
{"halide.hexagon.pack_satuh.vuw", u16_sat(wild_u32x)},

// We don't have a vpack equivalent to this one, so we match it directly.
Expand Down
5 changes: 5 additions & 0 deletions test/correctness/simd_op_check_hvx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ class SimdOpCheckHVX : public SimdOpCheckTest {
// for a more detailed explanation.
check("v*.uh = vsat(v*.uw,v*.uw)", hvx_width / 2, u16_sat(u32_1));
check("v*.h = vpack(v*.w,v*.w):sat", hvx_width / 2, i16_sat(i32_1));
// Test that bounds-inference instruction selection is working properly.
check("v*.ub = vpack(v*.h,v*.h):sat", hvx_width / 1, u8_sat(u16_1 >> 1));
check("v*.b = vpack(v*.h,v*.h):sat", hvx_width / 1, i8_sat(u16_1 >> 1));
check("v*.uh = vpack(v*.w,v*.w):sat", hvx_width / 2, u16_sat(u32_1 >> 1));
check("v*.h = vpack(v*.w,v*.w):sat", hvx_width / 2, i16_sat(u32_1 >> 1));

// vpack doesn't interleave its inputs, which means it doesn't
// simplify with widening. This is preferable for when the
Expand Down
5 changes: 5 additions & 0 deletions test/correctness/simd_op_check_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ class SimdOpCheckWASM : public SimdOpCheckTest {
check("i16x8.narrow_i32x4_u", 8 * w, u16_sat(i32_1));
check("i16x8.narrow_i32x4_s", 8 * w, i8_sat(i32_1));
check("i16x8.narrow_i32x4_s", 8 * w, u8_sat(i32_1));
// Test that bounds-inference instruction selection is working properly.
check("i8x16.narrow_i16x8_s", 16 * w, i8_sat(u16_1 >> 1));
check("i8x16.narrow_i16x8_u", 16 * w, u8_sat(u16_1 >> 1));
check("i16x8.narrow_i32x4_s", 8 * w, i16_sat(u32_1 >> 1));
check("i16x8.narrow_i32x4_u", 8 * w, u16_sat(u32_1 >> 1));

// Integer to integer widening
check("i16x8.extend_low_i8x16_s", 16 * w, i16(i8_1));
Expand Down

0 comments on commit 8141197

Please sign in to comment.