Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, Schedule] Add schedule primitive PadEinsum #12750

Merged
merged 6 commits into from
Sep 15, 2022

Conversation

vinx13
Copy link
Member

@vinx13 vinx13 commented Sep 9, 2022

Co-authored-by: Bohan Hou 32121147+spectrometerHBH@users.noreply.github.com

This PR adds a schedule primitive PadEinsum. It is used for computation in Einsum pattern specifically, which cover most cases for tensorization. Different from general cases for padding in https://github.com/apache/tvm-rfcs/blob/main/rfcs/0077-layout-transform-padding.md, this primitive pads the output blocks and the input blocks at once, which eliminates the need to extra arithmetic analysis to provide the guarantee of program correctness.

cc @Hzfengsy @wrongtest-intellif @spectrometerHBH @Lunderberg



@T.prim_func
def matmul_expected(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compare to #12720 cc @Lunderberg
Could I understand that it equals with a bundle of operations in certain workload pattern? Like

for buffer in [A_shared, B_shared, C_shared]:
     s.transpose_layout(buffer, (127, 127) -> (128, 128), pad_value=0)
for block in [A, B, C_shared]:
     for axis in s.get_loops(block)
         s.fuse(*s.split(axis, [1, 128]))
s.annotate(C_shared, "en_some_predicate_versus_overcomputation_selection", 1)
        

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It pads the producers with init value (zero) and over-computes the reduction block

Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
@vinx13 vinx13 force-pushed the feat/tir-pad-einsum branch 2 times, most recently from db1b3e5 to 9a0a81c Compare September 12, 2022 21:31
@vinx13 vinx13 force-pushed the feat/tir-pad-einsum branch from 9a0a81c to 09ed7ab Compare September 12, 2022 23:08
Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good, but just a few usability questions and potential improvements. I like seeing which assumptions are made here, which lead to a much simpler analysis than the more general case from the padding RFC.

I think the biggest question is the padding specified, and whether it can be specified as both a left/right padding, rather than only padding on the right.

* the output buffer and the producer buffer to be allocated inside the PrimFunc.
*
* The padding is a list of non-negative integers, each element corresponds to the padding for
* each block iter in the order of block iters. The block and it's producer blocks should have
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: "it's" should be "its", without an apostrphe

* The output buffer and the producer buffer is resized according to the padding size. It requires
* the output buffer and the producer buffer to be allocated inside the PrimFunc.
*
* The padding is a list of non-negative integers, each element corresponds to the padding for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the padding can only be applied to the end of an axis/iterator, and cannot be applied to the beginning. Could we specify two arrays of padding, one for the lower end each block iter and one for the upper end?

src/tir/schedule/primitive/pad_einsum.cc Outdated Show resolved Hide resolved
src/tir/schedule/primitive/pad_einsum.cc Outdated Show resolved Hide resolved
src/tir/schedule/primitive/pad_einsum.cc Outdated Show resolved Hide resolved
@vinx13
Copy link
Member Author

vinx13 commented Sep 14, 2022

@Lunderberg The current assumption is to over compute the reduction block, and infer the padding of the producer. Since the padding is inferred from buffer access pattern, I think we can't specify the padding as tuple

@Lunderberg
Copy link
Contributor

@vinx13 Thank you, and that makes sense. So, one of the simplifying assumptions that is all padding will only be on one side, and if the padding is allowed on both sides, that wouldn't just add a free parameter for the final output, but also for each producer.

@vinx13 vinx13 force-pushed the feat/tir-pad-einsum branch from f7726ee to 0becfce Compare September 15, 2022 00:33
Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@vinx13 vinx13 merged commit 1f8b5de into apache:main Sep 15, 2022
@junrushao
Copy link
Member

@vinx13 let's fix the following warnings:

/root/Projects/tvm-dev/src/tir/schedule/primitive/pad_einsum.cc:231:8: warning: 'tvm::tir::PadEinsumRewriter::VisitStmt_' hides overloaded virtual function [-Woverloaded-virtual]
  Stmt VisitStmt_(const ForNode* op) final {
       ^
/root/Projects/tvm-dev/src/tir/schedule/primitive/.././transform.h:134:8: note: hidden overloaded virtual function 'tvm::tir::ReplaceBufferMutator::VisitStmt_' declared here: type mismatch at 1st parameter ('const tvm::tir::BufferStoreNode *' vs 'const tvm::tir::ForNode *')
  Stmt VisitStmt_(const BufferStoreNode* op) final;
       ^
/root/Projects/tvm-dev/src/tir/schedule/primitive/pad_einsum.cc:374:47: warning: lambda capture 'buffer_remap' is not used [-Wunused-lambda-capture]
  auto f_pad_buffer = [&padded_iter_extents, &buffer_remap](Buffer buffer,
                                           ~~~^~~~~~~~~~~~

xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* [TIR, Schedule] Add schedule primitive PadEinsum

Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>

* lint

* [TIR] Fix producer indices check in PadEinsum

* address comments

* simplify lambda expr

* fix

Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants