Skip to content

Commit

Permalink
Implement relaxed SIMD dot product instructions (#4586)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlively authored Apr 11, 2022
1 parent 5f88dcd commit ce0333d
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 13 deletions.
6 changes: 5 additions & 1 deletion scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,11 @@
("f32x4.relaxed_max", "makeBinary(s, BinaryOp::RelaxedMaxVecF32x4)"),
("f64x2.relaxed_min", "makeBinary(s, BinaryOp::RelaxedMinVecF64x2)"),
("f64x2.relaxed_max", "makeBinary(s, BinaryOp::RelaxedMaxVecF64x2)"),
("i16x8.relaxed_q15mulr_s", "makeBinary(s, BinaryOp::RelaxedQ15MulrSVecI16x8)"),
("i16x8.relaxed_q15mulr_s", "makeBinary(s, BinaryOp::RelaxedQ15MulrSVecI16x8)"),
("i16x8.dot_i8x16_i7x16_s", "makeBinary(s, BinaryOp::DotI8x16I7x16SToVecI16x8)"),
("i16x8.dot_i8x16_i7x16_u", "makeBinary(s, BinaryOp::DotI8x16I7x16UToVecI16x8)"),
("i32x4.dot_i8x16_i7x16_add_s", "makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddSToVecI32x4)"),
("i32x4.dot_i8x16_i7x16_add_u", "makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddUToVecI32x4)"),

# reference types instructions
("ref.null", "makeRefNull(s)"),
Expand Down
33 changes: 30 additions & 3 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,17 @@ switch (op[0]) {
case 'b':
if (strcmp(op, "i16x8.bitmask") == 0) { return makeUnary(s, UnaryOp::BitmaskVecI16x8); }
goto parse_error;
case 'd': {
switch (op[22]) {
case 's':
if (strcmp(op, "i16x8.dot_i8x16_i7x16_s") == 0) { return makeBinary(s, BinaryOp::DotI8x16I7x16SToVecI16x8); }
goto parse_error;
case 'u':
if (strcmp(op, "i16x8.dot_i8x16_i7x16_u") == 0) { return makeBinary(s, BinaryOp::DotI8x16I7x16UToVecI16x8); }
goto parse_error;
default: goto parse_error;
}
}
case 'e': {
switch (op[7]) {
case 'q':
Expand Down Expand Up @@ -1692,9 +1703,25 @@ switch (op[0]) {
case 'b':
if (strcmp(op, "i32x4.bitmask") == 0) { return makeUnary(s, UnaryOp::BitmaskVecI32x4); }
goto parse_error;
case 'd':
if (strcmp(op, "i32x4.dot_i16x8_s") == 0) { return makeBinary(s, BinaryOp::DotSVecI16x8ToVecI32x4); }
goto parse_error;
case 'd': {
switch (op[11]) {
case '1':
if (strcmp(op, "i32x4.dot_i16x8_s") == 0) { return makeBinary(s, BinaryOp::DotSVecI16x8ToVecI32x4); }
goto parse_error;
case '8': {
switch (op[26]) {
case 's':
if (strcmp(op, "i32x4.dot_i8x16_i7x16_add_s") == 0) { return makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddSToVecI32x4); }
goto parse_error;
case 'u':
if (strcmp(op, "i32x4.dot_i8x16_i7x16_add_u") == 0) { return makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddUToVecI32x4); }
goto parse_error;
default: goto parse_error;
}
}
default: goto parse_error;
}
}
case 'e': {
switch (op[7]) {
case 'q':
Expand Down
4 changes: 4 additions & 0 deletions src/ir/cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
case SwizzleVecI8x16:
case RelaxedSwizzleVecI8x16:
case RelaxedQ15MulrSVecI16x8:
case DotI8x16I7x16SToVecI16x8:
case DotI8x16I7x16UToVecI16x8:
ret = 1;
break;
case InvalidBinary:
Expand Down Expand Up @@ -541,6 +543,8 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
case RelaxedFmsVecF32x4:
case RelaxedFmaVecF64x2:
case RelaxedFmsVecF64x2:
case DotI8x16I7x16AddSToVecI32x4:
case DotI8x16I7x16AddUToVecI32x4:
ret = 1;
break;
}
Expand Down
2 changes: 2 additions & 0 deletions src/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ class Literal {
Literal minUI32x4(const Literal& other) const;
Literal maxSI32x4(const Literal& other) const;
Literal maxUI32x4(const Literal& other) const;
Literal dotSI8x16toI16x8(const Literal& other) const;
Literal dotUI8x16toI16x8(const Literal& other) const;
Literal dotSI16x8toI32x4(const Literal& other) const;
Literal extMulLowSI32x4(const Literal& other) const;
Literal extMulHighSI32x4(const Literal& other) const;
Expand Down
12 changes: 12 additions & 0 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,12 @@ struct PrintExpressionContents
case RelaxedFmsVecF64x2:
o << "f64x2.relaxed_fms";
break;
case DotI8x16I7x16AddSToVecI32x4:
o << "i32x4.dot_i8x16_i7x16_add_s";
break;
case DotI8x16I7x16AddUToVecI32x4:
o << "i32x4.dot_i8x16_i7x16_add_u";
break;
}
restoreNormalColor(o);
}
Expand Down Expand Up @@ -1854,6 +1860,12 @@ struct PrintExpressionContents
case RelaxedQ15MulrSVecI16x8:
o << "i16x8.relaxed_q15mulr_s";
break;
case DotI8x16I7x16SToVecI16x8:
o << "i16x8.dot_i8x16_i7x16_s";
break;
case DotI8x16I7x16UToVecI16x8:
o << "i16x8.dot_i8x16_i7x16_u";
break;

case InvalidBinary:
WASM_UNREACHABLE("unvalid binary operator");
Expand Down
4 changes: 4 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,10 @@ enum ASTNodes {
F64x2RelaxedMin = 0xd4,
F64x2RelaxedMax = 0xee,
I16x8RelaxedQ15MulrS = 0x111,
I16x8DotI8x16I7x16S = 0x112,
I16x8DotI8x16I7x16U = 0x113,
I32x4DotI8x16I7x16AddS = 0x114,
I32x4DotI8x16I7x16AddU = 0x115,

// bulk memory opcodes

Expand Down
7 changes: 6 additions & 1 deletion src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,11 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
case RelaxedSwizzleVecI8x16:
return left.swizzleI8x16(right);

case DotI8x16I7x16SToVecI16x8:
return left.dotSI8x16toI16x8(right);
case DotI8x16I7x16UToVecI16x8:
return left.dotUI8x16toI16x8(right);

case InvalidBinary:
WASM_UNREACHABLE("invalid binary op");
}
Expand Down Expand Up @@ -1124,7 +1129,7 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
case RelaxedFmsVecF64x2:
return a.relaxedFmsF64x2(b, c);
default:
// TODO: implement signselect
// TODO: implement signselect and dot_add
WASM_UNREACHABLE("not implemented");
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ enum BinaryOp {
RelaxedMinVecF64x2,
RelaxedMaxVecF64x2,
RelaxedQ15MulrSVecI16x8,
DotI8x16I7x16SToVecI16x8,
DotI8x16I7x16UToVecI16x8,

InvalidBinary
};
Expand Down Expand Up @@ -552,6 +554,8 @@ enum SIMDTernaryOp {
LaneselectI16x8,
LaneselectI32x4,
LaneselectI64x2,
DotI8x16I7x16AddSToVecI32x4,
DotI8x16I7x16AddUToVecI32x4,
};

enum RefIsOp {
Expand Down
30 changes: 23 additions & 7 deletions src/wasm/literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2372,17 +2372,33 @@ Literal Literal::pmaxF64x2(const Literal& other) const {
return binary<2, &Literal::getLanesF64x2, &Literal::pmax>(*this, other);
}

Literal Literal::dotSI16x8toI32x4(const Literal& other) const {
LaneArray<8> lhs = getLanesSI16x8();
LaneArray<8> rhs = other.getLanesSI16x8();
LaneArray<4> result;
for (size_t i = 0; i < 4; ++i) {
result[i] = Literal(lhs[i * 2].geti32() * rhs[i * 2].geti32() +
lhs[i * 2 + 1].geti32() * rhs[i * 2 + 1].geti32());
template<size_t Lanes,
size_t Factor,
LaneArray<Lanes * Factor> (Literal::*IntoLanes)() const>
static Literal dot(const Literal& left, const Literal& right) {
LaneArray<Lanes* Factor> lhs = (left.*IntoLanes)();
LaneArray<Lanes* Factor> rhs = (right.*IntoLanes)();
LaneArray<Lanes> result;
for (size_t i = 0; i < Lanes; ++i) {
result[i] = Literal(int32_t(0));
for (size_t j = 0; j < Factor; ++j) {
result[i] = Literal(result[i].geti32() + lhs[i * Factor + j].geti32() *
rhs[i * Factor + j].geti32());
}
}
return Literal(result);
}

Literal Literal::dotSI8x16toI16x8(const Literal& other) const {
return dot<8, 2, &Literal::getLanesSI8x16>(*this, other);
}
Literal Literal::dotUI8x16toI16x8(const Literal& other) const {
return dot<8, 2, &Literal::getLanesUI8x16>(*this, other);
}
Literal Literal::dotSI16x8toI32x4(const Literal& other) const {
return dot<4, 2, &Literal::getLanesSI16x8>(*this, other);
}

Literal Literal::bitselectV128(const Literal& left,
const Literal& right) const {
return andV128(left).orV128(notV128().andV128(right));
Expand Down
16 changes: 16 additions & 0 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5601,6 +5601,14 @@ bool WasmBinaryBuilder::maybeVisitSIMDBinary(Expression*& out, uint32_t code) {
curr = allocator.alloc<Binary>();
curr->op = RelaxedQ15MulrSVecI16x8;
break;
case BinaryConsts::I16x8DotI8x16I7x16S:
curr = allocator.alloc<Binary>();
curr->op = DotI8x16I7x16SToVecI16x8;
break;
case BinaryConsts::I16x8DotI8x16I7x16U:
curr = allocator.alloc<Binary>();
curr->op = DotI8x16I7x16UToVecI16x8;
break;
default:
return false;
}
Expand Down Expand Up @@ -6075,6 +6083,14 @@ bool WasmBinaryBuilder::maybeVisitSIMDTernary(Expression*& out, uint32_t code) {
curr = allocator.alloc<SIMDTernary>();
curr->op = RelaxedFmsVecF64x2;
break;
case BinaryConsts::I32x4DotI8x16I7x16AddS:
curr = allocator.alloc<SIMDTernary>();
curr->op = DotI8x16I7x16AddSToVecI32x4;
break;
case BinaryConsts::I32x4DotI8x16I7x16AddU:
curr = allocator.alloc<SIMDTernary>();
curr->op = DotI8x16I7x16AddUToVecI32x4;
break;
default:
return false;
}
Expand Down
14 changes: 14 additions & 0 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,12 @@ void BinaryInstWriter::visitSIMDTernary(SIMDTernary* curr) {
case RelaxedFmsVecF64x2:
o << U32LEB(BinaryConsts::F64x2RelaxedFms);
break;
case DotI8x16I7x16AddSToVecI32x4:
o << U32LEB(BinaryConsts::I32x4DotI8x16I7x16AddS);
break;
case DotI8x16I7x16AddUToVecI32x4:
o << U32LEB(BinaryConsts::I32x4DotI8x16I7x16AddU);
break;
}
}

Expand Down Expand Up @@ -1846,6 +1852,14 @@ void BinaryInstWriter::visitBinary(Binary* curr) {
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8RelaxedQ15MulrS);
break;
case DotI8x16I7x16SToVecI16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8DotI8x16I7x16S);
break;
case DotI8x16I7x16UToVecI16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8DotI8x16I7x16U);
break;

case InvalidBinary:
WASM_UNREACHABLE("invalid binary op");
Expand Down
4 changes: 3 additions & 1 deletion src/wasm/wasm-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,9 @@ void FunctionValidator::visitBinary(Binary* curr) {
case NarrowUVecI32x4ToVecI16x8:
case SwizzleVecI8x16:
case RelaxedSwizzleVecI8x16:
case RelaxedQ15MulrSVecI16x8: {
case RelaxedQ15MulrSVecI16x8:
case DotI8x16I7x16SToVecI16x8:
case DotI8x16I7x16UToVecI16x8: {
shouldBeEqualOrFirstIsUnreachable(
curr->left->type, Type(Type::v128), curr, "v128 op");
shouldBeEqualOrFirstIsUnreachable(
Expand Down
112 changes: 112 additions & 0 deletions test/lit/relaxed-simd.wast
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,88 @@
)
)

;; CHECK-BINARY: (func $i16x8.dot_i8x16_i7x16_s (param $0 v128) (param $1 v128) (result v128)
;; CHECK-BINARY-NEXT: (i16x8.dot_i8x16_i7x16_s
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i16x8.dot_i8x16_i7x16_s (param $0 v128) (param $1 v128) (result v128)
;; CHECK-TEXT-NEXT: (i16x8.dot_i8x16_i7x16_s
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i16x8.dot_i8x16_i7x16_s (param $0 v128) (param $1 v128) (result v128)
(i16x8.dot_i8x16_i7x16_s
(local.get $0)
(local.get $1)
)
)

;; CHECK-BINARY: (func $i16x8.dot_i8x16_i7x16_u (param $0 v128) (param $1 v128) (result v128)
;; CHECK-BINARY-NEXT: (i16x8.dot_i8x16_i7x16_u
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i16x8.dot_i8x16_i7x16_u (param $0 v128) (param $1 v128) (result v128)
;; CHECK-TEXT-NEXT: (i16x8.dot_i8x16_i7x16_u
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i16x8.dot_i8x16_i7x16_u (param $0 v128) (param $1 v128) (result v128)
(i16x8.dot_i8x16_i7x16_u
(local.get $0)
(local.get $1)
)
)

;; CHECK-BINARY: (func $i32x4.dot_i8x16_i7x16_add_s (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-BINARY-NEXT: (i32x4.dot_i8x16_i7x16_add_s
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: (local.get $2)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i32x4.dot_i8x16_i7x16_add_s (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-TEXT-NEXT: (i32x4.dot_i8x16_i7x16_add_s
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: (local.get $2)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i32x4.dot_i8x16_i7x16_add_s (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
(i32x4.dot_i8x16_i7x16_add_s
(local.get $0)
(local.get $1)
(local.get $2)
)
)

;; CHECK-BINARY: (func $i32x4.dot_i8x16_i7x16_add_u (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-BINARY-NEXT: (i32x4.dot_i8x16_i7x16_add_u
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: (local.get $2)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i32x4.dot_i8x16_i7x16_add_u (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-TEXT-NEXT: (i32x4.dot_i8x16_i7x16_add_u
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: (local.get $2)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i32x4.dot_i8x16_i7x16_add_u (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
(i32x4.dot_i8x16_i7x16_add_u
(local.get $0)
(local.get $1)
(local.get $2)
)
)

)
;; CHECK-NODEBUG: (type $v128_v128_v128_=>_v128 (func (param v128 v128 v128) (result v128)))

Expand Down Expand Up @@ -507,3 +589,33 @@
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $18 (param $0 v128) (param $1 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i16x8.dot_i8x16_i7x16_s
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $19 (param $0 v128) (param $1 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i16x8.dot_i8x16_i7x16_u
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $20 (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i32x4.dot_i8x16_i7x16_add_s
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: (local.get $2)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $21 (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i32x4.dot_i8x16_i7x16_add_u
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: (local.get $2)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

0 comments on commit ce0333d

Please sign in to comment.