Skip to content

Commit

Permalink
[interpreter] Implement SIMD extended multiply instructions
Browse files Browse the repository at this point in the history
These were accepted into the proposal in WebAssembly#376.

There are 12 instructions in total:

- i16x8.extmul_{low,high}_i8x16_{s,u}
- i32x4.extmul_{low,high}_i16x8_{s,u}
- i64x2.extmul_{low,high}_i32x4_{s,u}

The implementation is straightforward, widen (using existing
operations), then a multiply with the wider shape.

Added a test generation script that reuses some logic in the generator
for arithmetic instructions. Since these instructions have different
src and dst shapes, I tweaked the base class to allow for having
different shapes.
  • Loading branch information
ngzhian committed Feb 2, 2021
1 parent 7554a37 commit 4ce490f
Show file tree
Hide file tree
Showing 12 changed files with 1,380 additions and 25 deletions.
12 changes: 12 additions & 0 deletions interpreter/exec/eval_simd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ module SimdOp (SXX : Simd.S) (Value : ValueType with type t = SXX.t) = struct
| I16x8 MaxS -> SXX.I16x8.max_s
| I16x8 MaxU -> SXX.I16x8.max_u
| I16x8 AvgrU -> SXX.I16x8.avgr_u
| I16x8 ExtMulLowS -> SXX.I16x8_convert.extmul_low_s
| I16x8 ExtMulHighS -> SXX.I16x8_convert.extmul_high_s
| I16x8 ExtMulLowU -> SXX.I16x8_convert.extmul_low_u
| I16x8 ExtMulHighU -> SXX.I16x8_convert.extmul_high_u
| I32x4 Add -> SXX.I32x4.add
| I32x4 Sub -> SXX.I32x4.sub
| I32x4 MinS -> SXX.I32x4.min_s
Expand All @@ -119,9 +123,17 @@ module SimdOp (SXX : Simd.S) (Value : ValueType with type t = SXX.t) = struct
| I32x4 GeS -> SXX.I32x4.ge_s
| I32x4 GeU -> SXX.I32x4.ge_u
| I32x4 DotI16x8S -> SXX.I32x4_convert.dot_i16x8_s
| I32x4 ExtMulLowS -> SXX.I32x4_convert.extmul_low_s
| I32x4 ExtMulHighS -> SXX.I32x4_convert.extmul_high_s
| I32x4 ExtMulLowU -> SXX.I32x4_convert.extmul_low_u
| I32x4 ExtMulHighU -> SXX.I32x4_convert.extmul_high_u
| I64x2 Add -> SXX.I64x2.add
| I64x2 Sub -> SXX.I64x2.sub
| I64x2 Mul -> SXX.I64x2.mul
| I64x2 ExtMulLowS -> SXX.I64x2_convert.extmul_low_s
| I64x2 ExtMulHighS -> SXX.I64x2_convert.extmul_high_s
| I64x2 ExtMulLowU -> SXX.I64x2_convert.extmul_low_u
| I64x2 ExtMulHighU -> SXX.I64x2_convert.extmul_high_u
| F32x4 Eq -> SXX.F32x4.eq
| F32x4 Ne -> SXX.F32x4.ne
| F32x4 Lt -> SXX.F32x4.lt
Expand Down
38 changes: 34 additions & 4 deletions interpreter/exec/simd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ sig
val widen_high_s : t -> t
val widen_low_u : t -> t
val widen_high_u : t -> t
val extmul_low_s : t -> t -> t
val extmul_high_s : t -> t -> t
val extmul_low_u : t -> t -> t
val extmul_high_u : t -> t -> t
end
module I32x4_convert : sig
val trunc_sat_f32x4_s : t -> t
Expand All @@ -186,10 +190,20 @@ sig
val widen_low_u : t -> t
val widen_high_u : t -> t
val dot_i16x8_s : t -> t -> t
val extmul_low_s : t -> t -> t
val extmul_high_s : t -> t -> t
val extmul_low_u : t -> t -> t
val extmul_high_u : t -> t -> t
end
module I64x2_convert : sig
val widen_low_s : t -> t
val widen_high_s : t -> t
val widen_low_u : t -> t
val widen_high_u : t -> t
val extmul_low_s : t -> t -> t
val extmul_high_s : t -> t -> t
val extmul_low_u : t -> t -> t
val extmul_high_u : t -> t -> t
end
module F32x4_convert : sig
val convert_i32x4_s : t -> t
Expand Down Expand Up @@ -417,6 +431,10 @@ struct
let widen_low_u = widen Lib.List.take 0xffl
let widen_high_u = widen Lib.List.drop 0xffl

let extmul_low_s x y = I16x8.mul (widen_low_s x) (widen_low_s y)
let extmul_high_s x y = I16x8.mul (widen_high_s x) (widen_high_s y)
let extmul_low_u x y = I16x8.mul (widen_low_u x) (widen_low_u y)
let extmul_high_u x y = I16x8.mul (widen_high_u x) (widen_high_u y)
end

module I32x4_convert = struct
Expand All @@ -441,16 +459,28 @@ struct
| [], [] -> []
| _, _ -> assert false
in Rep.of_i32x4 (dot xs ys)

let extmul_low_s x y = I32x4.mul (widen_low_s x) (widen_low_s y)
let extmul_high_s x y = I32x4.mul (widen_high_s x) (widen_high_s y)
let extmul_low_u x y = I32x4.mul (widen_low_u x) (widen_low_u y)
let extmul_high_u x y = I32x4.mul (widen_high_u x) (widen_high_u y)
end

module I64x2_convert = struct
let widen mask x =
let widen take_or_drop mask x =
Rep.of_i64x2
(List.map
(fun i32 -> Int64.(logand mask (of_int32 i32)))
(Lib.List.take 2 (Rep.to_i32x4 x)))
let widen_low_s = widen 0xffffffffffffffffL
let widen_low_u = widen 0xffffffffL
(take_or_drop 2 (Rep.to_i32x4 x)))
let widen_low_s = widen Lib.List.take 0xffffffffffffffffL
let widen_high_s = widen Lib.List.drop 0xffffffffffffffffL
let widen_low_u = widen Lib.List.take 0xffffffffL
let widen_high_u = widen Lib.List.drop 0xffffffffL

let extmul_low_s x y = I64x2.mul (widen_low_s x) (widen_low_s y)
let extmul_high_s x y = I64x2.mul (widen_high_s x) (widen_high_s y)
let extmul_low_u x y = I64x2.mul (widen_low_u x) (widen_low_u y)
let extmul_high_u x y = I64x2.mul (widen_high_u x) (widen_high_u y)
end

module F32x4_convert = struct
Expand Down
1 change: 1 addition & 0 deletions interpreter/syntax/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct
| Swizzle | Shuffle of int list | NarrowS | NarrowU
| AddSatS | AddSatU | SubSatS | SubSatU
| DotI16x8S
| ExtMulLowS | ExtMulHighS | ExtMulLowU | ExtMulHighU
type funop = Abs | Neg | Sqrt
| Ceil | Floor | Trunc | Nearest
| ConvertI32x4S | ConvertI32x4U
Expand Down
12 changes: 12 additions & 0 deletions interpreter/syntax/operators.ml
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ let i16x8_min_u = Binary (V128 V128Op.(I16x8 MinU))
let i16x8_max_s = Binary (V128 V128Op.(I16x8 MaxS))
let i16x8_max_u = Binary (V128 V128Op.(I16x8 MaxU))
let i16x8_avgr_u = Binary (V128 V128Op.(I16x8 AvgrU))
let i16x8_extmul_low_i8x16_s = Binary (V128 V128Op.(I16x8 ExtMulLowS))
let i16x8_extmul_high_i8x16_s = Binary (V128 V128Op.(I16x8 ExtMulHighS))
let i16x8_extmul_low_i8x16_u = Binary (V128 V128Op.(I16x8 ExtMulLowU))
let i16x8_extmul_high_i8x16_u = Binary (V128 V128Op.(I16x8 ExtMulHighU))

let i32x4_splat = Convert (V128 V128Op.(I32x4 Splat))
let i32x4_extract_lane imm = SimdExtract (V128Op.I32x4 (ZX, imm))
Expand Down Expand Up @@ -375,6 +379,10 @@ let i32x4_mul = Binary (V128 V128Op.(I32x4 Mul))
let i32x4_trunc_sat_f32x4_s = Unary (V128 V128Op.(I32x4 TruncSatF32x4S))
let i32x4_trunc_sat_f32x4_u = Unary (V128 V128Op.(I32x4 TruncSatF32x4U))
let i32x4_dot_i16x8_s = Binary (V128 V128Op.(I32x4 DotI16x8S))
let i32x4_extmul_low_i16x8_s = Binary (V128 V128Op.(I32x4 ExtMulLowS))
let i32x4_extmul_high_i16x8_s = Binary (V128 V128Op.(I32x4 ExtMulHighS))
let i32x4_extmul_low_i16x8_u = Binary (V128 V128Op.(I32x4 ExtMulLowU))
let i32x4_extmul_high_i16x8_u = Binary (V128 V128Op.(I32x4 ExtMulHighU))

let i64x2_splat = Convert (V128 V128Op.(I64x2 Splat))
let i64x2_extract_lane imm = SimdExtract (V128Op.I64x2 (ZX, imm))
Expand All @@ -386,6 +394,10 @@ let i64x2_mul = Binary (V128 V128Op.(I64x2 Mul))
let i64x2_shl = SimdShift V128Op.(I64x2 Shl)
let i64x2_shr_s = SimdShift V128Op.(I64x2 ShrS)
let i64x2_shr_u = SimdShift V128Op.(I64x2 ShrU)
let i64x2_extmul_low_i32x4_s = Binary (V128 V128Op.(I64x2 ExtMulLowS))
let i64x2_extmul_high_i32x4_s = Binary (V128 V128Op.(I64x2 ExtMulHighS))
let i64x2_extmul_low_i32x4_u = Binary (V128 V128Op.(I64x2 ExtMulLowU))
let i64x2_extmul_high_i32x4_u = Binary (V128 V128Op.(I64x2 ExtMulHighU))

let f32x4_splat = Convert (V128 V128Op.(F32x4 Splat))
let f32x4_extract_lane imm = SimdExtract (V128Op.F32x4 (ZX, imm))
Expand Down
13 changes: 13 additions & 0 deletions interpreter/text/lexer.mll
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,19 @@ rule token = parse
| "i32x4.dot_i16x8_s"
{ BINARY i32x4_dot_i16x8_s }

| "i16x8.extmul_low_i8x16_"(sign as s)
{ BINARY (ext s i16x8_extmul_low_i8x16_s i16x8_extmul_low_i8x16_u) }
| "i16x8.extmul_high_i8x16_"(sign as s)
{ BINARY (ext s i16x8_extmul_high_i8x16_s i16x8_extmul_high_i8x16_u) }
| "i32x4.extmul_low_i16x8_"(sign as s)
{ BINARY (ext s i32x4_extmul_low_i16x8_s i32x4_extmul_low_i16x8_u) }
| "i32x4.extmul_high_i16x8_"(sign as s)
{ BINARY (ext s i32x4_extmul_high_i16x8_s i32x4_extmul_high_i16x8_u) }
| "i64x2.extmul_low_i32x4_"(sign as s)
{ BINARY (ext s i64x2_extmul_low_i32x4_s i64x2_extmul_low_i32x4_u) }
| "i64x2.extmul_high_i32x4_"(sign as s)
{ BINARY (ext s i64x2_extmul_high_i32x4_s i64x2_extmul_high_i32x4_u) }

| (simd_shape as s) { SIMD_SHAPE (simd_shape s) }

| name as s { VAR s }
Expand Down
1 change: 1 addition & 0 deletions test/core/simd/meta/gen_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'simd_f64x2_pmin_pmax',
'simd_i32x4_dot_i16x8',
'simd_load_lane',
'simd_ext_mul',
)


Expand Down
21 changes: 17 additions & 4 deletions test/core/simd/meta/simd_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,27 @@ def __str__(self):
def lane(self):
return self.LANE_VALUE.get(self.LANE_TYPE)

@property
def dst_lane(self):
return self.lane

@property
def src_lane(self):
# Used for arithmetic that extends the lane, e.g. i16x8 lanes, which
# are extended multiply to i32x4.
if hasattr(self, 'SRC_LANE_TYPE'):
return self.LANE_VALUE.get(self.SRC_LANE_TYPE)
else:
return self.lane

@property
def normal_unary_op_test_data(self):
lane = self.lane
lane = self.src_lane
return [0, 1, -1, lane.max - 1, lane.min + 1, lane.min, lane.max, lane.mask]

@property
def normal_binary_op_test_data(self):
lane = self.lane
lane = self.src_lane
return [
(0, 0),
(0, 1),
Expand Down Expand Up @@ -170,7 +183,7 @@ def get_case_data(self):
for data_group, v128_forms in self.bin_test_data:
for data in data_group:
case_data.append([op_name, [str(data[0]), str(data[1])],
str(o.binary_op(data[0], data[1], self.lane)),
str(o.binary_op(data[0], data[1], self.src_lane, self.dst_lane)),
v128_forms])
for data_group in self.full_bin_test_data:
for data in data_group.get(op_name):
Expand All @@ -183,7 +196,7 @@ def get_case_data(self):
for data_group, v128_forms in self.unary_test_data:
for data in data_group:
case_data.append([op_name, [str(data)],
str(o.unary_op(data, self.lane)),
str(o.unary_op(data, self.dst_lane)),
v128_forms])

return case_data
Expand Down
63 changes: 56 additions & 7 deletions test/core/simd/meta/simd_ext_mul.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,74 @@
#!/usr/bin/env python3

"""
TODO comment
""" Base class for generating extended multiply instructions. These
instructions 2 inputs of the same (narrower) lane shape, multiplies
corresponding lanes with extension (no overflow/wraparound), producing 1 output
of a (wider) shape. These instructions can choose to work on the low or high
halves of the inputs, and perform signed or unsigned multiply.
Subclasses need to define 3 attributes:
- LANE_TYPE (this is the output shape)
- SRC_LANE_TYPE (this is the input (narrower) shape)
- BINARY_OPS (list of operations)
"""

from simd_arithmetic import SimdArithmeticCase


class SimdExtMulCase(SimdArithmeticCase):
UNARY_OPS = ()

@property
def full_bin_test_data(self):
return []

def get_combine_cases(self):
return ''

@property
def bin_test_data(self):
lane_forms = [self.SRC_LANE_TYPE, self.SRC_LANE_TYPE, self.LANE_TYPE]
return [(self.normal_binary_op_test_data, lane_forms)]

@property
def hex_binary_op_test_data(self):
return []

class SimdI16x8(SimdExtMulCase):
LANE_LEN = 16
def gen_test_cases(self):
wast_filename = '../simd_{wide}_extmul_{narrow}.wast'.format(
wide=self.LANE_TYPE, narrow=self.SRC_LANE_TYPE)
with open(wast_filename, 'w') as fp:
fp.write(self.get_all_cases())


class SimdI16x8ExtMulCase(SimdExtMulCase):
LANE_TYPE = 'i16x8'
SRC_LANE_TYPE = 'i8x16'
BINARY_OPS = ('extmul_low_i8x16_s', 'extmul_high_i8x16_s',
'extmul_low_i8x16_u', 'extmul_high_i8x16_u')
'extmul_low_i8x16_u', 'extmul_high_i8x16_u')


class SimdI32x4ExtMulCase(SimdExtMulCase):
LANE_TYPE = 'i32x4'
SRC_LANE_TYPE = 'i16x8'
BINARY_OPS = ('extmul_low_i16x8_s', 'extmul_high_i16x8_s',
'extmul_low_i16x8_u', 'extmul_high_i16x8_u')


class SimdI64x2ExtMulCase(SimdExtMulCase):
LANE_TYPE = 'i64x2'
SRC_LANE_TYPE = 'i32x4'
BINARY_OPS = ('extmul_low_i32x4_s', 'extmul_high_i32x4_s',
'extmul_low_i32x4_u', 'extmul_high_i32x4_u')


def gen_test_cases():
simd_i16x8_arith = SimdI16x8ArithmeticCase()
simd_i16x8_arith.gen_test_cases()
simd_i16x8_ext_mul_case = SimdI16x8ExtMulCase()
simd_i16x8_ext_mul_case.gen_test_cases()
simd_i32x4_ext_mul_case = SimdI32x4ExtMulCase()
simd_i32x4_ext_mul_case.gen_test_cases()
simd_i64x2_ext_mul_case = SimdI64x2ExtMulCase()
simd_i64x2_ext_mul_case.gen_test_cases()


if __name__ == '__main__':
Expand Down
32 changes: 22 additions & 10 deletions test/core/simd/meta/simd_integer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ArithmeticOp:
add_sat_s, add_sat_u,
sub_sat_s, sub_sat_u,
min_s, min_u, max_s, max_u, avgr_u, abs
ext_mul_s, ext_mul_u
"""
def __init__(self, op: str):
self.op = op
Expand Down Expand Up @@ -121,7 +122,7 @@ def unary_op(self, operand, lane):

return str(result)

def binary_op(self, operand1, operand2, lane):
def binary_op(self, operand1, operand2, src_lane, dst_lane=None):
"""General integer arithmetic and saturating arithmetic operations
with 2 operands.
Expand All @@ -130,12 +131,15 @@ def binary_op(self, operand1, operand2, lane):
add_sat_s, add_sat_u,
sub_sat_s, sub_sat_u,
min_s, min_u, max_s, max_u, avgr_u
ext_mul_s, ext_mul_u (same as mul)
:param operand1: the operand 1, integer or literal string in hex or decimal format
:param operand2: the operand 2, integer or literal string in hex or decimal format
:param lane: the LaneValue instance of a lane in v128
:param src_lane: the LaneValue instance of a lane in v128
:return: the string of the result of <p1 self.op p2> in hex or decimal format
"""
if not dst_lane:
dst_lane = src_lane
v1 = operand1
v2 = operand2
base1 = base2 = 10
Expand All @@ -155,27 +159,35 @@ def binary_op(self, operand1, operand2, lane):
value = v1 - v2
elif self.op == 'mul':
value = v1 * v2
elif self.op.startswith('extmul_'):
if self.op.endswith('s'):
i1 = self.get_valid_value(v1, src_lane)
i2 = self.get_valid_value(v2, src_lane)
else:
i1 = self.get_valid_value(v1, src_lane, signed=False)
i2 = self.get_valid_value(v2, src_lane, signed=False)
value = i1 * i2
elif 'sat' in self.op:
value = self._saturate(v1, v2, lane)
value = self._saturate(v1, v2, src_lane)
if self.op.endswith('_u'):
result_signed = False
elif self.op in ['min_s', 'max_s']:
i1 = self.get_valid_value(v1, lane)
i2 = self.get_valid_value(v2, lane)
i1 = self.get_valid_value(v1, src_lane)
i2 = self.get_valid_value(v2, src_lane)
if self.op == 'min_s':
return operand1 if i1 <= i2 else operand2
else:
return operand1 if i1 >= i2 else operand2
elif self.op in ['min_u', 'max_u']:
i1 = self.get_valid_value(v1, lane, signed=False)
i2 = self.get_valid_value(v2, lane, signed=False)
i1 = self.get_valid_value(v1, src_lane, signed=False)
i2 = self.get_valid_value(v2, src_lane, signed=False)
if self.op == 'min_u':
return operand1 if i1 <= i2 else operand2
else:
return operand1 if i1 >= i2 else operand2
elif self.op == 'avgr_u':
i1 = self.get_valid_value(v1, lane, signed=False)
i2 = self.get_valid_value(v2, lane, signed=False)
i1 = self.get_valid_value(v1, src_lane, signed=False)
i2 = self.get_valid_value(v2, src_lane, signed=False)
result = (i1 + i2 + 1) // 2
if base1 == 16 or base2 == 16:
return hex(result)
Expand All @@ -184,5 +196,5 @@ def binary_op(self, operand1, operand2, lane):
else:
raise Exception('Unknown binary operation')

result = self.get_valid_value(value, lane, signed=result_signed)
result = self.get_valid_value(value, dst_lane, signed=result_signed)
return str(result)
Loading

0 comments on commit 4ce490f

Please sign in to comment.