Skip to content

Commit

Permalink
Tighten bounds of abs() (#8168)
Browse files Browse the repository at this point in the history
* Tighten bounds of abs()

* make abs bounds tight for non-int32 too

* make int32 min expression match non-int32 min expression
  • Loading branch information
rootjalex authored Apr 5, 2024
1 parent 7d99357 commit a462044
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dependencies/llvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
message(STATUS "Using ClangConfig.cmake in: ${Clang_DIR}")

if (LLVM_PACKAGE_VERSION VERSION_LESS 16.0)
message(FATAL_ERROR "LLVM version must be 15.0 or newer")
message(FATAL_ERROR "LLVM version must be 16.0 or newer")
endif ()

if (LLVM_PACKAGE_VERSION VERSION_GREATER 19.0)
Expand Down
30 changes: 28 additions & 2 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1237,18 +1237,29 @@ class Bounds : public IRVisitor {

if (op->is_intrinsic(Call::abs)) {
Interval a = arg_bounds.get(0);
interval.min = make_zero(t);

if (a.is_bounded()) {
if (equal(a.min, a.max)) {
interval = Interval::single_point(Call::make(t, Call::abs, {a.max}, Call::PureIntrinsic));
} else if (op->args[0].type().is_int() && op->args[0].type().bits() >= 32) {
interval.max = Max::make(Cast::make(t, -a.min), Cast::make(t, a.max));
interval.min = Cast::make(t, Max::make(a.min, -Min::make(make_zero(a.min.type()), a.max)));
interval.max = Cast::make(t, Max::make(-a.min, a.max));
} else {
interval.min = Cast::make(t, Max::make(a.min, -Min::make(make_zero(a.min.type()), a.max)));
a.min = Call::make(t, Call::abs, {a.min}, Call::PureIntrinsic);
a.max = Call::make(t, Call::abs, {a.max}, Call::PureIntrinsic);
interval.max = Max::make(a.min, a.max);
}
} else {
if (a.has_lower_bound()) {
// If a is strictly positive, then abs(a) is strictly positive.
interval.min = Cast::make(t, Max::make(make_zero(a.min.type()), a.min));
} else if (a.has_upper_bound()) {
// If a is strictly negative, then abs(a) is strictly positive.
interval.min = Cast::make(t, -Min::make(make_zero(a.max.type()), a.max));
} else {
interval.min = make_zero(t);
}
// If the argument is unbounded on one side, then the max is unbounded.
interval.max = Interval::pos_inf();
}
Expand Down Expand Up @@ -3651,6 +3662,21 @@ void bounds_test() {
check(scope, cast<float>(x), 0.0f, 10.0f);

check(scope, cast<int32_t>(abs(cast<float>(x))), 0, 10);
check(scope, abs(2 + x), u32(2), u32(12));
check(scope, abs(x - 11), u32(1), u32(11));
check(scope, abs(x - 5), u32(0), u32(5));
check(scope, abs(2 + cast<float>(x)), 2.f, 12.f);
check(scope, abs(cast<float>(x) - 11), 1.f, 11.f);
check(scope, abs(cast<float>(x) - 5), 0.f, 5.f);
check(scope, abs(2 + cast<int8_t>(x)), u8(2), u8(12));
check(scope, abs(cast<int8_t>(x) - 11), u8(1), u8(11));
check(scope, abs(cast<int8_t>(x) - 5), u8(0), u8(5));
scope.push("x", Interval(123, Interval::pos_inf()));
check(scope, abs(x), u32(123), Interval::pos_inf());
scope.pop("x");
scope.push("x", Interval(Interval::neg_inf(), -123));
check(scope, abs(x), u32(123), Interval::pos_inf());
scope.pop("x");

// Check some vectors
check(scope, Ramp::make(x * 2, 5, 5), 0, 40);
Expand Down

0 comments on commit a462044

Please sign in to comment.