Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTERPRETER] Fix mask and other fields in load and store ops #3288

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def check_type_supported(dtype, device):
if is_interpreter():
if dtype in [
tl.float8e4nv, "float8e4nv", tl.bfloat16, "bfloat16", tl.float8e5, "float8e5", tl.float8e4b15,
"float8e4b15", tl.float8e4b15x4, "float8e4b15x4"
"float8e4b15", tl.float8e4b15x4, "float8e4b15x4", torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2
]:
pytest.skip("bfloat16 and float8 are not supported in the interpreter")

Expand Down Expand Up @@ -387,6 +387,7 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas)


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]])
def test_addptr(dtype, order, device):
check_type_supported(dtype, device)
Expand Down Expand Up @@ -1047,6 +1048,7 @@ def tuples_fn(a, b):
a * b


@pytest.mark.interpreter
def test_tuples(device):

@triton.jit
Expand Down Expand Up @@ -3185,6 +3187,7 @@ def kernel(out_ptr):
assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
def test_dot_without_load(dtype_str, device):
if is_cuda():
Expand Down Expand Up @@ -3218,6 +3221,7 @@ def _kernel(out, ALLOW_TF32: tl.constexpr):
# ---------------


@pytest.mark.interpreter
@pytest.mark.parametrize("start", [0, 1, 7, 16])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_arange(start, num_ctas, device):
Expand All @@ -3240,12 +3244,14 @@ def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr):
# ---------------


@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff)
for dtype_str in torch_dtypes
for size in [128, 512]
for size_diff in [0, 1, 2, 3, 4]])
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other)
for dtype_str in torch_dtypes
for size in [128, 512]
for size_diff in [0, 1, 2, 3, 4]
for other in [0, 1]])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_masked_load(dtype_str, size, size_diff, num_ctas, device):
def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested

Expand All @@ -3268,11 +3274,11 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)

mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None"
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas)

reference_out = torch.cat((input, torch.ones((size_diff, ), dtype=dtype, device=device)))
reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device)))
# print((output - reference_out).nonzero())
torch.testing.assert_close(output, reference_out)

Expand All @@ -3281,6 +3287,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):


# FIXME: Shape too small for ldmatrix when num_ctas=4
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device):

Expand Down Expand Up @@ -3323,6 +3330,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_
torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0)


@pytest.mark.interpreter
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
def test_load_cache_modifier(cache, device):
src = torch.empty(128, device=device)
Expand Down Expand Up @@ -3407,6 +3415,7 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
# ---------------


@pytest.mark.interpreter
@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"])
def test_store_cache_modifier(cache, device):
src = torch.empty(128, device=device)
Expand Down Expand Up @@ -3460,6 +3469,7 @@ def _impl(value=10):
return value


@pytest.mark.interpreter
def test_default(device):
value = 5
ret0 = torch.zeros(1, dtype=torch.int32, device=device)
Expand Down Expand Up @@ -3595,6 +3605,7 @@ def kernel(Z, X, Y):
np.testing.assert_allclose(z, to_numpy(z_tri))


@pytest.mark.interpreter
def test_constexpr_shape(device):

@triton.jit
Expand All @@ -3607,6 +3618,7 @@ def kernel(X):
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))


@pytest.mark.interpreter
def test_constexpr_scalar_shape(device):

@triton.jit
Expand Down Expand Up @@ -4601,6 +4613,7 @@ def do_test(src_layout, dst_layout):
do_test(mma_pair[1], mma_pair[0])


@pytest.mark.interpreter
def test_load_scalar_with_mask(device):

@triton.jit
Expand Down
19 changes: 10 additions & 9 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,9 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_

# Make `mask` and `other` into the same shape as `ptr`
if ptr.type.is_block():
if mask:
if mask is not None:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if other:
if other is not None:
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)

# Get `pointer_type<elt_ty>` and `elt_ty`
Expand All @@ -1008,7 +1008,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
ptr = cast(ptr, ptr_ty, builder)

# Cast `other` into `ele_ty` type
if other:
if other is not None:
other = cast(other, elt_ty, builder)

# Create loaded result type `dst_ty`
Expand All @@ -1028,8 +1028,9 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
is_volatile), dst_ty)


def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str,
cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor:
def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple,
padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool,
builder: ir.builder) -> tl.tensor:
# Cache, eviction and padding options
cache = _str_to_load_cache_modifier(cache_modifier)
eviction = _str_to_eviction_policy(eviction_policy)
Expand All @@ -1046,7 +1047,7 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor],
def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
# Store by a block pointer: `pointer_type<block_type<>>`
# Block pointers can not have the `mask` argument
if mask:
if mask is not None:
raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")

# Check same shape and element type
Expand Down Expand Up @@ -1093,7 +1094,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
# Make `mask` and `val` into the same shape as `ptr`
if ptr.type.is_block():
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
if mask:
if mask is not None:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)

ptr_ty = ptr.type.scalar
Expand Down Expand Up @@ -1154,9 +1155,9 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
raise ValueError("atomic_" + op + " does not support " + str(element_ty))
if ptr.type.is_block():
if mask:
if mask is not None:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if val:
if val is not None:
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
val = cast(val, ptr.type.scalar.element_ty, builder)
if not mask:
Expand Down
Loading