Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to pass explicit RDom to Function::define_update #8284

Merged
merged 4 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,9 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
val = substitute_self_reference(val, func_name, intm.function(), vars_rename);
update_vals[i] = val;
}
intm(update_args) = Tuple(update_vals);
// There may not actually be a reference to the RDom in the args or values,
// so we use Function::define_update, which lets pass pass an explicit RDom.
intm.function().define_update(update_args, update_vals, intm_rdom.domain());

// Determine the dims and schedule of the update definition of the
// intermediate Func. We copy over the schedule from the original
Expand Down
19 changes: 18 additions & 1 deletion src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ void Function::create_output_buffers(const std::vector<Type> &types, int dims) c
}
}

void Function::define_update(const vector<Expr> &_args, vector<Expr> values) {
void Function::define_update(const vector<Expr> &_args, vector<Expr> values, const ReductionDomain &rdom) {
int update_idx = static_cast<int>(contents->updates.size());

user_assert(!name().empty())
Expand Down Expand Up @@ -767,6 +767,23 @@ void Function::define_update(const vector<Expr> &_args, vector<Expr> values) {
for (const auto &value : values) {
value.accept(&check);
}
if (!check.reduction_domain.defined()) {
// Use the provided one
check.reduction_domain = rdom;
} else if (rdom.defined()) {
// This is an internal error because the ability to pass an explicit
// RDom is not exposed to the front-end. At the time of writing this is
// only used by rfactor.
internal_assert(rdom.same_as(check.reduction_domain))
<< "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
<< "Explicit reduction domain passed to Function::define_update, "
<< "but another reduction domain was referred to by the args or values.\n"
<< "Explicit reduction domain passed:\n"
<< RDom(rdom) << "\n"
<< "Found reduction domain:\n"
<< RDom(check.reduction_domain) << "\n";
}

if (check.reduction_domain.defined()) {
check.unbound_reduction_vars_ok = true;
check.reduction_domain.predicate().accept(&check);
Expand Down
17 changes: 9 additions & 8 deletions src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Definition.h"
#include "Expr.h"
#include "FunctionPtr.h"
#include "Reduction.h"
#include "Schedule.h"

namespace Halide {
Expand Down Expand Up @@ -117,15 +118,15 @@ class Function {
* reduction domain */
void define(const std::vector<std::string> &args, std::vector<Expr> values);

/** Add an update definition to this function. It must already
* have a pure definition but not an update definition, and the
* length of args must match the length of args used in the pure
* definition. 'value' must depend on some reduction domain, and
* may contain variables from that domain as well as pure
* variables. Any pure variables must also appear as Variables in
* the args array, and they must have the same name as the pure
/** Add an update definition to this function. It must already have a pure
* definition but not an update definition, and the length of args must
* match the length of args used in the pure definition. 'value' may depend
* on some reduction domain may contain variables from that domain as well
* as pure variables. A reduction domain may also be introduced by passing
* it as the last argument. Any pure variables must also appear as Variables
* in the args array, and they must have the same name as the pure
* definition's argument in the same index. */
void define_update(const std::vector<Expr> &args, std::vector<Expr> values);
void define_update(const std::vector<Expr> &args, std::vector<Expr> values, const ReductionDomain &rdom = ReductionDomain{});

/** Accept a visitor to visit all of the definitions and arguments
* of this function. */
Expand Down
40 changes: 40 additions & 0 deletions test/correctness/rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,45 @@ int self_assignment_rfactor_test() {
return 0;
}

int inlined_rfactor_with_disappearing_rvar_test() {
ImageParam in1(Float(32), 1);

Var x("x"), r("r"), u("u");
RVar ro("ro"), ri("ri");
Func f("f"), g("g");
Func sum1("sum1");

RDom rdom(0, 16);
g(r, x) = in1(x);
f(x) = sum(rdom, g(rdom, x), sum1);

{
// Some of the autoschedulers execute code like the below, which can
// erase an RDom from the LHS and RHS of a Func, but not from the dims
// list, which confused the implementation of rfactor (see
// https://github.com/halide/Halide/issues/8282)
using namespace Halide::Internal;
std::vector<Function> outputs = {f.function()};
auto env = build_environment(outputs);

for (auto &iter : env) {
iter.second.lock_loop_levels();
}

inline_function(sum1.function(), g.function());
}

sum1.compute_root()
.update(0)
.split(rdom, ro, ri, 8, TailStrategy::GuardWithIf)
.rfactor({{ro, u}})
.compute_root();

// This would crash with a missing symbol error prior to #8282 being fixed.
f.compile_jit();
return 0;
}

} // namespace

int main(int argc, char **argv) {
Expand Down Expand Up @@ -1032,6 +1071,7 @@ int main(int argc, char **argv) {
{"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test},
{"complex multiply rfactor test", complex_multiply_rfactor_test},
{"argmin rfactor test", argmin_rfactor_test},
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
};

using Sharder = Halide::Internal::Test::Sharder;
Expand Down
Loading