Skip to content

Commit

Permalink
Teach unrolling to exploit conditions in enclosing ifs (halide#7969)
Browse files Browse the repository at this point in the history
* Teach unrolling to exploit conditions in enclosing ifs

Fixes halide#7968

* Handle vectorization as well

* Remove unused usings

* Add missing print
  • Loading branch information
abadams authored and ardier committed Mar 3, 2024
1 parent b78c409 commit 1745fc3
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 89 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ SOURCE_FILES = \
BoundaryConditions.cpp \
Bounds.cpp \
BoundsInference.cpp \
BoundConstantExtentLoops.cpp \
BoundSmallAllocations.cpp \
Buffer.cpp \
Callable.cpp \
Expand Down Expand Up @@ -654,6 +655,7 @@ HEADER_FILES = \
BoundaryConditions.h \
Bounds.h \
BoundsInference.h \
BoundConstantExtentLoops.h \
BoundSmallAllocations.h \
Buffer.h \
Callable.h \
Expand Down
136 changes: 136 additions & 0 deletions src/BoundConstantExtentLoops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include "BoundConstantExtentLoops.h"
#include "Bounds.h"
#include "CSE.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Simplify.h"
#include "SimplifyCorrelatedDifferences.h"
#include "Substitute.h"

namespace Halide {
namespace Internal {

namespace {
class BoundLoops : public IRMutator {
using IRMutator::visit;

std::vector<std::pair<std::string, Expr>> lets;

Stmt visit(const LetStmt *op) override {
if (is_pure(op->value)) {
lets.emplace_back(op->name, op->value);
Stmt s = IRMutator::visit(op);
lets.pop_back();
return s;
} else {
return IRMutator::visit(op);
}
}

std::vector<Expr> facts;
Stmt visit(const IfThenElse *op) override {
facts.push_back(op->condition);
Stmt then_case = mutate(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
facts.back() = simplify(!op->condition);
else_case = mutate(op->else_case);
}
facts.pop_back();
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return op;
} else {
return IfThenElse::make(op->condition, then_case, else_case);
}
}

Stmt visit(const For *op) override {
if (is_const(op->extent)) {
// Nothing needs to be done
return IRMutator::visit(op);
}

if (op->for_type == ForType::Unrolled ||
op->for_type == ForType::Vectorized) {
// Give it one last chance to simplify to an int
Expr extent = simplify(op->extent);
Stmt body = op->body;
const IntImm *e = extent.as<IntImm>();

if (e == nullptr) {
// We're about to hard fail. Get really aggressive
// with the simplifier.
for (auto it = lets.rbegin(); it != lets.rend(); it++) {
extent = Let::make(it->first, it->second, extent);
}
extent = remove_likelies(extent);
extent = substitute_in_all_lets(extent);
extent = simplify(extent,
true,
Scope<Interval>::empty_scope(),
Scope<ModulusRemainder>::empty_scope(),
facts);
e = extent.as<IntImm>();
}

Expr extent_upper;
if (e == nullptr) {
// Still no luck. Try taking an upper bound and
// injecting an if statement around the body.
extent_upper = find_constant_bound(extent, Direction::Upper, Scope<Interval>());
if (extent_upper.defined()) {
e = extent_upper.as<IntImm>();
body =
IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) <
op->min + op->extent),
body);
}
}

if (e == nullptr && permit_failed_unroll && op->for_type == ForType::Unrolled) {
// Still no luck, but we're allowed to fail. Rewrite
// to a serial loop.
user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n";
body = mutate(body);
return For::make(op->name, op->min, op->extent,
ForType::Serial, op->partition_policy, op->device_api, std::move(body));
}

user_assert(e)
<< "Can only " << (op->for_type == ForType::Unrolled ? "unroll" : "vectorize")
<< " for loops over a constant extent.\n"
<< "Loop over " << op->name << " has extent " << extent << ".\n";
body = mutate(body);

return For::make(op->name, op->min, e,
op->for_type, op->partition_policy, op->device_api, std::move(body));
} else {
return IRMutator::visit(op);
}
}
bool permit_failed_unroll = false;

public:
BoundLoops() {
// Experimental autoschedulers may want to unroll without
// being totally confident the loop will indeed turn out
// to be constant-sized. If this feature continues to be
// important, we need to expose it in the scheduling
// language somewhere, but how? For now we do something
// ugly and expedient.

// For the tracking issue to fix this, see
// https://github.com/halide/Halide/issues/3479
permit_failed_unroll = get_env_variable("HL_PERMIT_FAILED_UNROLL") == "1";
}
};

} // namespace

Stmt bound_constant_extent_loops(const Stmt &s) {
return BoundLoops().mutate(s);
}

} // namespace Internal
} // namespace Halide
24 changes: 24 additions & 0 deletions src/BoundConstantExtentLoops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H
#define HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H

/** \file
* Defines the lowering pass that enforces a constant extent on all
* vectorized or unrolled loops.
*/

#include "Expr.h"

namespace Halide {
namespace Internal {

/** Replace all loop extents of unrolled or vectorized loops with constants, by
* substituting and simplifying as needed. If we can't determine a constant
* extent, but can determine a constant upper bound, inject an if statement into
* the body. If we can't even determine a constant upper bound, throw a user
* error. */
Stmt bound_constant_extent_loops(const Stmt &s);

} // namespace Internal
} // namespace Halide

#endif
4 changes: 2 additions & 2 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,11 @@ class BoundsInference : public IRMutator {
}

// Dump out the region required of each stage for debugging.

/*
debug(0) << "Box required of " << producer.name
<< " by " << consumer.name
<< " stage " << consumer.stage << ":\n";
<< " stage " << consumer.stage << ":\n"
<< " used: " << b.used << "\n";
for (size_t k = 0; k < b.size(); k++) {
debug(0) << " " << b[k].min << " ... " << b[k].max << "\n";
}
Expand Down
4 changes: 3 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ set(HEADER_FILES
BoundaryConditions.h
Bounds.h
BoundsInference.h
BoundSmallAllocations.h
BoundConstantExtentLoops.h
BoundSmallAllocations.h
Buffer.h
Callable.h
CanonicalizeGPUVars.h
Expand Down Expand Up @@ -189,6 +190,7 @@ set(SOURCE_FILES
BoundaryConditions.cpp
Bounds.cpp
BoundsInference.cpp
BoundConstantExtentLoops.cpp
BoundSmallAllocations.cpp
Buffer.cpp
Callable.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "AddParameterChecks.h"
#include "AllocationBoundsInference.h"
#include "AsyncProducers.h"
#include "BoundConstantExtentLoops.h"
#include "BoundSmallAllocations.h"
#include "Bounds.h"
#include "BoundsInference.h"
Expand Down Expand Up @@ -312,6 +313,10 @@ void lower_impl(const vector<Function> &output_funcs,
s = simplify_correlated_differences(s);
log("Lowering after simplifying correlated differences:", s);

debug(1) << "Bounding constant extent loops...\n";
s = bound_constant_extent_loops(s);
log("Lowering after bounding constant extent loops:", s);

debug(1) << "Unrolling...\n";
s = unroll_loops(s);
log("Lowering after unrolling:", s);
Expand Down
14 changes: 12 additions & 2 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,13 @@ Simplify::ScopedFact::~ScopedFact() {

Expr simplify(const Expr &e, bool remove_dead_let_stmts,
const Scope<Interval> &bounds,
const Scope<ModulusRemainder> &alignment) {
const Scope<ModulusRemainder> &alignment,
const std::vector<Expr> &assumptions) {
Simplify m(remove_dead_let_stmts, &bounds, &alignment);
std::vector<Simplify::ScopedFact> facts;
for (const Expr &a : assumptions) {
facts.push_back(m.scoped_truth(a));
}
Expr result = m.mutate(e, nullptr);
if (m.in_unreachable) {
return unreachable(e.type());
Expand All @@ -366,8 +371,13 @@ Expr simplify(const Expr &e, bool remove_dead_let_stmts,

Stmt simplify(const Stmt &s, bool remove_dead_let_stmts,
const Scope<Interval> &bounds,
const Scope<ModulusRemainder> &alignment) {
const Scope<ModulusRemainder> &alignment,
const std::vector<Expr> &assumptions) {
Simplify m(remove_dead_let_stmts, &bounds, &alignment);
std::vector<Simplify::ScopedFact> facts;
for (const Expr &a : assumptions) {
facts.push_back(m.scoped_truth(a));
}
Stmt result = m.mutate(s);
if (m.in_unreachable) {
return Evaluate::make(unreachable());
Expand Down
17 changes: 10 additions & 7 deletions src/Simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@
namespace Halide {
namespace Internal {

/** Perform a a wide range of simplifications to expressions and
* statements, including constant folding, substituting in trivial
* values, arithmetic rearranging, etc. Simplifies across let
* statements, so must not be called on stmts with dangling or
* repeated variable names.
/** Perform a wide range of simplifications to expressions and statements,
* including constant folding, substituting in trivial values, arithmetic
* rearranging, etc. Simplifies across let statements, so must not be called on
* stmts with dangling or repeated variable names. Can optionally be passed
* known bounds of any variables, known alignment properties, and any other
* Exprs that should be assumed to be true.
*/
// @{
Stmt simplify(const Stmt &, bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope());
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
Expr simplify(const Expr &, bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope());
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
// @}

/** Attempt to statically prove an expression is true using the simplifier. */
Expand Down
Loading

0 comments on commit 1745fc3

Please sign in to comment.