forked from halide/Halide
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Teach unrolling to exploit conditions in enclosing ifs (halide#7969)
* Teach unrolling to exploit conditions in enclosing ifs Fixes halide#7968 * Handle vectorization as well * Remove unused usings * Add missing print
- Loading branch information
Showing
13 changed files
with
298 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#include "BoundConstantExtentLoops.h" | ||
#include "Bounds.h" | ||
#include "CSE.h" | ||
#include "IRMutator.h" | ||
#include "IROperator.h" | ||
#include "Simplify.h" | ||
#include "SimplifyCorrelatedDifferences.h" | ||
#include "Substitute.h" | ||
|
||
namespace Halide { | ||
namespace Internal { | ||
|
||
namespace { | ||
class BoundLoops : public IRMutator { | ||
using IRMutator::visit; | ||
|
||
std::vector<std::pair<std::string, Expr>> lets; | ||
|
||
Stmt visit(const LetStmt *op) override { | ||
if (is_pure(op->value)) { | ||
lets.emplace_back(op->name, op->value); | ||
Stmt s = IRMutator::visit(op); | ||
lets.pop_back(); | ||
return s; | ||
} else { | ||
return IRMutator::visit(op); | ||
} | ||
} | ||
|
||
std::vector<Expr> facts; | ||
Stmt visit(const IfThenElse *op) override { | ||
facts.push_back(op->condition); | ||
Stmt then_case = mutate(op->then_case); | ||
Stmt else_case; | ||
if (op->else_case.defined()) { | ||
facts.back() = simplify(!op->condition); | ||
else_case = mutate(op->else_case); | ||
} | ||
facts.pop_back(); | ||
if (then_case.same_as(op->then_case) && | ||
else_case.same_as(op->else_case)) { | ||
return op; | ||
} else { | ||
return IfThenElse::make(op->condition, then_case, else_case); | ||
} | ||
} | ||
|
||
Stmt visit(const For *op) override { | ||
if (is_const(op->extent)) { | ||
// Nothing needs to be done | ||
return IRMutator::visit(op); | ||
} | ||
|
||
if (op->for_type == ForType::Unrolled || | ||
op->for_type == ForType::Vectorized) { | ||
// Give it one last chance to simplify to an int | ||
Expr extent = simplify(op->extent); | ||
Stmt body = op->body; | ||
const IntImm *e = extent.as<IntImm>(); | ||
|
||
if (e == nullptr) { | ||
// We're about to hard fail. Get really aggressive | ||
// with the simplifier. | ||
for (auto it = lets.rbegin(); it != lets.rend(); it++) { | ||
extent = Let::make(it->first, it->second, extent); | ||
} | ||
extent = remove_likelies(extent); | ||
extent = substitute_in_all_lets(extent); | ||
extent = simplify(extent, | ||
true, | ||
Scope<Interval>::empty_scope(), | ||
Scope<ModulusRemainder>::empty_scope(), | ||
facts); | ||
e = extent.as<IntImm>(); | ||
} | ||
|
||
Expr extent_upper; | ||
if (e == nullptr) { | ||
// Still no luck. Try taking an upper bound and | ||
// injecting an if statement around the body. | ||
extent_upper = find_constant_bound(extent, Direction::Upper, Scope<Interval>()); | ||
if (extent_upper.defined()) { | ||
e = extent_upper.as<IntImm>(); | ||
body = | ||
IfThenElse::make(likely_if_innermost(Variable::make(Int(32), op->name) < | ||
op->min + op->extent), | ||
body); | ||
} | ||
} | ||
|
||
if (e == nullptr && permit_failed_unroll && op->for_type == ForType::Unrolled) { | ||
// Still no luck, but we're allowed to fail. Rewrite | ||
// to a serial loop. | ||
user_warning << "HL_PERMIT_FAILED_UNROLL is allowing us to unroll a non-constant loop into a serial loop. Did you mean to do this?\n"; | ||
body = mutate(body); | ||
return For::make(op->name, op->min, op->extent, | ||
ForType::Serial, op->partition_policy, op->device_api, std::move(body)); | ||
} | ||
|
||
user_assert(e) | ||
<< "Can only " << (op->for_type == ForType::Unrolled ? "unroll" : "vectorize") | ||
<< " for loops over a constant extent.\n" | ||
<< "Loop over " << op->name << " has extent " << extent << ".\n"; | ||
body = mutate(body); | ||
|
||
return For::make(op->name, op->min, e, | ||
op->for_type, op->partition_policy, op->device_api, std::move(body)); | ||
} else { | ||
return IRMutator::visit(op); | ||
} | ||
} | ||
bool permit_failed_unroll = false; | ||
|
||
public: | ||
BoundLoops() { | ||
// Experimental autoschedulers may want to unroll without | ||
// being totally confident the loop will indeed turn out | ||
// to be constant-sized. If this feature continues to be | ||
// important, we need to expose it in the scheduling | ||
// language somewhere, but how? For now we do something | ||
// ugly and expedient. | ||
|
||
// For the tracking issue to fix this, see | ||
// https://github.com/halide/Halide/issues/3479 | ||
permit_failed_unroll = get_env_variable("HL_PERMIT_FAILED_UNROLL") == "1"; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
Stmt bound_constant_extent_loops(const Stmt &s) { | ||
return BoundLoops().mutate(s); | ||
} | ||
|
||
} // namespace Internal | ||
} // namespace Halide |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H | ||
#define HALIDE_BOUND_CONSTANT_EXTENT_LOOPS_H | ||
|
||
/** \file | ||
* Defines the lowering pass that enforces a constant extent on all | ||
* vectorized or unrolled loops. | ||
*/ | ||
|
||
#include "Expr.h" | ||
|
||
namespace Halide { | ||
namespace Internal { | ||
|
||
/** Replace all loop extents of unrolled or vectorized loops with constants, by | ||
* substituting and simplifying as needed. If we can't determine a constant | ||
* extent, but can determine a constant upper bound, inject an if statement into | ||
* the body. If we can't even determine a constant upper bound, throw a user | ||
* error. */ | ||
Stmt bound_constant_extent_loops(const Stmt &s); | ||
|
||
} // namespace Internal | ||
} // namespace Halide | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.