Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][TIR]fix symbolic strides lower #15986

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -79,7 +80,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
Array<PrimExpr> args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since your description mentions this as a separate bug, can it either be split out into a separate PR, or (since it is a relatively small change), have a test case added for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I open a new PR here . Please have a check.
I will open a new PR to fix dtype mismatch bug in PASS InjectPTXAsyncCopy after symbolic strides PR is merged

load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated cp.async
if (predicated) {
Expand Down Expand Up @@ -114,7 +115,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
}();
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
} else {
Expand Down Expand Up @@ -144,7 +145,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes), predicate_value}));
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
if (buffer->strides.size()) {
ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
ICHECK(
arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0));
alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
}
}
Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_opaque_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,34 @@ def transformed_strided_buffer_func(
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)


@T.prim_func
def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
n = T.int64()
A = T.match_buffer(a, (1, n, 10240), "float32")
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, the presence of this expression is kind of odd to me. Assuming this example came from a TIR printout, I would have expected ((n + 63) // 64 * 4 + 7) // 8 to be simplified to the equivalent (n + 127) // 128. The fact that it didn't simplify may indicate that I should take a look at the CanonicalSimplifier.

with T.block(""):
T.reads(A[0, i * 128 + j * 32:i * 128 + j * 32 + 96, k * 64:k * 64 + 64])
A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn")
for ax0, ax1 in T.grid(96, 64):
with T.block("A_pad_shared.dyn"):
T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like the same T.reads and T.writes annotations as would be automatically inferred from the block's body. Unless the test depends on a specific override to use non-default read/write annotations, it should be removed for readability.

T.reads(A[0, i * 128 + j * 32 + ax0, k * 64 + ax1])
T.writes(A_pad_shared_dyn[0, ax0, ax1])
A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(i * 128 + j * 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float16(0))


@T.prim_func
def transformed_symbolic_strided_buffer_func(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (1, n, 10240))
for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the test case depend on T.int64 datatypes? If not, this would be much more readable by using T.int32. Because it is the default integer type in TVMScript, it wouldn't require the explicit type conversions. (e.g. (n + 63) instead of (n + T.int64(63)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank u for the adivice, i've modified it using T.int32 and pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64)) in the test case.

A_pad_shared_dyn = T.allocate([1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72], "float32", "shared.dyn")
A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 64), data=A_pad_shared_dyn, strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72, 1), scope="shared.dyn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression T.min((n + T.int64(63)) // T.int64(64) * T.int64(64) occurs frequently, and makes it difficult to read. Can this be pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64))? The generated TIR will still contain the full expression, but the test case can be easier to read.

for ax0, ax1 in T.grid(96, 64):
if i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0) < (n + T.int64(63)) // T.int64(64) * T.int64(64):
A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0) < n, A[0, i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0), k * 64 + ax1], T.float32(0))


@T.prim_func
def annotated_loops(a: T.handle) -> None:
A = T.match_buffer(a, (16,), "float32")
Expand Down Expand Up @@ -301,6 +329,10 @@ def test_strided_buffer():
_check(compacted_strided_buffer_func, transformed_strided_buffer_func)


def test_symbolic_strided_buffer():
_check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func)


def test_lower_te():
x = te.placeholder((1,))
y = te.compute((1,), lambda i: x[i] + 2)
Expand Down
Loading