Skip to content

Commit

Permalink
[mlir][Arith] Let integer range narrowing handle negative values (#11…
Browse files Browse the repository at this point in the history
…9642)

Update integer range narrowing to handle negative values.

The previous restriction to only narrowing known-non-negative values
wasn't needed, as both the signed and unsigned ranges represent bounds
on the values of each variable in the program ... except that one might
be more accurate than the other. So, if either the signed or unsigned
interpretetation of the inputs and outputs allows for integer narrowing,
the narrowing is permitted.

This commit also updates the integer optimization rewrites to preserve
the stae of constant-like operations and those that are narrowed so that
rewrites of other operations don't lose that range information.
  • Loading branch information
krzysz00 authored Dec 13, 2024
1 parent ecdf0da commit 9bf7930
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 142 deletions.
262 changes: 141 additions & 121 deletions mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
return inferredRange.getConstantValue();
}

static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
Value newVal) {
assert(oldVal.getType() == newVal.getType() &&
"Can't copy integer ranges between different types");
auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
if (!oldState)
return;
(void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
*oldState);
}

/// Patterned after SCCP
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
PatternRewriter &rewriter,
Expand Down Expand Up @@ -80,6 +91,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
if (!constOp)
return failure();

copyIntegerRange(solver, value, constOp->getResult(0));
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
return success();
}
Expand Down Expand Up @@ -195,56 +207,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DataFlowSolver &solver;
};

/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
static LogicalResult checkIntType(Type type, unsigned targetBitwidth) {
Type elemType = getElementTypeOrSelf(type);
if (isa<IndexType>(elemType))
return success();

if (auto intType = dyn_cast<IntegerType>(elemType))
if (intType.getWidth() > targetBitwidth)
return success();

return failure();
}

/// Check if op have same type for all operands and results and this type
/// is suitable for truncation.
static LogicalResult checkElementwiseOpType(Operation *op,
unsigned targetBitwidth) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return failure();

Type type;
for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
if (!type) {
type = val.getType();
continue;
}

if (type != val.getType())
return failure();
}

return checkIntType(type, targetBitwidth);
}

/// Return union of all operands values ranges.
static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
ValueRange operands) {
std::optional<ConstantIntRanges> ret;
for (Value value : operands) {
/// Gather ranges for all the values in `values`. Appends to the existing
/// vector.
static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
SmallVectorImpl<ConstantIntRanges> &ranges) {
for (Value val : values) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
solver.lookupState<IntegerValueRangeLattice>(val);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return std::nullopt;
return failure();

const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();

ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
ranges.push_back(inferredRange);
}
return ret;
return success();
}

/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
Expand All @@ -258,56 +235,79 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
return dstType;
}

/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin,
APInt smax, APInt umin, APInt umax) {
auto sge = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.sext(width);
val2 = val2.sext(width);
return val1.sge(val2);
};
auto sle = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.sext(width);
val2 = val2.sext(width);
return val1.sle(val2);
};
auto uge = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.zext(width);
val2 = val2.zext(width);
return val1.uge(val2);
};
auto ule = [](APInt val1, APInt val2) -> bool {
unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
val1 = val1.zext(width);
val2 = val2.zext(width);
return val1.ule(val2);
};
return success(sge(range.smin(), smin) && sle(range.smax(), smax) &&
uge(range.umin(), umin) && ule(range.umax(), umax));
namespace {
// Enum for tracking which type of truncation should be performed
// to narrow an operation, if any.
enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
} // namespace

/// If the values within `range` can be represented using only `width` bits,
/// return the kind of truncation needed to preserve that property.
///
/// This check relies on the fact that the signed and unsigned ranges are both
/// always correct, but that one might be an approximation of the other,
/// so we want to use the correct truncation operation.
static CastKind checkTruncatability(const ConstantIntRanges &range,
unsigned targetWidth) {
unsigned srcWidth = range.smin().getBitWidth();
if (srcWidth <= targetWidth)
return CastKind::None;
unsigned removedWidth = srcWidth - targetWidth;
// The sign bits need to extend into the sign bit of the target width. For
// example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
// bits.
bool canTruncateSigned =
range.smin().getNumSignBits() >= (removedWidth + 1) &&
range.smax().getNumSignBits() >= (removedWidth + 1);
bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
range.umax().countLeadingZeros() >= removedWidth;
if (canTruncateSigned && canTruncateUnsigned)
return CastKind::Both;
if (canTruncateSigned)
return CastKind::Signed;
if (canTruncateUnsigned)
return CastKind::Unsigned;
return CastKind::None;
}

static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
if (lhs == CastKind::None || rhs == CastKind::None)
return CastKind::None;
if (lhs == CastKind::Both)
return rhs;
if (rhs == CastKind::Both)
return lhs;
if (lhs == rhs)
return lhs;
return CastKind::None;
}

static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
CastKind castKind) {
Type srcType = src.getType();
assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
"Mixing vector and non-vector types");
assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
Type srcElemType = getElementTypeOrSelf(srcType);
Type dstElemType = getElementTypeOrSelf(dstType);
assert(srcElemType.isIntOrIndex() && "Invalid src type");
assert(dstElemType.isIntOrIndex() && "Invalid dst type");
if (srcType == dstType)
return src;

if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
if (castKind == CastKind::Signed)
return builder.create<arith::IndexCastOp>(loc, dstType, src);
return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
}

auto srcInt = cast<IntegerType>(srcElemType);
auto dstInt = cast<IntegerType>(dstElemType);
if (dstInt.getWidth() < srcInt.getWidth())
return builder.create<arith::TruncIOp>(loc, dstType, src);

if (castKind == CastKind::Signed)
return builder.create<arith::ExtSIOp>(loc, dstType, src);
return builder.create<arith::ExtUIOp>(loc, dstType, src);
}

Expand All @@ -319,36 +319,47 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
std::optional<ConstantIntRanges> range =
getOperandsRange(solver, op->getResults());
if (!range)
return failure();
if (op->getNumResults() == 0)
return rewriter.notifyMatchFailure(op, "can't narrow resultless op");

SmallVector<ConstantIntRanges> ranges;
if (failed(collectRanges(solver, op->getOperands(), ranges)))
return rewriter.notifyMatchFailure(op, "input without specified range");
if (failed(collectRanges(solver, op->getResults(), ranges)))
return rewriter.notifyMatchFailure(op, "output without specified range");

Type srcType = op->getResult(0).getType();
if (!llvm::all_equal(op->getResultTypes()))
return rewriter.notifyMatchFailure(op, "mismatched result types");
if (op->getNumOperands() == 0 ||
!llvm::all_of(op->getOperandTypes(),
[=](Type t) { return t == srcType; }))
return rewriter.notifyMatchFailure(
op, "no operands or operand types don't match result type");

for (unsigned targetBitwidth : targetBitwidths) {
if (failed(checkElementwiseOpType(op, targetBitwidth)))
continue;

Type srcType = op->getResult(0).getType();

// We are truncating op args to the desired bitwidth before the op and
// then extending op results back to the original width after. extui and
// exti will produce different results for negative values, so limit
// signed range to non-negative values.
auto smin = APInt::getZero(targetBitwidth);
auto smax = APInt::getSignedMaxValue(targetBitwidth);
auto umin = APInt::getMinValue(targetBitwidth);
auto umax = APInt::getMaxValue(targetBitwidth);
if (failed(checkRange(*range, smin, smax, umin, umax)))
CastKind castKind = CastKind::Both;
for (const ConstantIntRanges &range : ranges) {
castKind = mergeCastKinds(castKind,
checkTruncatability(range, targetBitwidth));
if (castKind == CastKind::None)
break;
}
if (castKind == CastKind::None)
continue;

Type targetType = getTargetType(srcType, targetBitwidth);
if (targetType == srcType)
continue;

Location loc = op->getLoc();
IRMapping mapping;
for (Value arg : op->getOperands()) {
Value newArg = doCast(rewriter, loc, arg, targetType);
for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
CastKind argCastKind = castKind;
// When dealing with `index` values, preserve non-negativity in the
// index_casts since we can't recover this in unsigned when equivalent.
if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
argCastKind = CastKind::Both;
Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
mapping.map(arg, newArg);
}

Expand All @@ -359,8 +370,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
}
});
SmallVector<Value> newResults;
for (Value res : newOp->getResults())
newResults.emplace_back(doCast(rewriter, loc, res, srcType));
for (auto [newRes, oldRes] :
llvm::zip_equal(newOp->getResults(), op->getResults())) {
Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
copyIntegerRange(solver, oldRes, castBack);
newResults.push_back(castBack);
}

rewriter.replaceOp(op, newResults);
return success();
Expand All @@ -382,21 +397,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
Value lhs = op.getLhs();
Value rhs = op.getRhs();

std::optional<ConstantIntRanges> range =
getOperandsRange(solver, {lhs, rhs});
if (!range)
SmallVector<ConstantIntRanges> ranges;
if (failed(collectRanges(solver, op.getOperands(), ranges)))
return failure();
const ConstantIntRanges &lhsRange = ranges[0];
const ConstantIntRanges &rhsRange = ranges[1];

Type srcType = lhs.getType();
for (unsigned targetBitwidth : targetBitwidths) {
Type srcType = lhs.getType();
if (failed(checkIntType(srcType, targetBitwidth)))
continue;

auto smin = APInt::getSignedMinValue(targetBitwidth);
auto smax = APInt::getSignedMaxValue(targetBitwidth);
auto umin = APInt::getMinValue(targetBitwidth);
auto umax = APInt::getMaxValue(targetBitwidth);
if (failed(checkRange(*range, smin, smax, umin, umax)))
CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
// Note: this includes target width > src width.
if (castKind == CastKind::None)
continue;

Type targetType = getTargetType(srcType, targetBitwidth);
Expand All @@ -405,12 +418,13 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {

Location loc = op->getLoc();
IRMapping mapping;
for (Value arg : op->getOperands()) {
Value newArg = doCast(rewriter, loc, arg, targetType);
mapping.map(arg, newArg);
}
Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
mapping.map(lhs, lhsCast);
mapping.map(rhs, rhsCast);

Operation *newOp = rewriter.clone(*op, mapping);
copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
rewriter.replaceOp(op, newOp->getResults());
return success();
}
Expand All @@ -425,19 +439,23 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
/// This pattern assumes all passed `targetBitwidths` are not wider than index
/// type.
struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
template <typename CastOp>
struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
: OpRewritePattern(context), targetBitwidths(target) {}
: OpRewritePattern<CastOp>(context), targetBitwidths(target) {}

LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
LogicalResult matchAndRewrite(CastOp op,
PatternRewriter &rewriter) const override {
auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
auto srcOp = op.getIn().template getDefiningOp<CastOp>();
if (!srcOp)
return failure();
return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");

Value src = srcOp.getIn();
if (src.getType() != op.getType())
return failure();
return rewriter.notifyMatchFailure(op, "outer types don't match");

if (!srcOp.getType().isIndex())
return rewriter.notifyMatchFailure(op, "intermediate type isn't index");

auto intType = dyn_cast<IntegerType>(op.getType());
if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
Expand Down Expand Up @@ -517,7 +535,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
ArrayRef<unsigned> bitwidthsSupported) {
patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
bitwidthsSupported);
patterns.add<FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported);
patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
bitwidthsSupported);
}

std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
Expand Down
Loading

0 comments on commit 9bf7930

Please sign in to comment.