Skip to content

Commit

Permalink
[TIR, Schedule] Generate consumer-in-bound predicate after reverse_co…
Browse files Browse the repository at this point in the history
…mpute_inline
  • Loading branch information
vinx13 committed Sep 6, 2022
1 parent 2e83e03 commit 58ff35d
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 11 deletions.
82 changes: 71 additions & 11 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
where A is the only buffer the block consumes, whose indices are distinct atomic variables,
and there should be no variables other than the index variables), and f is a bijective affine
mapping)";
mapping) and there should not be predicates in the inlined block.)";

class HasInitBlock : public ScheduleError {
public:
Expand Down Expand Up @@ -161,16 +161,25 @@ class NonSingleProducerError : public ScheduleError {
IRModule mod_;
Block block_;

static void Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
/*!
* \brief Check if the block has a single producer.
* \param self The schedule state
* \param block_sref The sref of the block to be checked
* \param scope_root_sref The sref of the scope root
* \return The sref of the producer block if the block has a single producer
* \throw ScheduleError if the block does not have a single producer
*/
static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<Dependency> producers = scope->GetDepsByDst(consumer_block_sref);
StmtSRef producer_block_sref{nullptr};
if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) {
const StmtSRef& producer_block_sref = producers[0]->src;
producer_block_sref = producers[0]->src;
if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) {
Array<Dependency> consumers = scope->GetDepsBySrc(producer_block_sref);
if (consumers.size() == 1) {
return;
return producer_block_sref;
}
}
}
Expand Down Expand Up @@ -521,11 +530,27 @@ class ReverseComputeInliner : public BaseInliner {
};

public:
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const Block& consumer_block,
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block,
const BlockRealize& consumer_block_realize,
const StmtSRef& scope_root_sref)
: BaseInliner(inlined_buffer, consumer_block, scope_root_sref) {}
: BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref),
producer_block_(producer_block) {
// Initialize the consumer's block predicate to ensure consumer block iters are in-bound,
// which should be inserted to its producer after inlining
consumer_predicate_ = Bool(true);
for (const IterVar& iter : consumer_block_realize->block->iter_vars) {
consumer_predicate_ = consumer_predicate_ && (iter->var >= iter->dom->min &&
iter->var < iter->dom->min + iter->dom->extent);
}
}

bool BodyPatternAllowInline(const Block& consumer_block) {
bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) {
const Block& consumer_block = consumer_block_realize->block;

if (!is_one(consumer_block_realize->predicate)) {
// Failure: Predicate is the consumer block is not supported
return false;
}
if (inlined_store_ == nullptr) {
// Failure: block body is not BufferStore
return false;
Expand Down Expand Up @@ -571,6 +596,34 @@ class ReverseComputeInliner : public BaseInliner {
using BaseInliner::VisitExpr_;
using BaseInliner::VisitStmt_;

/*! \brief Generate the predicate after inlining based on the consumer predicate */
PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) {
const IterVar& iter = producer_block_realize->block->iter_vars[i];
analyzer.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
subst_map.Set(iter->var, producer_block_realize->iter_values[i]);
}
// Substitute the consumer block iters with the corresponding iters in the producer blocks
PrimExpr predicate = Substituter(this)(consumer_predicate_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer.Simplify(predicate);
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
return predicate;
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (op->block.get() == producer_block_) {
new_block_realize.CopyOnWrite()->predicate =
BuildInlinedConsumerPredicate(new_block_realize.get());
}
return std::move(new_block_realize);
}

Stmt VisitStmt_(const BufferStoreNode* _store) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_store));
if (!store->buffer.same_as(inlined_buffer_)) {
Expand Down Expand Up @@ -647,6 +700,10 @@ class ReverseComputeInliner : public BaseInliner {
Array<PrimExpr> buffer_load_indices_;
/*! \brief The IterMap representing the indices of the consumer's BufferLoad */
Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
/*! \brief The producer block */
const BlockNode* producer_block_{nullptr};
/*! \brief The predicate to ensure the consumer block iters are in-bound */
PrimExpr consumer_predicate_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer;
};
Expand Down Expand Up @@ -700,6 +757,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
bool check_only = false) {
const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
Block consumer_block = GetRef<Block>(_consumer_block);
BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref);
HasInitBlock::Check(self->mod, consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
Expand All @@ -709,10 +767,12 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
// Step 2. Check completeness
CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
// Step 3. Check if the consumer has a single complete producer
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
StmtSRef producer_block_sref =
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
// Step 4. Analyze the block body
ReverseComputeInliner inliner(inlined_buffer, consumer_block, scope_root_sref);
if (!inliner.BodyPatternAllowInline(consumer_block)) {
ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs<BlockNode>(),
consumer_block_realize, scope_root_sref);
if (!inliner.BodyPatternAllowInline(consumer_block_realize)) {
throw BodyAnalysisError(true, self->mod, consumer_block);
}
// Step 5. Create a plan that removes the leaf block to be inlined
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,32 @@ def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
)


@T.prim_func
def elementwise_overcomputed_producer(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(127, 127):
with T.block("C"):
cvi, cvj = T.axis.remap("SS", [i, j])
C[cvi, cvj] = B[cvi, cvj] + 1.0


@T.prim_func
def elementwise_overcomputed_producer_reverse_inlined(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"]
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.where(i < 127 and j < 127)
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


# pylint: enable=no-member,invalid-name,unused-variable

use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
Expand Down Expand Up @@ -822,5 +848,15 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name):
)


def test_reverse_compute_inline_overcomputed_producer(use_block_name):
"""Test reverse compute inline overcomputed producer"""
sch = tir.Schedule(elementwise_overcomputed_producer, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
sch.reverse_compute_inline(compute)
tvm.ir.assert_structural_equal(
elementwise_overcomputed_producer_reverse_inlined, sch.mod["main"]
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 58ff35d

Please sign in to comment.