-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix][TIR]fix symbolic strides lower #15986
Conversation
def transformed_symbolic_strided_buffer_func(a: T.handle): | ||
n = T.int64() | ||
A = T.match_buffer(a, (1, n, 10240)) | ||
for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the test case depend on T.int64
datatypes? If not, this would be much more readable by using T.int32
. Because it is the default integer type in TVMScript, it wouldn't require the explicit type conversions. (e.g. (n + 63)
instead of (n + T.int64(63))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank u for the adivice, i've modified it using T.int32
and pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64))
in the test case.
A = T.match_buffer(a, (1, n, 10240)) | ||
for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160): | ||
A_pad_shared_dyn = T.allocate([1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72], "float32", "shared.dyn") | ||
A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 64), data=A_pad_shared_dyn, strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72, 1), scope="shared.dyn") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression T.min((n + T.int64(63)) // T.int64(64) * T.int64(64)
occurs frequently, and makes it difficult to read. Can this be pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64))
? The generated TIR will still contain the full expression, but the test case can be easier to read.
A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn") | ||
for ax0, ax1 in T.grid(96, 64): | ||
with T.block("A_pad_shared.dyn"): | ||
T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like the same T.reads
and T.writes
annotations as would be automatically inferred from the block's body. Unless the test depends on a specific override to use non-default read/write annotations, it should be removed for readability.
def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: | ||
n = T.int64() | ||
A = T.match_buffer(a, (1, n, 10240), "float32") | ||
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated, the presence of this expression is kind of odd to me. Assuming this example came from a TIR printout, I would have expected ((n + 63) // 64 * 4 + 7) // 8
to be simplified to the equivalent (n + 127) // 128
. The fact that it didn't simplify may indicate that I should take a look at the CanonicalSimplifier
.
@@ -79,7 +80,7 @@ class PTXAsyncCopyInjector : public StmtMutator { | |||
if (indices_lanes == 1) { | |||
auto src_offset = load->indices[0]; | |||
auto dst_offset = store->indices[0]; | |||
Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), | |||
Array<PrimExpr> args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since your description mentions this as a separate bug, can it either be split out into a separate PR, or (since it is a relatively small change), have a test case added for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I open a new PR here . Please have a check.
I will open a new PR to fix dtype mismatch bug in PASS InjectPTXAsyncCopy
after symbolic strides PR is merged
compact_buffer_region
PASS modify shared buffer stride[0] toT.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96))
and stride[1] isT.int64(72)
but in LowerOpaqueBlock PASS it report error:
InternalError: Check failed: (is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))) is false:
For more detaied discuss, see here
Another bug occurs in PASS InjectPTXAsyncCopy .
that is dst_offset.dtype could be int64, the dtype of PrimExpr(index_factor) would be set to default to int32.
cause dtype inconsistent when calling tir::Mul.
To reproduce the problem in InjectPTXAsyncCopy, see script here