Skip to content

Commit

Permalink
Track likely values through lets in loop partitioning (#7930)
Browse files Browse the repository at this point in the history
* Track likely values through lets in loop partitioning

Fixes #7929

Improves runtime of lens_blur app by ~20%

* Add uncaptured likely tags to selects in boundary condition helpers

Now that we look through lets, we end up in more situations where both
sides have a captured likely.

* Better comments
  • Loading branch information
abadams authored Nov 16, 2023
1 parent 0f65435 commit ad0f24e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 50 deletions.
29 changes: 22 additions & 7 deletions src/BoundaryConditions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ Func constant_exterior(const Func &source, const Tuple &value,
if (value.as_vector().size() > 1) {
std::vector<Expr> def;
for (size_t i = 0; i < value.as_vector().size(); i++) {
def.push_back(select(out_of_bounds, value[i], repeat_edge(source, bounds)(args)[i]));
def.push_back(select(out_of_bounds, value[i], likely(repeat_edge(source, bounds)(args)[i])));
}
bounded(args) = Tuple(def);
} else {
bounded(args) = select(out_of_bounds, value[0], repeat_edge(source, bounds)(args));
bounded(args) = select(out_of_bounds, value[0], likely(repeat_edge(source, bounds)(args)));
}

return bounded;
Expand Down Expand Up @@ -99,10 +99,25 @@ Func repeat_image(const Func &source,
Expr coord = arg_var - min; // Enforce zero origin.
coord = coord % extent; // Range is 0 to w-1
coord = coord + min; // Restore correct min

coord = select(arg_var < min || arg_var >= min + extent, coord,
clamp(likely(arg_var), min, min + extent - 1));

likely(clamp(likely(arg_var), min, min + extent - 1)));

// In the line above, we want loop partitioning to both cause the
// clamp to go away, and also cause the select to go away. For loop
// partitioning to make one of these constructs go away we need one
// of two things to be true:
//
// 1) One arg has a likely intrinsic buried somewhere within it, and
// the other arg doesn't.
// 2) Both args have likely intrinsics, but in one of the args it is
// not within any inner min/max/select node. This is called an
// 'uncaptured' likely.
//
// The issue with this boundary condition is that the true branch of
// the select (coord) may well have a likely within it somewhere
// introduced by a loop tail strategy, so condition 1 doesn't
// hold. To be more robust, we make condition 2 hold, by introducing
// an uncaptured likely to the false branch.
actuals.push_back(coord);
} else if (!min.defined() && !extent.defined()) {
actuals.push_back(arg_var);
Expand Down Expand Up @@ -143,7 +158,7 @@ Func mirror_image(const Func &source,
coord = coord + min; // Restore correct min
coord = clamp(coord, min, min + extent - 1);
coord = select(arg_var < min || arg_var >= min + extent, coord,
clamp(likely(arg_var), min, min + extent - 1));
likely(clamp(likely(arg_var), min, min + extent - 1)));
actuals.push_back(coord);
} else if (!min.defined() && !extent.defined()) {
actuals.push_back(arg_var);
Expand Down Expand Up @@ -188,7 +203,7 @@ Func mirror_interior(const Func &source,

// The boundary condition probably doesn't apply
coord = select(arg_var < min || arg_var >= min + extent, coord,
clamp(likely(arg_var), min, min + extent - 1));
likely(clamp(likely(arg_var), min, min + extent - 1)));

actuals.push_back(coord);
} else if (!min.defined() && !extent.defined()) {
Expand Down
128 changes: 90 additions & 38 deletions src/PartitionLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,34 @@ class HasLikelyTag : public IRVisitor {
}
}

void visit(const Variable *op) override {
result |= scope.contains(op->name);
}

const Scope<> &scope;

public:
HasLikelyTag(const Scope<> &s)
: scope(s) {
}

bool result = false;
};

class HasUncapturedLikelyTag : public HasLikelyTag {
using HasLikelyTag::visit;

// Any likelies buried inside the following ops are captured the by respective ops
void visit(const Select *op) override {
}
void visit(const Min *op) override {
}
void visit(const Max *op) override {
}

public:
HasUncapturedLikelyTag(const Scope<> &s)
: HasLikelyTag(s) {
}
};

// The goal of loop partitioning is to split loops up into a prologue,
Expand Down Expand Up @@ -243,6 +257,7 @@ class FindSimplifications : public IRVisitor {
using IRVisitor::visit;

Scope<> depends_on_loop_var, depends_on_invalid_buffers;
Scope<> vars_with_uncaptured_likely, vars_with_likely;
Scope<> buffers;

void visit(const Allocate *op) override {
Expand All @@ -263,23 +278,20 @@ class FindSimplifications : public IRVisitor {
}
condition = remove_likelies(condition);
Simplification s = {condition, std::move(old), std::move(likely_val), std::move(unlikely_val), true};
while (s.condition.type().is_vector()) {
s.condition = simplify(s.condition);
if (const Broadcast *b = s.condition.as<Broadcast>()) {
s.condition = b->value;
} else {
// Devectorize the condition
s.condition = and_condition_over_domain(s.condition, Scope<Interval>::empty_scope());
s.tight = false;
}
}
internal_assert(s.condition.type().is_scalar()) << s.condition << "\n";
simplifications.push_back(s);
}

bool has_uncaptured_likely(const Expr &e) const {
return has_uncaptured_likely_tag(e, vars_with_uncaptured_likely);
}

bool has_likely(const Expr &e) const {
return has_likely_tag(e, vars_with_likely);
}

void visit(const Min *op) override {
bool likely_a = has_uncaptured_likely_tag(op->a);
bool likely_b = has_uncaptured_likely_tag(op->b);
bool likely_a = has_uncaptured_likely(op->a);
bool likely_b = has_uncaptured_likely(op->b);

// If one side has an uncaptured likely, don't hunt for
// simplifications in the other side.
Expand All @@ -294,20 +306,23 @@ class FindSimplifications : public IRVisitor {
// call. If neither does, prefer the side that contains any
// likely call at all.
if (!likely_a && !likely_b) {
likely_a = has_likely_tag(op->a);
likely_b = has_likely_tag(op->b);
likely_a = has_likely(op->a);
likely_b = has_likely(op->b);
}

if (likely_b && !likely_a) {
new_simplification(op->b <= op->a, op, op->b, op->a);
} else if (likely_a && !likely_b) {
new_simplification(op->a <= op->b, op, op->a, op->b);
} else if (likely_a && likely_b) {
// Likelies on both sides, continue inwards.
IRVisitor::visit(op);
}
}

void visit(const Max *op) override {
bool likely_a = has_uncaptured_likely_tag(op->a);
bool likely_b = has_uncaptured_likely_tag(op->b);
bool likely_a = has_uncaptured_likely(op->a);
bool likely_b = has_uncaptured_likely(op->b);

if (!likely_a) {
op->b.accept(this);
Expand All @@ -317,8 +332,8 @@ class FindSimplifications : public IRVisitor {
}

if (!likely_a && !likely_b) {
likely_a = has_likely_tag(op->a);
likely_b = has_likely_tag(op->b);
likely_a = has_likely(op->a);
likely_b = has_likely(op->b);
}

if (likely_b && !likely_a) {
Expand All @@ -331,13 +346,8 @@ class FindSimplifications : public IRVisitor {
void visit_select(const Expr &condition, const Expr &old, const Expr &true_value, const Expr &false_value) {
condition.accept(this);

bool likely_t = has_uncaptured_likely_tag(true_value);
bool likely_f = has_uncaptured_likely_tag(false_value);

if (!likely_t && !likely_f) {
likely_t = has_likely_tag(true_value);
likely_f = has_likely_tag(false_value);
}
bool likely_t = has_uncaptured_likely(true_value);
bool likely_f = has_uncaptured_likely(false_value);

if (!likely_t) {
false_value.accept(this);
Expand All @@ -346,6 +356,11 @@ class FindSimplifications : public IRVisitor {
true_value.accept(this);
}

if (!likely_t && !likely_f) {
likely_t = has_likely(true_value);
likely_f = has_likely(false_value);
}

if (likely_t && !likely_f) {
new_simplification(condition, old, true_value, false_value);
} else if (likely_f && !likely_t) {
Expand Down Expand Up @@ -376,7 +391,7 @@ class FindSimplifications : public IRVisitor {
// statement is marked as likely, treat it as likely true and
// partition accordingly.
IRVisitor::visit(op);
if (has_uncaptured_likely_tag(op->condition)) {
if (has_uncaptured_likely(op->condition)) {
new_simplification(op->condition, op->condition, const_true(), const_false());
}
}
Expand Down Expand Up @@ -408,15 +423,15 @@ class FindSimplifications : public IRVisitor {

void visit(const Store *op) override {
IRVisitor::visit(op);
if (has_uncaptured_likely_tag(op->predicate)) {
if (has_uncaptured_likely(op->predicate)) {
const int lanes = op->predicate.type().lanes();
new_simplification(op->predicate, op->predicate, const_true(lanes), remove_likelies(op->predicate));
}
}

void visit(const Load *op) override {
IRVisitor::visit(op);
if (has_uncaptured_likely_tag(op->predicate)) {
if (has_uncaptured_likely(op->predicate)) {
const int lanes = op->predicate.type().lanes();
new_simplification(op->predicate, op->predicate, const_true(lanes), remove_likelies(op->predicate));
}
Expand All @@ -429,6 +444,11 @@ class FindSimplifications : public IRVisitor {
ScopedBinding<> bind_invalid(expr_uses_invalid_buffers(op->value, buffers) ||
expr_uses_vars(op->value, depends_on_invalid_buffers),
depends_on_invalid_buffers, op->name);
ScopedBinding<> bind_uncaptured_likely(has_uncaptured_likely(op->value),
vars_with_uncaptured_likely, op->name);
ScopedBinding<> bind_likely(has_likely(op->value),
vars_with_likely, op->name);

vector<Simplification> old;
old.swap(simplifications);
IRVisitor::visit(op);
Expand Down Expand Up @@ -566,6 +586,18 @@ class PartitionLoops : public IRMutator {
vector<Simplification> middle_simps, prologue_simps, epilogue_simps;
bool lower_bound_is_tight = true, upper_bound_is_tight = true;
for (auto &s : finder.simplifications) {

// Devectorize the condition
while (s.condition.type().is_vector()) {
s.condition = simplify(s.condition);
if (const Broadcast *b = s.condition.as<Broadcast>()) {
s.condition = b->value;
} else {
s.condition = and_condition_over_domain(s.condition, Scope<Interval>::empty_scope());
s.tight = false;
}
}

// Solve for the interval over which this simplification is true.
s.interval = solve_for_inner_interval(s.condition, op->name);
if (s.tight) {
Expand Down Expand Up @@ -991,24 +1023,44 @@ class ExpandSelects : public IRMutator {

Expr visit(const Select *op) override {
Expr condition = mutate(op->condition);

const Call *true_likely = Call::as_intrinsic(op->true_value, {Call::likely});
const Call *false_likely = Call::as_intrinsic(op->false_value, {Call::likely});

Expr true_value = mutate(op->true_value);
Expr false_value = mutate(op->false_value);
if (const Or *o = condition.as<Or>()) {
if (is_trivial(true_value)) {
return mutate(Select::make(o->a, true_value, Select::make(o->b, true_value, false_value)));
Expr expr = Select::make(o->b, true_value, false_value);
if (false_likely) {
expr = likely(expr);
}
return mutate(Select::make(o->a, true_value, expr));
} else {
string var_name = unique_name('t');
Expr var = Variable::make(true_value.type(), var_name);
Expr expr = mutate(Select::make(o->a, var, Select::make(o->b, var, false_value)));
Expr expr = Select::make(o->b, var, false_value);
if (false_likely) {
expr = likely(expr);
}
expr = mutate(Select::make(o->a, var, expr));
return Let::make(var_name, true_value, expr);
}
} else if (const And *a = condition.as<And>()) {
if (is_trivial(false_value)) {
return mutate(Select::make(a->a, Select::make(a->b, true_value, false_value), false_value));
Expr expr = Select::make(a->b, true_value, false_value);
if (true_likely) {
expr = likely(expr);
}
return mutate(Select::make(a->a, expr, false_value));
} else {
string var_name = unique_name('t');
Expr var = Variable::make(false_value.type(), var_name);
Expr expr = mutate(Select::make(a->a, Select::make(a->b, true_value, var), var));
Expr expr = Select::make(a->b, true_value, var);
if (true_likely) {
expr = likely(expr);
}
expr = mutate(Select::make(a->a, expr, var));
return Let::make(var_name, false_value, expr);
}
} else if (const Not *n = condition.as<Not>()) {
Expand Down Expand Up @@ -1098,14 +1150,14 @@ class LowerLikelyIfInnermost : public IRMutator {

} // namespace

bool has_uncaptured_likely_tag(const Expr &e) {
HasUncapturedLikelyTag h;
bool has_uncaptured_likely_tag(const Expr &e, const Scope<> &scope) {
HasUncapturedLikelyTag h(scope);
e.accept(&h);
return h.result;
}

bool has_likely_tag(const Expr &e) {
HasLikelyTag h;
bool has_likely_tag(const Expr &e, const Scope<> &scope) {
HasLikelyTag h(scope);
e.accept(&h);
return h.result;
}
Expand Down
13 changes: 8 additions & 5 deletions src/PartitionLoops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
*/

#include "Expr.h"
#include "Scope.h"

namespace Halide {
namespace Internal {

/** Return true if an expression uses a likely tag that isn't captured
* by an enclosing Select, Min, or Max. */
bool has_uncaptured_likely_tag(const Expr &e);
/** Return true if an expression uses a likely tag that isn't captured by an
* enclosing Select, Min, or Max. The scope contains all vars that should be
* considered to have uncaptured likelies. */
bool has_uncaptured_likely_tag(const Expr &e, const Scope<> &scope);

/** Return true if an expression uses a likely tag. */
bool has_likely_tag(const Expr &e);
/** Return true if an expression uses a likely tag. The scope contains all vars
* in scope that should be considered to have likely tags. */
bool has_likely_tag(const Expr &e, const Scope<> &scope);

/** Partitions loop bodies into a prologue, a steady state, and an
* epilogue. Finds the steady state by hunting for use of clamped
Expand Down
20 changes: 20 additions & 0 deletions test/correctness/likely.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,26 @@ int main(int argc, char **argv) {
result = g.realize({10});
}

// Test for the bug described in https://github.com/halide/Halide/issues/7929
{
Func f, g, h;
Var x, y;

f(x, y) = x;
f.compute_root();

Param<int> p;
g = BoundaryConditions::repeat_edge(f, {{0, p}, {Expr(), Expr()}});

h(x, y) = g(x, y) + g(x, y + 1) + g(x, y + 2);

count_partitions(h, 3);

// Same thing with vectorization too.
h.vectorize(x, 8);
count_partitions(h, 3);
}

// The performance of this behavior is tested in
// test/performance/boundary_conditions.cpp

Expand Down

0 comments on commit ad0f24e

Please sign in to comment.