-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR, Schedule] Add schedule primitive PadEinsum #12750
Conversation
af3c595
to
ad826eb
Compare
|
||
|
||
@T.prim_func | ||
def matmul_expected( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compare to #12720 cc @Lunderberg
Could I understand that it equals with a bundle of operations in certain workload pattern? Like
for buffer in [A_shared, B_shared, C_shared]:
s.transpose_layout(buffer, (127, 127) -> (128, 128), pad_value=0)
for block in [A, B, C_shared]:
for axis in s.get_loops(block)
s.fuse(*s.split(axis, [1, 128]))
s.annotate(C_shared, "en_some_predicate_versus_overcomputation_selection", 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It pads the producers with init value (zero) and over-computes the reduction block
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
db1b3e5
to
9a0a81c
Compare
9a0a81c
to
09ed7ab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good, but just a few usability questions and potential improvements. I like seeing which assumptions are made here, which lead to a much simpler analysis than the more general case from the padding RFC.
I think the biggest question is the padding specified, and whether it can be specified as both a left/right padding, rather than only padding on the right.
include/tvm/tir/schedule/schedule.h
Outdated
* the output buffer and the producer buffer to be allocated inside the PrimFunc. | ||
* | ||
* The padding is a list of non-negative integers, each element corresponds to the padding for | ||
* each block iter in the order of block iters. The block and it's producer blocks should have |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: "it's" should be "its", without an apostrphe
* The output buffer and the producer buffer is resized according to the padding size. It requires | ||
* the output buffer and the producer buffer to be allocated inside the PrimFunc. | ||
* | ||
* The padding is a list of non-negative integers, each element corresponds to the padding for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the padding can only be applied to the end of an axis/iterator, and cannot be applied to the beginning. Could we specify two arrays of padding, one for the lower end each block iter and one for the upper end?
@Lunderberg The current assumption is to over compute the reduction block, and infer the padding of the producer. Since the padding is inferred from buffer access pattern, I think we can't specify the padding as tuple |
@vinx13 Thank you, and that makes sense. So, one of the simplifying assumptions that is all padding will only be on one side, and if the padding is allowed on both sides, that wouldn't just add a free parameter for the final output, but also for each producer. |
f7726ee
to
0becfce
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@vinx13 let's fix the following warnings:
|
* [TIR, Schedule] Add schedule primitive PadEinsum Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> * lint * [TIR] Fix producer indices check in PadEinsum * address comments * simplify lambda expr * fix Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Bohan Hou 32121147+spectrometerHBH@users.noreply.github.com
This PR adds a schedule primitive
PadEinsum
. It is used for computation in Einsum pattern specifically, which cover most cases for tensorization. Different from general cases for padding in https://github.com/apache/tvm-rfcs/blob/main/rfcs/0077-layout-transform-padding.md, this primitive pads the output blocks and the input blocks at once, which eliminates the need to extra arithmetic analysis to provide the guarantee of program correctness.cc @Hzfengsy @wrongtest-intellif @spectrometerHBH @Lunderberg