Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle many more intrinsics in Bounds.cpp #7823

Merged
merged 11 commits into from
Dec 1, 2023
175 changes: 161 additions & 14 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@ using std::string;
using std::vector;

namespace {

bool can_widen(const Expr &e) {
return e.type().bits() < 64;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be <= 32. I'm thinking of the 48 bit types in the xtensa backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

bool can_widen_all(const std::vector<Expr> &args) {
for (const auto &e : args) {
if (!can_widen(e)) {
return false;
}
}
return true;
}

Expr widen(Expr a) {
internal_assert(can_widen(a));
Type result_type = a.type().widen();
return Cast::make(result_type, std::move(a));
}

Expr narrow(Expr a) {
Type result_type = a.type().narrow();
return Cast::make(result_type, std::move(a));
}

Expr saturating_narrow(const Expr &a) {
Type narrow = a.type().narrow();
return saturating_cast(narrow, a);
}

int static_sign(const Expr &x) {
if (is_positive_const(x)) {
return 1;
Expand All @@ -56,6 +86,7 @@ int static_sign(const Expr &x) {
}
return 0;
}

} // anonymous namespace

const FuncValueBounds &empty_func_value_bounds() {
Expand Down Expand Up @@ -1195,6 +1226,15 @@ class Bounds : public IRVisitor {
// else fall thru and continue
}

const auto handle_expr_bounds = [this, t](const Expr &e) -> void {
if (e.defined()) {
e.accept(this);
} else {
// Just use the bounds of the type
this->bounds_of_type(t);
}
};

if (op->is_intrinsic(Call::abs)) {
Interval a = arg_bounds.get(0);
interval.min = make_zero(t);
Expand Down Expand Up @@ -1229,12 +1269,6 @@ class Bounds : public IRVisitor {
bounds_of_type(t);
}
}
} else if (op->is_intrinsic(Call::saturating_cast)) {
internal_assert(op->args.size() == 1);

Expr a = lower_saturating_cast(op->type, op->args[0]);
a.accept(this);
return;
} else if (op->is_intrinsic(Call::unsafe_promise_clamped) ||
op->is_intrinsic(Call::promise_clamped)) {
// Unlike an explicit clamp, we are also permitted to
Expand Down Expand Up @@ -1468,6 +1502,7 @@ class Bounds : public IRVisitor {
}
} else if (op->args.size() == 1 &&
(op->is_intrinsic(Call::round) ||
op->is_intrinsic(Call::strict_float) ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's going to be a merge conflict here because Call::saturating_cast is in the same category. Probably should add it in this PR in case the other one doesn't go in and we revert the u32 -> i32 cast change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it back! saturating_cast doesn't belong here.

op->name == "ceil_f32" || op->name == "ceil_f64" ||
op->name == "floor_f32" || op->name == "floor_f64" ||
op->name == "exp_f32" || op->name == "exp_f64" ||
Expand Down Expand Up @@ -1518,14 +1553,127 @@ class Bounds : public IRVisitor {
}
interval = result;
} else if (op->is_intrinsic(Call::widen_right_add)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this many safety checks for the widening operations. Any Expr in a widening op needs to be able to be widened - we can't lift to widening_mul unless a user widened the inputs. We only need to be careful with operations that we can lift to without widening operations, but that the "simple" lowering pattern involves widening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes and no -- it's true a that the Exprs in a widening op need to be widened, and well-formed code shouldn't pass us cases that don't fit; that said, we absolutely will get misuse in that way, so what should we do when that happens? IMHO we are better off checking for it an explicitly devolving to bounds-of-type, rather than risking that the bounds-calc code makes a mistake and calculates an inappropriate bound due to inadvertent overflow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I misunderstood the use case. Is this for when users write code that produces a LUT index, and uses intermediate 64 bit types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for when users write code that produces a LUT index

Yes? I mean, we have no control of what the user is doing with these functions; they could pass it insane nonsense, so we need to be somewhat defensive here. We'd prefer to avoid a too-loose bounds, but we absolutely cannot risk getting too-tight bounds.

Expr add = Add::make(op->args[0], cast(op->args[0].type(), op->args[1]));
add.accept(this);
} else if (op->is_intrinsic(Call::widen_right_sub)) {
Expr sub = Sub::make(op->args[0], cast(op->args[0].type(), op->args[1]));
sub.accept(this);
internal_assert(op->args.size() == 2);
Expr e = can_widen(op->args[1]) ?
lower_widen_right_add(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::widen_right_mul)) {
Expr mul = Mul::make(op->args[0], cast(op->args[0].type(), op->args[1]));
mul.accept(this);
internal_assert(op->args.size() == 2);
Expr e = can_widen(op->args[1]) ?
lower_widen_right_mul(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::widen_right_sub)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen(op->args[1]) ?
lower_widen_right_sub(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::widening_add)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
lower_widening_add(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::widening_mul)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
lower_widening_mul(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::widening_sub)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
lower_widening_sub(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::saturating_add)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
narrow(clamp(widen(op->args[0]) + widen(op->args[1]), t.min(), t.max())) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::saturating_sub)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
narrow(clamp(widen(op->args[0]) - widen(op->args[1]), t.min(), t.max())) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::saturating_cast)) {
internal_assert(op->args.size() == 1);
bounds_of_type(t);
Interval type_bounds = interval;
interval = arg_bounds.get(0);

if (interval.has_lower_bound()) {
interval.min = saturating_cast(t, interval.min);
} else if (op->args[0].type().is_uint()) {
// If we're casting from a uint, we can at least lower bound it
// with zero.
interval.min = make_zero(t);
} else {
interval.min = type_bounds.min;
}
if (interval.has_upper_bound()) {
interval.max = saturating_cast(t, interval.max);
} else {
interval.max = type_bounds.max;
}
} else if (op->is_intrinsic(Call::widening_shift_left)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen(op->args[0]) ?
lower_widening_shift_left(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::widening_shift_right)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen(op->args[0]) ?
lower_widening_shift_right(op->args[0], op->args[1]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::rounding_shift_right)) {
internal_assert(op->args.size() == 2);
// TODO: uses bitwise ops we may not handle well
handle_expr_bounds(lower_rounding_shift_right(op->args[0], op->args[1]));
} else if (op->is_intrinsic(Call::rounding_shift_left)) {
internal_assert(op->args.size() == 2);
// TODO: uses bitwise ops we may not handle well
handle_expr_bounds(lower_rounding_shift_left(op->args[0], op->args[1]));
} else if (op->is_intrinsic(Call::halving_add)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
narrow((widen(op->args[0]) + widen(op->args[1])) / 2) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::halving_sub)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
narrow((widen(op->args[0]) - widen(op->args[1])) / 2) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::rounding_halving_add)) {
internal_assert(op->args.size() == 2);
Expr e = can_widen_all(op->args) ?
narrow((widen(op->args[0]) + widen(op->args[1]) + 1) / 2) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::rounding_mul_shift_right)) {
internal_assert(op->args.size() == 3);
Expr e = can_widen_all(op->args) ?
saturating_narrow(rounding_shift_right(widening_mul(op->args[0], op->args[1]), op->args[2])) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::mul_shift_right)) {
internal_assert(op->args.size() == 3);
Expr e = can_widen_all(op->args) ?
saturating_narrow(widening_mul(op->args[0], op->args[1]) >> op->args[2]) :
Expr();
handle_expr_bounds(e);
} else if (op->is_intrinsic(Call::sorted_avg)) {
internal_assert(op->args.size() == 2);
Expr e = lower_sorted_avg(op->args[0], op->args[1]);
handle_expr_bounds(e);
} else if (op->call_type == Call::Halide) {
bounds_of_func(op->name, op->value_index, op->type);
} else {
Expand Down Expand Up @@ -2261,7 +2409,6 @@ class BoxesTouched : public IRGraphVisitor {
Stmt else_case = Evaluate::make(op->args[2]);
equivalent_if = IfThenElse::make(op->args[0], then_case, else_case);
} else {
internal_assert(op->args.size() == 2);
equivalent_if = IfThenElse::make(op->args[0], then_case);
}
equivalent_if.accept(this);
Expand Down
1 change: 1 addition & 0 deletions src/FindIntrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Expr lower_saturating_cast(const Type &t, const Expr &a);
Expr lower_halving_add(const Expr &a, const Expr &b);
Expr lower_halving_sub(const Expr &a, const Expr &b);
Expr lower_rounding_halving_add(const Expr &a, const Expr &b);
Expr lower_sorted_avg(const Expr &a, const Expr &b);

Expr lower_mul_shift_right(const Expr &a, const Expr &b, const Expr &q);
Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q);
Expand Down