Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DietCode] Local Padding #11793

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ TVM_DLL const Op& reinterpret();
*/
TVM_DLL const Op& likely();

/*!
* \brief Marks a condition as affecting the region size.
*/
TVM_DLL const Op& affect_region_size();

/*!
* \brief Bitwise and operator.
*/
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,13 @@ TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false
* \return The marked expression.
*/
TVM_DLL PrimExpr likely(PrimExpr cond, Span span = Span());
/*!
* \brief Mark condition as affecting the region size.
* \param cond The condition
* \param span The location of this operation in the source.
* \return The marked expression.
*/
TVM_DLL PrimExpr affect_region_size(PrimExpr cond, Span span = Span());
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
Expand Down
14 changes: 12 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ TVM_DLL Pass CoProcSync();
*/
TVM_DLL Pass LiftAttrScope(String attr_key);

/*!
* \brief Pad tensors according to the local workspace size. This is to remove predicates at the
* compute body that could negatively affect the performance.
*
* \param enable_local_pad Whether local padding is enabled.
*
* \return The pass.
*/
TVM_DLL Pass LocalPad(bool enable_local_pad = false);

/*!
* \brief partition loops in the stmt.
*
Expand Down Expand Up @@ -440,10 +450,10 @@ TVM_DLL Pass ConvertBlocksToOpaque();
*
* \endcode
*
*
* \param enable_local_pad Whether local padding is enabled in the downstream.
* \return The pass.
*/
TVM_DLL Pass CompactBufferAllocation();
TVM_DLL Pass CompactBufferAllocation(bool enable_local_pad = false);

/*!
* This pass legalizes packed calls by wrapping their arguments into TVMValues
Expand Down
25 changes: 23 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ def LiftAttrScope(attr_key: str):
return _ffi_api.LiftAttrScope(attr_key) # type: ignore


def LocalPad(enable_local_pad: bool = False):
"""Pad tensors by the size of the local workspace.

Parameters
----------
enable_local_pad : bool
Whether local padding has been enabled.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LocalPad(enable_local_pad) # type: ignore


def LoopPartition():
"""Inject virtual thread loops.

Expand Down Expand Up @@ -735,7 +751,7 @@ def ConvertBlocksToOpaque():
return _ffi_api.ConvertBlocksToOpaque() # type: ignore


def CompactBufferAllocation():
def CompactBufferAllocation(enable_local_pad: bool = False):
"""Compact the buffer access region. by removing the buffer regions
that are not accessed, i.e. narrowing the buffer shape and adjust
the access region if necessary.
Expand Down Expand Up @@ -771,13 +787,18 @@ def CompactBufferAllocation():
for j in range(0, 16):
C[i, j] = B[0, j] + 1

Parameters
----------
enable_local_pad : bool
Whether local padding has been enabled.

Returns
-------
fpass : tvm.transform.Pass
The result pass

"""
return _ffi_api.CompactBufferAllocation() # type: ignore
return _ffi_api.CompactBufferAllocation(enable_local_pad) # type: ignore


def LowerMatchBuffer():
Expand Down
9 changes: 6 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_local_pad", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
Expand Down Expand Up @@ -139,9 +140,10 @@ TVM_REGISTER_GLOBAL("driver.get_binds")
return out_arr;
});

Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
Array<tvm::transform::Pass> CreatePassList(bool simple_mode) {
ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
transform::PassContext pass_ctx = transform::PassContext::Current();

bool enable_local_pad = pass_ctx->GetConfig<Bool>("tir.enable_local_pad", Bool(false)).value();
bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
bool disable_storage_rewrite =
pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite", Bool(false)).value();
Expand Down Expand Up @@ -200,7 +202,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::CompactBufferAllocation(enable_local_pad));
pass_list.push_back(tir::transform::LocalPad(enable_local_pad));
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
Expand All @@ -213,7 +216,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());

// PHASE 2
if (!disable_loop_partition) {
if (!simple_mode) {
pass_list.push_back(tir::transform::LoopPartition());
}

Expand Down
5 changes: 5 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ TIR_DEFINE_BUILTIN_FUNC(likely)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
.set_attr<TVectorizable>("TVectorizable", true);

TIR_DEFINE_BUILTIN_FUNC(affect_region_size)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
.set_attr<TVectorizable>("TVectorizable", true);

TIR_DEFINE_BUILTIN_FUNC(bitwise_and)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
Expand Down
5 changes: 5 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,11 @@ PrimExpr likely(PrimExpr cond, Span span) {
return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, span);
}

// no effect on the region size
PrimExpr affect_region_size(PrimExpr cond, Span span) {
return tir::Call(cond.dtype(), tir::builtin::affect_region_size(), {cond}, span);
}

// operator>
PrimExpr operator>(PrimExpr a, PrimExpr b) { return greater(a, b); }
PrimExpr greater(PrimExpr a, PrimExpr b, Span span) {
Expand Down
125 changes: 105 additions & 20 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,31 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
}

inline bool CheckSameAccessRegion(const Region& lhs, const Region& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (size_t region_idx = 0; region_idx < lhs.size(); ++region_idx) {
if (!StructuralEqual()(lhs[region_idx], rhs[region_idx])) {
return false;
}
}
return true;
}

/*!
* \brief Collect the access region of each buffer.
* \note The param buffer regions will not be collected.
*/
class BufferAccessRegionCollector : public StmtExprVisitor {
public:
static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> Collect(
const PrimFunc& f) {
static std::pair<std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>,
std::unordered_map<BlockRealize, PrimExpr, ObjectPtrHash, ObjectPtrEqual>>
Collect(const PrimFunc& f) {
BufferAccessRegionCollector collector;
collector(f->body);
return std::move(collector.buffer_access_region_);
return std::make_pair(collector.buffer_access_region_,
collector.block_predicate_with_annotations_);
}

private:
Expand Down Expand Up @@ -257,10 +271,31 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}

void VisitStmt_(const BlockRealizeNode* op) final {
PrimExpr cur_predicate = predicate_in_scope;
predicate_in_scope = op->predicate;
PrimExpr cur_predicate = predicate_in_scope_;
predicate_in_scope_ = op->predicate;
std::vector<PrimExpr> cur_predicate_in_scope_subexprs = predicate_in_scope_subexprs_;
predicate_in_scope_subexprs_ = DecomposePredicate(predicate_in_scope_);
std::unordered_set<size_t> cur_affect_region_size_indices = affect_region_size_indices_;
affect_region_size_indices_.clear();

StmtExprVisitor::VisitStmt_(op);
predicate_in_scope = cur_predicate;

std::vector<PrimExpr> predicate_subexprs_with_annotations;
predicate_subexprs_with_annotations.reserve(predicate_in_scope_subexprs_.size());
for (size_t subexpr_i = 0; subexpr_i < predicate_in_scope_subexprs_.size(); ++subexpr_i) {
if (affect_region_size_indices_.count(subexpr_i)) {
predicate_subexprs_with_annotations.push_back(
affect_region_size(predicate_in_scope_subexprs_[subexpr_i]));
} else {
predicate_subexprs_with_annotations.push_back(predicate_in_scope_subexprs_[subexpr_i]);
}
}
block_predicate_with_annotations_[GetRef<BlockRealize>(op)] =
FlattenPredicateSubExprs(predicate_subexprs_with_annotations);

predicate_in_scope_ = cur_predicate;
predicate_in_scope_subexprs_ = cur_predicate_in_scope_subexprs;
affect_region_size_indices_ = cur_affect_region_size_indices;
}

/**************** Helper functions ****************/
Expand Down Expand Up @@ -288,7 +323,21 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
// Step 2. Relax the access region
NDIntSet nd_int_set =
NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, &dom_analyzer_);
NDIntSetEval(buffer_region->region, predicate_in_scope_, dom_map_, &dom_analyzer_);
Region narrowed_buffer_region = SimplifyAndNarrowBufferRegionFromNDIntSet(
nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_);

for (size_t subexpr_i = 0; subexpr_i < predicate_in_scope_subexprs_.size(); ++subexpr_i) {
PrimExpr flattened_predicate =
FlattenPredicateSubExprs(predicate_in_scope_subexprs_, subexpr_i);
NDIntSet nd_int_set_wo_subexpr =
NDIntSetEval(buffer_region->region, flattened_predicate, dom_map_, &dom_analyzer_);
Region narrowed_buffer_region_wo_subexpr = SimplifyAndNarrowBufferRegionFromNDIntSet(
nd_int_set_wo_subexpr, buffer->shape, &dom_analyzer_, ancestor_loops_);
if (!CheckSameAccessRegion(narrowed_buffer_region, narrowed_buffer_region_wo_subexpr)) {
affect_region_size_indices_.insert(subexpr_i);
}
}
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const VarNode* v = ancestor_loops_[i]->loop_var.get();
Expand Down Expand Up @@ -345,8 +394,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
*/
std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash, ObjectPtrEqual>
buffer_var_in_scope_;
/*! \brief The block predicate of current scope */
PrimExpr predicate_in_scope{true};
/*! \brief The block predicate of the current scope */
PrimExpr predicate_in_scope_{true};
/*! \brief The sub-expressions of the predicate of the current scope */
std::vector<PrimExpr> predicate_in_scope_subexprs_;
/*! \brief The set of indicates that affect the buffer size. */
std::unordered_set<size_t> affect_region_size_indices_;

/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
Expand All @@ -358,6 +411,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> relaxed_accesses_;
/*! \brief The map from Buffer to it entire access region, used for returning. */
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> buffer_access_region_;
/*! \brief The map form BlockRealize to its annotated predicate. */
std::unordered_map<BlockRealize, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
block_predicate_with_annotations_;
/*! \brief The map from Buffer to it's access regions annotated by current block. */
std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual>
access_annotations_;
Expand Down Expand Up @@ -398,7 +454,10 @@ class BufferCompactor : public StmtExprMutator {
const PrimFunc& f,
const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
const std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>&
storage_align) {
storage_align,
std::unordered_map<BlockRealize, PrimExpr, ObjectPtrHash, ObjectPtrEqual>&&
block_predicate_with_annotations,
bool enable_local_pad) {
std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;

for (const auto& kv : regions) {
Expand All @@ -419,7 +478,8 @@ class BufferCompactor : public StmtExprMutator {
}
buffer_info.emplace(buffer, std::move(buffer_alloc_info));
}
BufferCompactor compactor(std::move(buffer_info));
BufferCompactor compactor(std::move(buffer_info), std::move(block_predicate_with_annotations),
enable_local_pad);
Stmt stmt = compactor(f->body);
return stmt;
}
Expand Down Expand Up @@ -447,9 +507,14 @@ class BufferCompactor : public StmtExprMutator {
explicit BufferAllocInfo(Region region) : region(std::move(region)) {}
};

explicit BufferCompactor(
std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info)
: buffer_info_(std::move(buffer_info)) {}
BufferCompactor(
std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual>&& buffer_info,
std::unordered_map<BlockRealize, PrimExpr, ObjectPtrHash, ObjectPtrEqual>&&
block_predicate_with_annotations,
bool enable_local_pad)
: buffer_info_(std::move(buffer_info)),
block_predicate_with_annotations_(std::move(block_predicate_with_annotations)),
enable_local_pad_(enable_local_pad) {}

Stmt VisitStmt_(const BufferStoreNode* _op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
Expand All @@ -465,6 +530,18 @@ class BufferCompactor : public StmtExprMutator {
return std::move(load);
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
if (!enable_local_pad_) {
return StmtExprMutator::VisitStmt_(op);
}
BlockRealize ret = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
auto it = block_predicate_with_annotations_.find(GetRef<BlockRealize>(op));
if (it == block_predicate_with_annotations_.end()) {
return ret;
}
return BlockRealize(ret->iter_values, it->second, ret->block);
}

Stmt VisitStmt_(const BlockNode* op) final {
// Step 0. Check there is no Init part.
ICHECK(!op->init.defined());
Expand Down Expand Up @@ -580,17 +657,25 @@ class BufferCompactor : public StmtExprMutator {

/*! \brief The allocation information about each buffer. */
std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info_;
/*! \brief The map form BlockRealize to its annotated predicate. */
std::unordered_map<BlockRealize, PrimExpr, ObjectPtrHash, ObjectPtrEqual>&&
block_predicate_with_annotations_;
/*! \brief Whether local padding has been enabled. */
bool enable_local_pad_;
};

PrimFunc CompactBufferAllocation(PrimFunc f) {
PrimFunc CompactBufferAllocation(PrimFunc f, bool enable_local_pad) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
BufferAccessRegionCollector::Collect(f);
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region;
std::unordered_map<BlockRealize, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
block_predicate_with_annotations;
std::tie(region, block_predicate_with_annotations) = BufferAccessRegionCollector::Collect(f);
std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>
storage_align = StorageAlignCollector::Collect(f);
fptr->body = BufferCompactor::Compact(f, region, storage_align);
fptr->body = BufferCompactor::Compact(
f, region, storage_align, std::move(block_predicate_with_annotations), enable_local_pad);
return f;
} else {
return f;
Expand All @@ -599,9 +684,9 @@ PrimFunc CompactBufferAllocation(PrimFunc f) {

namespace transform {

Pass CompactBufferAllocation() {
Pass CompactBufferAllocation(bool enable_local_pad) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return CompactBufferAllocation(std::move(f));
return CompactBufferAllocation(std::move(f), enable_local_pad);
};
return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {});
}
Expand Down
Loading