Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Sep 15, 2022
1 parent 48ac825 commit 0becfce
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ class ScheduleNode : public runtime::Object {
* the output buffer and the producer buffer to be allocated inside the PrimFunc.
*
* The padding is a list of non-negative integers, each element corresponds to the padding for
* each block iter in the order of block iters. The block and it's producer blocks should have
* each block iter in the order of block iters. The block and its producer blocks should have
* trivial bindings, i.e. each block iter is bound to a single loop variable. After padding, the
* block iter extent and the corresponding outer loop is extended by the padding size.
*
Expand Down
3 changes: 1 addition & 2 deletions src/tir/schedule/primitive/pad_einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integ
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
InvalidPaddingError::Check(self, GetRef<Block>(block), padding);

const Array<StmtSRef> producers = GetProducers(self, block_sref);
{
auto f_check_block_properties = [&](const StmtSRef& block_sref, bool is_producer) {
CheckBlockHasTrivialBinding(self, block_sref);
Expand All @@ -331,7 +332,6 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integ
f_check_block_properties(block_sref, false);

// Check block properties of the producer block
const Array<StmtSRef> producers = GetProducers(self, block_sref);
for (const StmtSRef& producer_sref : producers) {
f_check_block_properties(producer_sref, true);
}
Expand Down Expand Up @@ -384,7 +384,6 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array<Integ

buffer_remap.Set(einsum.output_buffer, f_pad_buffer(einsum.output_buffer, einsum.output_indices));

// std::unordered_set<const BlockNode*> producers;
std::unordered_map<const BlockNode*, PrimExpr> producer_predicate;

// Different from the output block, the padding for the producer block is not directly specified
Expand Down

0 comments on commit 0becfce

Please sign in to comment.