Skip to content

Commit

Permalink
[Fix][TIR]fix symbolic strides lower (#16000)
Browse files Browse the repository at this point in the history
* [Fix][TIR]fix symbolic strides lower

* [Fix][TIR] run the black formatter
  • Loading branch information
JackWeiw authored Oct 30, 2023
1 parent 7eedea5 commit 57597f6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
if (buffer->strides.size()) {
ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
ICHECK(
arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0));
alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
}
}
Expand Down
48 changes: 48 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_opaque_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,50 @@ def transformed_strided_buffer_func(
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)


@T.prim_func
def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
n = T.int32()
A = T.match_buffer(a, (1, n, 10240))
padded_size = T.meta_var(T.min((n + 63) // 64 * 64, 96))
# with T.block("root"):
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
with T.block(""):
A_pad_shared_dyn = T.alloc_buffer(
(1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn"
)
for ax0, ax1 in T.grid(96, 64):
with T.block("A_pad_shared.dyn"):
T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(
i * 128 + j * 32 + ax0 < n,
A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
T.float32(0),
)


@T.prim_func
def transformed_symbolic_strided_buffer_func(a: T.handle):
n = T.int32()
A = T.match_buffer(a, (1, n, 10240))
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
A_pad_shared_dyn = T.allocate(
[1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn"
)
A_pad_shared_dyn_1 = T.decl_buffer(
(1, T.min((n + 63) // 64 * 64, 96), 64),
data=A_pad_shared_dyn,
strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
scope="shared.dyn",
)
for ax0, ax1 in T.grid(96, 64):
if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64:
A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(
i * 128 + j * 32 + ax0 < n,
A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
T.float32(0),
)


@T.prim_func
def annotated_loops(a: T.handle) -> None:
A = T.match_buffer(a, (16,), "float32")
Expand Down Expand Up @@ -301,6 +345,10 @@ def test_strided_buffer():
_check(compacted_strided_buffer_func, transformed_strided_buffer_func)


def test_symbolic_strided_buffer():
_check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func)


def test_lower_te():
x = te.placeholder((1,))
y = te.compute((1,), lambda i: x[i] + 2)
Expand Down

0 comments on commit 57597f6

Please sign in to comment.