-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix][TIR]fix symbolic strides lower #15986
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -250,6 +250,34 @@ def transformed_strided_buffer_func( | |
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2) | ||
|
||
|
||
@T.prim_func | ||
def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: | ||
n = T.int64() | ||
A = T.match_buffer(a, (1, n, 10240), "float32") | ||
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated, the presence of this expression is kind of odd to me. Assuming this example came from a TIR printout, I would have expected |
||
with T.block(""): | ||
T.reads(A[0, i * 128 + j * 32:i * 128 + j * 32 + 96, k * 64:k * 64 + 64]) | ||
A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn") | ||
for ax0, ax1 in T.grid(96, 64): | ||
with T.block("A_pad_shared.dyn"): | ||
T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like the same |
||
T.reads(A[0, i * 128 + j * 32 + ax0, k * 64 + ax1]) | ||
T.writes(A_pad_shared_dyn[0, ax0, ax1]) | ||
A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(i * 128 + j * 32 + ax0 < n, A[0, i * 128 + j * 32 + ax0, k * 64 + ax1], T.float16(0)) | ||
|
||
|
||
@T.prim_func | ||
def transformed_symbolic_strided_buffer_func(a: T.handle): | ||
n = T.int64() | ||
A = T.match_buffer(a, (1, n, 10240)) | ||
for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the test case depend on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank u for the adivice, i've modified it using |
||
A_pad_shared_dyn = T.allocate([1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72], "float32", "shared.dyn") | ||
A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 64), data=A_pad_shared_dyn, strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72, 1), scope="shared.dyn") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expression |
||
for ax0, ax1 in T.grid(96, 64): | ||
if i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0) < (n + T.int64(63)) // T.int64(64) * T.int64(64): | ||
A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0) < n, A[0, i * T.int64(128) + T.Cast("int64", j) * T.int64(32) + T.Cast("int64", ax0), k * 64 + ax1], T.float32(0)) | ||
|
||
|
||
@T.prim_func | ||
def annotated_loops(a: T.handle) -> None: | ||
A = T.match_buffer(a, (16,), "float32") | ||
|
@@ -301,6 +329,10 @@ def test_strided_buffer(): | |
_check(compacted_strided_buffer_func, transformed_strided_buffer_func) | ||
|
||
|
||
def test_symbolic_strided_buffer(): | ||
_check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func) | ||
|
||
|
||
def test_lower_te(): | ||
x = te.placeholder((1,)) | ||
y = te.compute((1,), lambda i: x[i] + 2) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since your description mentions this as a separate bug, can it either be split out into a separate PR, or (since it is a relatively small change), have a test case added for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I open a new PR here . Please have a check.
I will open a new PR to fix dtype mismatch bug in PASS
InjectPTXAsyncCopy
after symbolic strides PR is merged