Skip to content

Commit

Permalink
[INTERPRETER] Support unary ops (triton-lang#3279)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Mar 5, 2024
1 parent eaf3395 commit 5dbd842
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def kernel(Z, X, SIZE: tl.constexpr):
z_ref = eval(expr if numpy_expr is None else numpy_expr)
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x)
kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas)
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
Expand Down Expand Up @@ -849,6 +849,7 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# ---------------


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, expr",
[(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x')
for dtype_x in int_dtypes])
Expand All @@ -862,6 +863,7 @@ def test_unary_op(dtype_x, expr, num_ctas, device):
# ----------------


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x)
for dtype_x in ["float32", "float64"]
for expr in ['exp', 'log', 'cos', 'sin']
Expand All @@ -875,6 +877,7 @@ def test_math_op(dtype_x, expr, device, x):
# ----------------


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16])
def test_abs(dtype_x, device):
_test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device)
Expand Down
9 changes: 8 additions & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,13 @@ def create_advance(self, ptr, offsets):
ret.offsets[i].data += offsets[i].data
return ret

def get_all_ones_value(self, type):
np_type = self.np_dtype(type)
if "int" in np_type.name:
return TensorHandle(np.full(1, -1, dtype=np_type), type)
else:
raise TypeError(f"unsupported type {type}")


def _patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (member(*args, **
Expand Down Expand Up @@ -577,7 +584,7 @@ def fallback(*args, **kwargs):
for name, member in inspect.getmembers(math):
if name in mapping:
setattr(math, name, make_numpy(name))
else:
elif callable(member): # We only wrap functions
setattr(math, name, make_fallback(name))


Expand Down

0 comments on commit 5dbd842

Please sign in to comment.