Skip to content

Commit

Permalink
[FRONTEND] Change sort to use integer types for tensor slicing. (trit…
Browse files Browse the repository at this point in the history
…on-lang#2744)

This will avoid potential bad problem with special floating point values
as well as allow optimizing multiplication by 0.
  • Loading branch information
ThomasRaoux authored Dec 3, 2023
1 parent f2bc68e commit 2fd0fa3
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions python/triton/language/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,25 @@ def _indicator(n_dims: core.constexpr, idx: core.constexpr, pos: core.constexpr)
return y


@jit
def _cast_to_int(x):
y = x
if x.dtype.is_floating():
if core.constexpr(x.dtype.primitive_bitwidth) == 16:
dtype_int = core.int16
elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
dtype_int = core.int32
elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
dtype_int = core.int64
else:
raise ValueError("Unsupported dtype")
y = x.to(dtype_int, bitcast=True)
return y


@jit
def _take_slice(x, n_dims: core.constexpr, idx: core.constexpr, pos: core.constexpr, keep_dim: core.constexpr = True):
y = sum(x * _indicator(n_dims, idx, pos), n_dims - 1 - idx)
y = sum(x * _indicator(n_dims, idx, pos).to(x.dtype), n_dims - 1 - idx)
if keep_dim:
y = core.expand_dims(y, n_dims - 1 - idx)

Expand All @@ -317,24 +333,12 @@ def _take_slice(x, n_dims: core.constexpr, idx: core.constexpr, pos: core.conste

@jit
def _compare_and_swap(x, desc_mask, n_dims: core.constexpr, idx: core.constexpr):
l = _take_slice(x, n_dims, idx, 0)
r = _take_slice(x, n_dims, idx, 1)
x_int = _cast_to_int(x)
l_int = _take_slice(x_int, n_dims, idx, 0)
r_int = _take_slice(x_int, n_dims, idx, 1)
l = l_int.to(x.dtype, bitcast=True)
r = r_int.to(x.dtype, bitcast=True)

x_int = x
l_int = l
r_int = r
if x.dtype.is_floating():
if core.constexpr(x.dtype.primitive_bitwidth) == 16:
dtype_int = core.int16
elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
dtype_int = core.int32
elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
dtype_int = core.int64
else:
raise ValueError("Unsupported dtype")
x_int = x.to(dtype_int, bitcast=True)
l_int = l.to(dtype_int, bitcast=True)
r_int = r.to(dtype_int, bitcast=True)
desc_mask = desc_mask.to(x_int.dtype)
zero = zeros_like(x_int)
y = x_int ^ core.where((l > r) ^ desc_mask, l_int ^ r_int, zero)
Expand Down

0 comments on commit 2fd0fa3

Please sign in to comment.