Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][CompactBufferAllocation] Improve upperbound estimation of buffer compaction #12527

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,29 @@ Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets);
IntSet Intersect(const Array<IntSet>& sets);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
* \brief Converts the Ranges to IntSets
* \param var_dom The ranges of variables
* \return The integer sets of the variables
*/
Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate.
* The result should be strict, i.e. no region is discarded or relaxed.
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
* \param analyzer The analyzer used
* \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis
*/
TVM_DLL Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate.
* Some subregion may be discarded during the lower-bound analysis.
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
Expand All @@ -273,6 +295,21 @@ TVM_DLL Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& reg
const PrimExpr& predicate,
arith::Analyzer* analyzer);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
* Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added
* to the result.
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
* \param analyzer The analyzer used
* \return an array of arith::IntSet as the result of analysis
*/
TVM_DLL Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SET_H_
8 changes: 7 additions & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
# under the License.
"""Integer bound analysis, simplification and pattern detection."""

from .int_set import IntSet, IntervalSet, estimate_region_lower_bound
from .int_set import (
IntSet,
IntervalSet,
estimate_region_lower_bound,
estimate_region_strict_bound,
estimate_region_upper_bound,
)
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
Expand Down
48 changes: 48 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, min_value, max_value):

def estimate_region_lower_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
Some subregion may be discarded during the lower-bound analysis.
Parameters
----------
Expand All @@ -103,6 +104,53 @@ def estimate_region_lower_bound(region, var_dom, predicate):
return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate)


def estimate_region_strict_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
The result should be strict, i.e. no region is discarded or relaxed.
Parameters
----------
region : List[Range]
The region to be analyzed.
var_dom : Dict[Var, Range]
The ranges of the variables
predicate : PrimExpr
The predicate for the affine map
Returns
----------
region_int_set : Optional[List[IntSet]]
None if the detection fails, or an array of IntSets as the result of analysis
"""
return _ffi_api.EstimateRegionStrictBound(region, var_dom, predicate)


def estimate_region_upper_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
Relaxation of the region may be used in upper-bound analysis,
i.e. some extra region may be added to the result.
Parameters
----------
region : List[Range]
The region to be analyzed.
var_dom : Dict[Var, Range]
The ranges of the variables
predicate : PrimExpr
The predicate for the affine map
Returns
----------
region_int_set : List[IntSet]
an array of IntSets as the result of analysis
"""
return _ffi_api.EstimateRegionUpperBound(region, var_dom, predicate)


def pos_inf():
"""Returns the symbolic positive infinity
Expand Down
131 changes: 107 additions & 24 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,9 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map<const VarNode*, IntSet>& dom

IntSet EvalSet(Range r, const Map<Var, IntSet>& dom_map) {
Analyzer ana;
if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) {
return EvalSet(r->min, dom_map);
}
IntervalSetEvaluator m(&ana, dom_map);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
PrimExpr sum = r->min + r->extent - 1;
Expand Down Expand Up @@ -1035,15 +1038,57 @@ IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) {
return EvalSet(r, ConvertDomMap(dom_map));
}

Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate, Analyzer* analyzer) {
Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
Map<Var, arith::IntSet> result;
for (auto kv : var_dom) {
const Var& var = kv.first;
const Range& range = kv.second;
result.Set(var, arith::IntSet::FromRange(range));
}
return result;
}

/*! \brief Helper function to convert IterSumExpr to the actual touched range. */
static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent,
Analyzer* analyzer) {
if (iter_min->args.empty()) {
return IntSet::FromMinExtent(iter_min->base, extent);
}
ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr";
const IterSplitExpr& split = iter_min->args[0];
if (!analyzer->CanProve(extent >= split->scale)) {
return NullOpt;
}

const PrimExpr& base = iter_min->base;
// IterSplitExpr: (source // lower_factor) % extent * scale
// where `(source // lower_factor) % extent` is within [0, extent - 1]
if (analyzer->CanProve(split->scale < 0)) {
// If scale is negative, the var dom is [(extent - 1) * scale, 0]
// The total base is `base + (extent - 1) * scale`,
// while total extent is `dom_extent + (extent - 1) * (-scale)`
const PrimExpr& var_extent = (split->extent - 1) * split->scale;
return IntSet::FromMinExtent(base + var_extent, extent - var_extent);
} else {
// If scale is positive, the var dom is [0, (extent - 1) * scale]
// The total dom is [base, dom_extent + (extent - 1) * scale]
return IntSet::FromMinExtent(base, extent + (split->extent - 1) * split->scale);
}
}

Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate, Analyzer* analyzer) {
int ndim = region.size();
Array<IterSumExpr> iter_sum_exprs{nullptr};
{
Array<PrimExpr> affine_indices;
affine_indices.reserve(ndim);
for (const Range& range : region) {
if (!is_const_number(range->extent)) {
// dynamic extent is not supported yet.
return NullOpt;
}
affine_indices.push_back(range->min);
}
auto res = DetectIterMap(
Expand All @@ -1060,31 +1105,57 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
for (int i = 0; i < ndim; ++i) {
const IterSumExpr& sum_expr = iter_sum_exprs[i];
const Range& range = region[i];
if (sum_expr->args.empty()) {
result.push_back(IntSet::FromMinExtent(sum_expr->base, range->extent));
continue;
}
ICHECK_EQ(sum_expr->args.size(), 1);
const IterSplitExpr& split = sum_expr->args[0];
if (!analyzer->CanProve(range->extent >= split->scale)) {
Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer);
if (int_set.defined()) {
result.push_back(int_set.value());
} else {
return NullOpt;
}
}
return result;
}

const PrimExpr& base = sum_expr->base;
// IterSplitExpr: (source // lower_factor) % extent * scale
// where `(source // lower_factor) % extent` is within [0, extent - 1]
if (analyzer->CanProve(split->scale < 0)) {
// If scale is negative, the var dom is [(extent - 1) * scale, 0]
// The total base is `base + (extent - 1) * scale`,
// while total extent is `dom_extent + (extent - 1) * (-scale)`
const PrimExpr& var_extent = (split->extent - 1) * split->scale;
result.push_back(IntSet::FromMinExtent(base + var_extent, range->extent - var_extent));
} else {
// If scale is positive, the var dom is [0, (extent - 1) * scale]
// The total dom is [base, dom_extent + (extent - 1) * scale]
result.push_back(
IntSet::FromMinExtent(base, range->extent + (split->extent - 1) * split->scale));
Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer) {
return EstimateRegionStrictBound(region, var_dom, predicate, analyzer);
}

Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region, const Map<Var, Range>& var_dom,
const PrimExpr& predicate, Analyzer* analyzer) {
if (Optional<Array<arith::IntSet>> result = EstimateRegionStrictBound(
/*region=*/region,
/*var_dom=*/var_dom,
/*predicate=*/predicate, /*analyzer=*/analyzer)) {
return result.value();
}
Array<IntSet> result;
result.reserve(region.size());
// try estimate each dimension independently
for (const Range& range : region) {
auto res = DetectIterMap(
/*indices=*/{range->min}, /*input_iters=*/var_dom,
/*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer);
if (!res->indices.empty()) {
ICHECK_EQ(res->indices.size(), 1U);
IterSumExpr sum_expr = res->indices[0];

// dynamic extent is not supported yet.
PrimExpr extent = range->extent;
if (!is_const_number(extent)) {
IntSet relaxed = EvalSet(extent, AsIntSet(var_dom));
ICHECK(relaxed.HasUpperBound());
extent = relaxed.max();
}

if (Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer)) {
result.push_back(int_set.value());
continue;
}
}
// fallback to coarse grained evalset
result.push_back(EvalSet(range, AsIntSet(var_dom)));
}
return result;
}
Expand Down Expand Up @@ -1118,6 +1189,18 @@ TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound")
Analyzer analyzer;
return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer);
});
TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound")
.set_body_typed([](Array<Range> region, Map<Var, Range> var_dom,
PrimExpr predicate) -> Optional<Array<IntSet>> {
Analyzer analyzer;
return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer);
});
TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound")
.set_body_typed([](Array<Range> region, Map<Var, Range> var_dom,
PrimExpr predicate) -> Optional<Array<IntSet>> {
Analyzer analyzer;
return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer);
});

TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; });
TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; });
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
runtime::StorageRank rank = scope.rank;
if (rank != previous_rank || !var_dom.defined()) {
previous_rank = rank;
var_dom = AsIntSet(LoopDomainOfSRefTreePath(
var_dom = arith::AsIntSet(LoopDomainOfSRefTreePath(
/*low_inclusive=*/relax_path_low_inclusive,
/*high_exclusive=*/relax_path_high_exclusive,
/*extra_relax_scope=*/scope));
Expand Down
14 changes: 6 additions & 8 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"
#include <tvm/arith/int_set.h>

#include "./utils.h"
namespace tvm {
namespace tir {

Expand All @@ -44,13 +45,10 @@ Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region,
/*low_inclusive=*/dom_low_inclusive,
/*high_exclusive=*/dom_high_exclusive,
/*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope()));
if (Optional<Array<arith::IntSet>> result = EstimateRegionLowerBound(
/*region=*/region->region,
/*var_dom=*/var_dom,
/*predicate=*/predicate, /*analyzer=*/analyzer)) {
return result.value();
}
return arith::EvalSet(region->region, AsIntSet(var_dom));
return EstimateRegionUpperBound(
/*region=*/region->region,
/*var_dom=*/var_dom,
/*predicate=*/predicate, /*analyzer=*/analyzer);
}

/*!
Expand Down
18 changes: 0 additions & 18 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,24 +249,6 @@ inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 1 && thread_scope.dim_index >= 0;
}

/******** Integer set ********/

/*!
* \brief Converts the Ranges to IntSets
* \param var_dom The ranges of variables
* \return The integer sets of the variables
*/
inline Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
std::unordered_map<Var, arith::IntSet, ObjectPtrHash, ObjectPtrEqual> result;
result.reserve(var_dom.size());
for (auto kv : var_dom) {
Var& var = kv.first;
Range& range = kv.second;
result.emplace(std::move(var), arith::IntSet::FromRange(std::move(range)));
}
return {result.begin(), result.end()};
}

/**************** Loop extents ****************/

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
var_dom[GetRef<Var>(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0));
}
Optional<Array<arith::IntSet>> eval_res =
arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer);
arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer);
if (eval_res.defined()) {
return NDIntSet(eval_res.value().begin(), eval_res.value().end());
}
Expand Down
Loading