Skip to content

Commit

Permalink
[INTERPRETER] Fix mask and other fields in load and store ops (
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Mar 6, 2024
1 parent 233c88b commit 441d6b1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
29 changes: 21 additions & 8 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def check_type_supported(dtype, device):
if is_interpreter():
if dtype in [
tl.float8e4nv, "float8e4nv", tl.bfloat16, "bfloat16", tl.float8e5, "float8e5", tl.float8e4b15,
"float8e4b15", tl.float8e4b15x4, "float8e4b15x4"
"float8e4b15", tl.float8e4b15x4, "float8e4b15x4", torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2
]:
pytest.skip("bfloat16 and float8 are not supported in the interpreter")

Expand Down Expand Up @@ -387,6 +387,7 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas)


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]])
def test_addptr(dtype, order, device):
check_type_supported(dtype, device)
Expand Down Expand Up @@ -1047,6 +1048,7 @@ def tuples_fn(a, b):
a * b


@pytest.mark.interpreter
def test_tuples(device):

@triton.jit
Expand Down Expand Up @@ -3185,6 +3187,7 @@ def kernel(out_ptr):
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
def test_dot_without_load(dtype_str, device):
if is_cuda():
Expand Down Expand Up @@ -3218,6 +3221,7 @@ def _kernel(out, ALLOW_TF32: tl.constexpr):
# ---------------


@pytest.mark.interpreter
@pytest.mark.parametrize("start", [0, 1, 7, 16])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_arange(start, num_ctas, device):
Expand All @@ -3240,12 +3244,14 @@ def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr):
# ---------------


@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff)
for dtype_str in torch_dtypes
for size in [128, 512]
for size_diff in [0, 1, 2, 3, 4]])
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other)
for dtype_str in torch_dtypes
for size in [128, 512]
for size_diff in [0, 1, 2, 3, 4]
for other in [0, 1]])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_masked_load(dtype_str, size, size_diff, num_ctas, device):
def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested

Expand All @@ -3268,11 +3274,11 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)

mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None"
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas)

reference_out = torch.cat((input, torch.ones((size_diff, ), dtype=dtype, device=device)))
reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device)))
# print((output - reference_out).nonzero())
torch.testing.assert_close(output, reference_out)

Expand All @@ -3281,6 +3287,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):


# FIXME: Shape too small for ldmatrix when num_ctas=4
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device):

Expand Down Expand Up @@ -3323,6 +3330,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_
torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0)


@pytest.mark.interpreter
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache, device):
src = torch.empty(128, device=device)
Expand Down Expand Up @@ -3407,6 +3415,7 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
# ---------------


@pytest.mark.interpreter
@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"])
def test_store_cache_modifier(cache, device):
src = torch.empty(128, device=device)
Expand Down Expand Up @@ -3460,6 +3469,7 @@ def _impl(value=10):
return value


@pytest.mark.interpreter
def test_default(device):
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device=device)
Expand Down Expand Up @@ -3595,6 +3605,7 @@ def kernel(Z, X, Y):
np.testing.assert_allclose(z, to_numpy(z_tri))


@pytest.mark.interpreter
def test_constexpr_shape(device):

@triton.jit
Expand All @@ -3607,6 +3618,7 @@ def kernel(X):
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))


@pytest.mark.interpreter
def test_constexpr_scalar_shape(device):

@triton.jit
Expand Down Expand Up @@ -4601,6 +4613,7 @@ def do_test(src_layout, dst_layout):
do_test(mma_pair[1], mma_pair[0])


@pytest.mark.interpreter
def test_load_scalar_with_mask(device):

@triton.jit
Expand Down
9 changes: 6 additions & 3 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,9 +1460,11 @@ def load(pointer, mask=None, other=None, boundary_check=tuple(), padding_option=
:type volatile: bool, optional
"""
# `mask` and `other` can be constexpr
if _constexpr_to_value(mask) is not None:
mask = _constexpr_to_value(mask)
other = _constexpr_to_value(other)
if mask is not None:
mask = _to_tensor(mask, _builder)
if _constexpr_to_value(other) is not None:
if other is not None:
other = _to_tensor(other, _builder)
padding_option = _constexpr_to_value(padding_option)
cache_modifier = _constexpr_to_value(cache_modifier)
Expand Down Expand Up @@ -1513,7 +1515,8 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
"""
# `value` can be constexpr
value = _to_tensor(value, _builder)
if _constexpr_to_value(mask) is not None:
mask = _constexpr_to_value(mask)
if mask is not None:
mask = _to_tensor(mask, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
Expand Down
19 changes: 10 additions & 9 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,9 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_

# Make `mask` and `other` into the same shape as `ptr`
if ptr.type.is_block():
if mask:
if mask is not None:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if other:
if other is not None:
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)

# Get `pointer_type<elt_ty>` and `elt_ty`
Expand All @@ -1008,7 +1008,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
ptr = cast(ptr, ptr_ty, builder)

# Cast `other` into `ele_ty` type
if other:
if other is not None:
other = cast(other, elt_ty, builder)

# Create loaded result type `dst_ty`
Expand All @@ -1028,8 +1028,9 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
is_volatile), dst_ty)


def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str,
cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor:
def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple,
padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool,
builder: ir.builder) -> tl.tensor:
# Cache, eviction and padding options
cache = _str_to_load_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
Expand All @@ -1046,7 +1047,7 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor],
def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
# Store by a block pointer: `pointer_type<block_type<>>`
# Block pointers can not have the `mask` argument
if mask:
if mask is not None:
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")

# Check same shape and element type
Expand Down Expand Up @@ -1093,7 +1094,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
# Make `mask` and `val` into the same shape as `ptr`
if ptr.type.is_block():
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
if mask:
if mask is not None:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)

ptr_ty = ptr.type.scalar
Expand Down Expand Up @@ -1154,9 +1155,9 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
raise ValueError("atomic_" + op + " does not support " + str(element_ty))
if ptr.type.is_block():
if mask:
if mask is not None:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if val:
if val is not None:
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
val = cast(val, ptr.type.scalar.element_ty, builder)
if not mask:
Expand Down

0 comments on commit 441d6b1

Please sign in to comment.