From 678ea32f37ad0968d019db2a1228ac87c2cedf92 Mon Sep 17 00:00:00 2001 From: Alexander Root <32245479+rootjalex@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:49:57 -0700 Subject: [PATCH] [ARM] support new udot/sdot patterns (#7800) --- src/CodeGen_ARM.cpp | 85 ++++++++++++++++++++++++++ src/FindIntrinsics.cpp | 11 +++- test/correctness/simd_op_check_arm.cpp | 17 +++++- 3 files changed, 109 insertions(+), 4 deletions(-) diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 8e4304f89941..87e170da98d1 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -117,6 +117,7 @@ class CodeGen_ARM : public CodeGen_Posix { /** Nodes for which we want to emit specific neon intrinsics */ // @{ void visit(const Cast *) override; + void visit(const Add *) override; void visit(const Sub *) override; void visit(const Min *) override; void visit(const Max *) override; @@ -910,6 +911,90 @@ void CodeGen_ARM::visit(const Cast *op) { CodeGen_Posix::visit(op); } +void CodeGen_ARM::visit(const Add *op) { + if (neon_intrinsics_disabled() || + !op->type.is_vector() || + !target.has_feature(Target::ARMDotProd) || + !op->type.is_int_or_uint() || + op->type.bits() != 32) { + CodeGen_Posix::visit(op); + return; + } + + struct Pattern { + Expr pattern; + const char *intrin; + Type coeff_type = UInt(8); + }; + + // Initial values. + Expr init_i32 = Variable::make(Int(32, 0), "init"); + Expr init_u32 = Variable::make(UInt(32, 0), "init"); + // Values + Expr a_i8 = Variable::make(Int(8, 0), "a"), b_i8 = Variable::make(Int(8, 0), "b"); + Expr c_i8 = Variable::make(Int(8, 0), "c"), d_i8 = Variable::make(Int(8, 0), "d"); + Expr a_u8 = Variable::make(UInt(8, 0), "a"), b_u8 = Variable::make(UInt(8, 0), "b"); + Expr c_u8 = Variable::make(UInt(8, 0), "c"), d_u8 = Variable::make(UInt(8, 0), "d"); + // Coefficients + Expr ac_i8 = Variable::make(Int(8, 0), "ac"), bc_i8 = Variable::make(Int(8, 0), "bc"); + Expr cc_i8 = Variable::make(Int(8, 0), "cc"), dc_i8 = Variable::make(Int(8, 0), "dc"); + Expr ac_u8 = Variable::make(UInt(8, 0), "ac"), bc_u8 = Variable::make(UInt(8, 0), "bc"); + Expr cc_u8 = Variable::make(UInt(8, 0), "cc"), dc_u8 = Variable::make(UInt(8, 0), "dc"); + + // clang-format off + static const Pattern patterns[] = { + // If we had better normalization, we could drastically reduce the number of patterns here. + // Signed variants. + {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product"}, + {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8)), "dot_product", Int(8)}, + {init_i32 + widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, + {init_i32 + widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, + {init_i32 + widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8)), "dot_product", Int(8)}, + // Signed variants (associative). + {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product"}, + {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), i16(d_i8))), "dot_product", Int(8)}, + {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), widening_mul(b_i8, bc_i8)) + widening_add(i16(c_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, + {init_i32 + (widening_add(widening_mul(a_i8, ac_i8), i16(b_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, + {init_i32 + (widening_add(i16(a_i8), widening_mul(b_i8, bc_i8)) + widening_add(widening_mul(c_i8, cc_i8), widening_mul(d_i8, dc_i8))), "dot_product", Int(8)}, + // Unsigned variants. + {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product"}, + {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8)), "dot_product", UInt(8)}, + {init_u32 + widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, + {init_u32 + widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, + {init_u32 + widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8)), "dot_product", UInt(8)}, + // Unsigned variants (associative). + {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product"}, + {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), u16(d_u8))), "dot_product", UInt(8)}, + {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), widening_mul(b_u8, bc_u8)) + widening_add(u16(c_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, + {init_u32 + (widening_add(widening_mul(a_u8, ac_u8), u16(b_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, + {init_u32 + (widening_add(u16(a_u8), widening_mul(b_u8, bc_u8)) + widening_add(widening_mul(c_u8, cc_u8), widening_mul(d_u8, dc_u8))), "dot_product", UInt(8)}, + }; + // clang-format on + + std::map matches; + for (const Pattern &p : patterns) { + if (expr_match(p.pattern, op, matches)) { + Expr init = matches["init"]; + Expr values = Shuffle::make_interleave({matches["a"], matches["b"], matches["c"], matches["d"]}); + // Coefficients can be 1 if not in the pattern. + Expr one = make_one(p.coeff_type.with_lanes(op->type.lanes())); + // This hideous code pattern implements fetching a + // default value if the map doesn't contain a key. + Expr _ac = matches.try_emplace("ac", one).first->second; + Expr _bc = matches.try_emplace("bc", one).first->second; + Expr _cc = matches.try_emplace("cc", one).first->second; + Expr _dc = matches.try_emplace("dc", one).first->second; + Expr coeffs = Shuffle::make_interleave({_ac, _bc, _cc, _dc}); + value = call_overloaded_intrin(op->type, p.intrin, {init, values, coeffs}); + if (value) { + return; + } + } + } + + CodeGen_Posix::visit(op); +} + void CodeGen_ARM::visit(const Sub *op) { if (neon_intrinsics_disabled()) { CodeGen_Posix::visit(op); diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 29a8913e1068..a77a7b1798f3 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -203,6 +203,7 @@ class FindIntrinsics : public IRMutator { IRMatcher::Wild<0> x; IRMatcher::Wild<1> y; IRMatcher::Wild<2> z; + IRMatcher::Wild<3> w; IRMatcher::WildConst<0> c0; IRMatcher::WildConst<1> c1; @@ -255,7 +256,7 @@ class FindIntrinsics : public IRMutator { result = widen_right_add(b, narrow_a); } internal_assert(result.type() == op->type); - return result; + return mutate(result); } else if (narrow_b.defined()) { Expr result; if (a.type().code() != narrow_b.type().code()) { @@ -420,7 +421,7 @@ class FindIntrinsics : public IRMutator { result = widen_right_mul(b, narrow_a); } internal_assert(result.type() == op->type); - return result; + return mutate(result); } else if (narrow_b.defined()) { Expr result; if (a.type().code() != narrow_b.type().code()) { @@ -772,6 +773,12 @@ class FindIntrinsics : public IRMutator { x + cast(op->type, widening_sub(z, y)), is_x_same_uint) || + // (x + y + widen(z)) + widen(w) = x + y + widening_add(z, w) + rewrite(widen_right_add(x + widen_right_add(y, z), w), + x + (y + widening_add(z, w)), + // We only care about integers, this should be trivially true. + is_x_same_int_or_uint) || + // Saturating patterns. rewrite(saturating_cast(op->type, widening_add(x, y)), saturating_add(x, y), diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index ba588b090bae..68fbf91a0081 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -29,8 +29,8 @@ class SimdOpCheckARM : public SimdOpCheckTest { Expr f64_1 = in_f64(x), f64_2 = in_f64(x + 16), f64_3 = in_f64(x + 32); Expr f32_1 = in_f32(x), f32_2 = in_f32(x + 16), f32_3 = in_f32(x + 32); Expr f16_1 = in_f16(x), f16_2 = in_f16(x + 16), f16_3 = in_f16(x + 32); - Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32); - Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32); + Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32), i8_4 = in_i8(x + 48); + Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32), u8_4 = in_u8(x + 48); Expr i16_1 = in_i16(x), i16_2 = in_i16(x + 16), i16_3 = in_i16(x + 32); Expr u16_1 = in_u16(x), u16_2 = in_u16(x + 16), u16_3 = in_u16(x + 32); Expr i32_1 = in_i32(x), i32_2 = in_i32(x + 16), i32_3 = in_i32(x + 32); @@ -587,6 +587,19 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vpaddl.s8" : "sdot", 8, sum_(i32(in_i8(f * x + r)))); check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(i32(in_u8(f * x + r)))); check(arm32 ? "vpaddl.u8" : "udot", 8, sum_(u32(in_u8(f * x + r)))); + if (!arm32) { + check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12); + check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4)); + check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) * 6 + i32(i8_3) + i32(i8_4) * 12); + check("sdot", 8, i32_1 + i32(i8_1) * 3 + i32(i8_2) + i32(i8_3) * 9 + i32(i8_4) * 12); + check("sdot", 8, i32_1 + i32(i8_1) + i32(i8_2) * 6 + i32(i8_3) * 9 + i32(i8_4) * 12); + + check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12); + check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4)); + check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) * 6 + u32(u8_3) + u32(u8_4) * 12); + check("udot", 8, u32_1 + u32(u8_1) * 3 + u32(u8_2) + u32(u8_3) * 9 + u32(u8_4) * 12); + check("udot", 8, u32_1 + u32(u8_1) + u32(u8_2) * 6 + u32(u8_3) * 9 + u32(u8_4) * 12); + } } else { check(arm32 ? "vpaddl.s8" : "saddlp", 8, sum_(i32(in_i8(f * x + r)))); check(arm32 ? "vpaddl.u8" : "uaddlp", 8, sum_(i32(in_u8(f * x + r))));