Skip to content

Commit

Permalink
Check consumer block iters are covered
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Sep 8, 2022
1 parent 15d90c8 commit ac4a9ec
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 15 deletions.
73 changes: 58 additions & 15 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
where A is the only buffer the block consumes, whose indices are distinct atomic variables,
and there should be no variables other than the index variables), and f is a bijective affine
mapping) and there should not be predicates in the inlined block.)";
mapping and there should not be predicates in the inlined block. The iter domains of the inlined
block should be covered by the producer block.)";

class HasInitBlock : public ScheduleError {
public:
Expand Down Expand Up @@ -534,13 +535,14 @@ class ReverseComputeInliner : public BaseInliner {
const BlockRealize& consumer_block_realize,
const StmtSRef& scope_root_sref)
: BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref),
producer_block_(producer_block) {
// Initialize the consumer's block predicate to ensure consumer block iters are in-bound,
// which should be inserted to its producer after inlining
consumer_predicate_ = Bool(true);
producer_block_(producer_block),
consumer_block_(consumer_block_realize->block.get()) {
// Initialize the predicates to ensure consumer block iters are in-bound
consumer_iter_in_bound_ = Bool(true);
for (const IterVar& iter : consumer_block_realize->block->iter_vars) {
consumer_predicate_ = consumer_predicate_ && (iter->var >= iter->dom->min &&
iter->var < iter->dom->min + iter->dom->extent);
consumer_iter_in_bound_ =
consumer_iter_in_bound_ &&
(iter->var >= iter->dom->min && iter->var < iter->dom->min + iter->dom->extent);
}
}

Expand Down Expand Up @@ -582,13 +584,25 @@ class ReverseComputeInliner : public BaseInliner {
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
/*check_level=*/arith::IterMapLevel::Bijective,
/*analyzer=*/&analyzer,
/*analyzer=*/&analyzer_,
/*simplify_trivial_iterators=*/false);
buffer_load_iter_map_ = res->indices;
if (buffer_load_iter_map_.empty()) {
// Failure: indices of BufferLoad are not bijective affine
return false;
}

const BufferStoreNode* producer_store = producer_block_->body.as<BufferStoreNode>();
if (producer_store == nullptr) {
// Failure: producer block body is not BufferStore
return false;
}
CreateInverseMapping(producer_store->indices);
if (!CheckConsumerCovered()) {
// Failure: consumer block iter domains are not covered by the producer block
return false;
}

return true;
}

Expand All @@ -602,13 +616,13 @@ class ReverseComputeInliner : public BaseInliner {
Map<Var, PrimExpr> subst_map;
for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) {
const IterVar& iter = producer_block_realize->block->iter_vars[i];
analyzer.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
subst_map.Set(iter->var, producer_block_realize->iter_values[i]);
}
// Substitute the consumer block iters with the corresponding iters in the producer blocks
PrimExpr predicate = Substituter(this)(consumer_predicate_);
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer.Simplify(predicate);
predicate = analyzer_.Simplify(predicate);
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
Expand All @@ -632,6 +646,32 @@ class ReverseComputeInliner : public BaseInliner {
return ReplaceInlinedBuffer(std::move(store));
}

/*!
* \brief Check the consumer block iter domains are covered by the producer block iter domains
* \return Whether the consumer block iter domains are covered
*/
bool CheckConsumerCovered() {
Map<IterVar, arith::IntSet> producer_iter_doms;
for (const IterVar& iter_var : producer_block_->iter_vars) {
producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom));
}
// For each block iter in the consumer block, find the corresponding expression in the producer
for (const IterVar& iter : consumer_block_->iter_vars) {
if (auto it = idx_sub_.find(iter->var.get()); it != idx_sub_.end()) {
const PrimExpr& producer_iter = it->second;
arith::IntSet producer_iter_range = arith::EvalSet(producer_iter, producer_iter_doms);
if (analyzer_.CanProve(producer_iter_range.min() > iter->dom->min) ||
analyzer_.CanProve(producer_iter_range.max() <
iter->dom->min + iter->dom->extent - 1)) {
return false;
}
} else {
return false;
}
}
return true;
}

/*!
* \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with
* the result. It will be later used to transform the BufferStore indices of the producer.
Expand All @@ -645,7 +685,6 @@ class ReverseComputeInliner : public BaseInliner {
}

Stmt ReplaceInlinedBuffer(BufferStore producer) {
CreateInverseMapping(producer->indices);
producer_rhs_ = producer->value;
return Substituter(this)(GetRef<BufferStore>(inlined_store_));
}
Expand Down Expand Up @@ -702,10 +741,14 @@ class ReverseComputeInliner : public BaseInliner {
Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
/*! \brief The producer block */
const BlockNode* producer_block_{nullptr};
/*! \brief The predicate to ensure the consumer block iters are in-bound */
PrimExpr consumer_predicate_{nullptr};
/* \brief The consumer block */
const BlockNode* consumer_block_{nullptr};
/*! \brief The predicate to ensure the consumer block iters are in-bound. It will be inserted
* as the predicate of the producer block after inlining.
*/
PrimExpr consumer_iter_in_bound_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer;
arith::Analyzer analyzer_;
};

void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,22 @@ def elementwise_overcomputed_producer_reverse_inlined(
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


@T.prim_func
def elementwise_producer_not_cover_consumer(
A: T.Buffer[(128, 128), "float32"],
D: T.Buffer[(256, 128), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(256, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], T.float32(0), dtype="float32")


# pylint: enable=no-member,invalid-name,unused-variable

use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
Expand Down Expand Up @@ -858,5 +874,15 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name):
)


def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name):
"""Test reverse compute inline failure when the inlined block iter domains are not covered by
its producer
"""
sch = tir.Schedule(elementwise_producer_not_cover_consumer, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
with pytest.raises(tvm.tir.ScheduleError):
sch.reverse_compute_inline(compute)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ac4a9ec

Please sign in to comment.