Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM] support new udot/sdot patterns #7800

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class CodeGen_ARM : public CodeGen_Posix {
/** Nodes for which we want to emit specific neon intrinsics */
// @{
void visit(const Cast *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
void visit(const Min *) override;
void visit(const Max *) override;
Expand Down Expand Up @@ -906,6 +907,90 @@ void CodeGen_ARM::visit(const Cast *op) {
CodeGen_Posix::visit(op);
}

void CodeGen_ARM::visit(const Add *op) {
if (neon_intrinsics_disabled() ||
!op->type.is_vector() ||
!target.has_feature(Target::ARMDotProd) ||
!op->type.is_int_or_uint() ||
op->type.bits() != 32) {
CodeGen_Posix::visit(op);
return;
}

struct Pattern {
Expr pattern;
const char *intrin;
Type coeff_type = UInt(8);
};

// Initial values.
Expr init_i32 = Variable::make(Int(32, 0), "init");
Expr init_u32 = Variable::make(UInt(32, 0), "init");
// Values
Expr a_i8 = Variable::make(Int(8, 0), "a"), b_i8 = Variable::make(Int(8, 0), "b");
Expr c_i8 = Variable::make(Int(8, 0), "c"), d_i8 = Variable::make(Int(8, 0), "d");
Expr a_u8 = Variable::make(UInt(8, 0), "a"), b_u8 = Variable::make(UInt(8, 0), "b");
Expr c_u8 = Variable::make(UInt(8, 0), "c"), d_u8 = Variable::make(UInt(8, 0), "d");
// Coefficients
Expr ac_i8 = Variable::make(Int(8, 0), "ac"), bc_i8 = Variable::make(Int(8, 0), "bc");
Expr cc_i8 = Variable::make(Int(8, 0), "cc"), dc_i8 = Variable::make(Int(8, 0), "dc");
Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc");
Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc");

// clang-format off
static const Pattern patterns[] = {
// If we had better normalization, we could drastically reduce the number of patterns here.
// Signed variants.
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
{init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)},
// Signed variants (associative).
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
{init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)},
// Unsigned variants.
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
{init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)},
// Unsigned variants (associative).
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
{init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)},
};
// clang-format on

std::map<std::string, Expr> matches;
for (const Pattern &p : patterns) {
if (expr_match(p.pattern, op, matches)) {
Expr init = matches["init"];
Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]});
// Coefficients can be 1 if not in the pattern.
Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes()));
// This hideous code pattern implements fetching a
// default value if the map doesn't contain a key.
Expr _ac = matches.try_emplace("ac", one).first->second;
Expr _bc = matches.try_emplace("bc", one).first->second;
Expr _cc = matches.try_emplace("cc", one).first->second;
Expr _dc = matches.try_emplace("dc", one).first->second;
Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc});
value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs});
if (value) {
return;
}
}
}

CodeGen_Posix::visit(op);
}

void CodeGen_ARM::visit(const Sub *op) {
if (neon_intrinsics_disabled()) {
CodeGen_Posix::visit(op);
Expand Down
11 changes: 9 additions & 2 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class FindIntrinsics : public IRMutator {
IRMatcher::Wild<0> x;
IRMatcher::Wild<1> y;
IRMatcher::Wild<2> z;
IRMatcher::Wild<3> w;
IRMatcher::WildConst<0> c0;
IRMatcher::WildConst<1> c1;

Expand Down Expand Up @@ -255,7 +256,7 @@ class FindIntrinsics : public IRMutator {
result = widen_right_add(b, narrow_a);
}
internal_assert(result.type() == op->type);
return result;
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
Expand Down Expand Up @@ -420,7 +421,7 @@ class FindIntrinsics : public IRMutator {
result = widen_right_mul(b, narrow_a);
}
internal_assert(result.type() == op->type);
return result;
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
Expand Down Expand Up @@ -772,6 +773,12 @@ class FindIntrinsics : public IRMutator {
x + cast(op->type, widening_sub(z, y)),
is_x_same_uint) ||

// (x + y + widen(z)) + widen(w) = x + y + widening_add(z, w)
rewrite(widen_right_add(x + widen_right_add(y, z), w),
x + (y + widening_add(z, w)),
// We only care about integers, this should be trivially true.
is_x_same_int_or_uint) ||

// Saturating patterns.
rewrite(saturating_cast(op->type, widening_add(x, y)),
saturating_add(x, y),
Expand Down
17 changes: 15 additions & 2 deletions test/correctness/simd_op_check_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class SimdOpCheckARM : public SimdOpCheckTest {
Expr f64_1 = in_f64(x), f64_2 = in_f64(x + 16), f64_3 = in_f64(x + 32);
Expr f32_1 = in_f32(x), f32_2 = in_f32(x + 16), f32_3 = in_f32(x + 32);
Expr f16_1 = in_f16(x), f16_2 = in_f16(x + 16), f16_3 = in_f16(x + 32);
Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32);
Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32);
Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32), i8_4 = in_i8(x + 48);
Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32), u8_4 = in_u8(x + 48);
Expr i16_1 = in_i16(x), i16_2 = in_i16(x + 16), i16_3 = in_i16(x + 32);
Expr u16_1 = in_u16(x), u16_2 = in_u16(x + 16), u16_3 = in_u16(x + 32);
Expr i32_1 = in_i32(x), i32_2 = in_i32(x + 16), i32_3 = in_i32(x + 32);
Expand Down Expand Up @@ -575,6 +575,19 @@ class SimdOpCheckARM : public SimdOpCheckTest {
check(arm32 ? "vpaddl.s8" : "sdot", 8, sum_(i32(in_i8(f * x + r))));
check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(i32(in_u8(f * x + r))));
check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(u32(in_u8(f * x + r))));
if (!arm32) {
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12);
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4));
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) + i32(i8_4) * 12);
check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) + i32(i8_3) * 9 + i32(i8_4) * 12);
check("sdot", 8, i32_1 + i32(i8_1) + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12);

check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12);
check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4));
check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) + u32(u8_4) * 12);
check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) + u32(u8_3) * 9 + u32(u8_4) * 12);
check("udot", 8, u32_1 + u32(u8_1) + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12);
}
} else {
check(arm32 ? "vpaddl.s8" : "saddlp", 8, sum_(i32(in_i8(f * x + r))));
check(arm32 ? "vpaddl.u8" : "uaddlp", 8, sum_(i32(in_u8(f * x + r))));
Expand Down