Skip to content

Commit

Permalink
Fix overview for int dot to allow saturation of 16-bit intermediates
Browse files Browse the repository at this point in the history
  • Loading branch information
ngzhian committed Feb 27, 2023
1 parent 917bd5c commit c423e63
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions proposals/relaxed-simd/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,45 @@ When the second operand of the product has the high bit set in a lane, that
lane's result is implementation defined.

```python
def dot_product(signed, elements, a, b, c):
def i16x8_dot_i8x16_i7x16_s(a, b):
intermediate = []
result = []
for i in range(16):
if (b[i] & 0x80):
lhs = as_signed(a[i]) if signed else a[i]
rhs = IMPLEMENTATION_DEFINED_ONE_OF(as_signed(b[i]), b[i])
lhs = as_signed(a[i])
rhs = IMPLEMENTATION_DEFINED_ONE_OF(as_signed(b[i]), as_unsigned(b[i]))
intermediate[i] = lhs * rhs
else:
intermediate[i] = (as_signed(a[i]) if signed else a[i]) * b[i]
for i in range(0, 16, elements):
result[i/elements] = sum(intermediate[i:i+elements])
result[i/elements] += c[i/elements] if c else 0
intermediate[i] = as_signed(a[i]) * b[i]
for i in range(0, 16, 2):
result[i/2] = IMPLEMENTATION_DEFINED_ONE_OF(
intermediate[i] + intermediate[i+1],
saturate(intermediate[i] + intermediate[i+1]))

i16x8_dot_i8x16_i7x16_s(a, b) = dot_product(signed=True, elements=2, a, b)
i32x4.dot_i8x16_i7x16_add_s(a, b, c) = dot_product(signed=True, elements=4, a, b, c)
def i32x4.dot_i8x16_i7x16_add_s(a, b, c):
intermediate = []
tmp = []
result = []
for i in range(16):
if (b[i] & 0x80):
lhs = as_signed(a[i])
rhs = IMPLEMENTATION_DEFINED_ONE_OF(as_signed(b[i]), as_unsigned(b[i]))
intermediate[i] = lhs * rhs
else:
intermediate[i] = as_signed(a[i]) * b[i]

for i in range(0, 16, 2):
tmp[i/2] = IMPLEMENTATION_DEFINED_ONE_OF(
intermediate[i] + intermediate[i+1],
saturate(intermediate[i] + intermediate[i+1]))

for i in range(0, 8, 2):
dst[i/4] = tmp[i] + tmp[i+1]

for i in range(0, 4):
dst[i] += c[i]

saturate(x) = min(INT16_MAX, max(INT16_MIN, x))
```

## Binary format
Expand Down

0 comments on commit c423e63

Please sign in to comment.