Skip to content

Commit

Permalink
Add ability to pass explicit RDom to Function::define_update (#8284)
Browse files Browse the repository at this point in the history
* Add ability to pass explicit RDom to Function::define_update

And use it in rfactor. There are cases where an RDom is attached to the
original Func but not actually referred to in the LHS or RHS.

Fixes #8282

* Fix comment
  • Loading branch information
abadams authored Jun 23, 2024
1 parent 22367de commit 61df9ba
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 10 deletions.
4 changes: 3 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,9 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
val = substitute_self_reference(val, func_name, intm.function(), vars_rename);
update_vals[i] = val;
}
intm(update_args) = Tuple(update_vals);
// There may not actually be a reference to the RDom in the args or values,
// so we use Function::define_update, which lets pass pass an explicit RDom.
intm.function().define_update(update_args, update_vals, intm_rdom.domain());

// Determine the dims and schedule of the update definition of the
// intermediate Func. We copy over the schedule from the original
Expand Down
19 changes: 18 additions & 1 deletion src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ void Function::create_output_buffers(const std::vector<Type> &types, int dims) c
}
}

void Function::define_update(const vector<Expr> &_args, vector<Expr> values) {
void Function::define_update(const vector<Expr> &_args, vector<Expr> values, const ReductionDomain &rdom) {
int update_idx = static_cast<int>(contents->updates.size());

user_assert(!name().empty())
Expand Down Expand Up @@ -767,6 +767,23 @@ void Function::define_update(const vector<Expr> &_args, vector<Expr> values) {
for (const auto &value : values) {
value.accept(&check);
}
if (!check.reduction_domain.defined()) {
// Use the provided one
check.reduction_domain = rdom;
} else if (rdom.defined()) {
// This is an internal error because the ability to pass an explicit
// RDom is not exposed to the front-end. At the time of writing this is
// only used by rfactor.
internal_assert(rdom.same_as(check.reduction_domain))
<< "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
<< "Explicit reduction domain passed to Function::define_update, "
<< "but another reduction domain was referred to by the args or values.\n"
<< "Explicit reduction domain passed:\n"
<< RDom(rdom) << "\n"
<< "Found reduction domain:\n"
<< RDom(check.reduction_domain) << "\n";
}

if (check.reduction_domain.defined()) {
check.unbound_reduction_vars_ok = true;
check.reduction_domain.predicate().accept(&check);
Expand Down
17 changes: 9 additions & 8 deletions src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Definition.h"
#include "Expr.h"
#include "FunctionPtr.h"
#include "Reduction.h"
#include "Schedule.h"

namespace Halide {
Expand Down Expand Up @@ -117,15 +118,15 @@ class Function {
* reduction domain */
void define(const std::vector<std::string> &args, std::vector<Expr> values);

/** Add an update definition to this function. It must already
* have a pure definition but not an update definition, and the
* length of args must match the length of args used in the pure
* definition. 'value' must depend on some reduction domain, and
* may contain variables from that domain as well as pure
* variables. Any pure variables must also appear as Variables in
* the args array, and they must have the same name as the pure
/** Add an update definition to this function. It must already have a pure
* definition but not an update definition, and the length of args must
* match the length of args used in the pure definition. 'value' may depend
* on some reduction domain may contain variables from that domain as well
* as pure variables. A reduction domain may also be introduced by passing
* it as the last argument. Any pure variables must also appear as Variables
* in the args array, and they must have the same name as the pure
* definition's argument in the same index. */
void define_update(const std::vector<Expr> &args, std::vector<Expr> values);
void define_update(const std::vector<Expr> &args, std::vector<Expr> values, const ReductionDomain &rdom = ReductionDomain{});

/** Accept a visitor to visit all of the definitions and arguments
* of this function. */
Expand Down
40 changes: 40 additions & 0 deletions test/correctness/rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,45 @@ int self_assignment_rfactor_test() {
return 0;
}

int inlined_rfactor_with_disappearing_rvar_test() {
ImageParam in1(Float(32), 1);

Var x("x"), r("r"), u("u");
RVar ro("ro"), ri("ri");
Func f("f"), g("g");
Func sum1("sum1");

RDom rdom(0, 16);
g(r, x) = in1(x);
f(x) = sum(rdom, g(rdom, x), sum1);

{
// Some of the autoschedulers execute code like the below, which can
// erase an RDom from the LHS and RHS of a Func, but not from the dims
// list, which confused the implementation of rfactor (see
// https://github.com/halide/Halide/issues/8282)
using namespace Halide::Internal;
std::vector<Function> outputs = {f.function()};
auto env = build_environment(outputs);

for (auto &iter : env) {
iter.second.lock_loop_levels();
}

inline_function(sum1.function(), g.function());
}

sum1.compute_root()
.update(0)
.split(rdom, ro, ri, 8, TailStrategy::GuardWithIf)
.rfactor({{ro, u}})
.compute_root();

// This would crash with a missing symbol error prior to #8282 being fixed.
f.compile_jit();
return 0;
}

} // namespace

int main(int argc, char **argv) {
Expand Down Expand Up @@ -1032,6 +1071,7 @@ int main(int argc, char **argv) {
{"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test},
{"complex multiply rfactor test", complex_multiply_rfactor_test},
{"argmin rfactor test", argmin_rfactor_test},
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
};

using Sharder = Halide::Internal::Test::Sharder;
Expand Down

0 comments on commit 61df9ba

Please sign in to comment.