From 9d64e24b8dd1f65c7c12ce4941a455e7f1a89067 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 25 Sep 2023 09:38:12 -0700 Subject: [PATCH] Fix #7851 In one place in PartitionLoops and in another place in the simplifier we were neglecting to consider nested vectorization. I added the fuzzer output as a new test, because I have no idea how I'd generate this error with human-readable code. It stems from an interaction of several tail strategies. --- src/IR.cpp | 2 + src/PartitionLoops.cpp | 2 +- src/Simplify_Stmts.cpp | 4 ++ test/correctness/CMakeLists.txt | 1 + test/correctness/fuzz_schedule.cpp | 60 ++++++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 test/correctness/fuzz_schedule.cpp diff --git a/src/IR.cpp b/src/IR.cpp index 244d142cfb60..bf53aaac476c 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -574,6 +574,8 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { internal_assert(condition.defined() && then_case.defined()) << "IfThenElse of undefined\n"; // else_case may be null. + internal_assert(condition.type().is_scalar()) << "IfThenElse with vector condition\n"; + IfThenElse *node = new IfThenElse; node->condition = std::move(condition); node->then_case = std::move(then_case); diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index d3b296dc0165..477b91173f19 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -263,7 +263,7 @@ class FindSimplifications : public IRVisitor { } condition = remove_likelies(condition); Simplification s = {condition, std::move(old), std::move(likely_val), std::move(unlikely_val), true}; - if (s.condition.type().is_vector()) { + while (s.condition.type().is_vector()) { s.condition = simplify(s.condition); if (const Broadcast *b = s.condition.as()) { s.condition = b->value; diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 5be05e42e6c6..6a8e53ccfa73 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -328,6 +328,10 @@ Stmt Simplify::visit(const Store *op) { const Load *load = value.as(); const Broadcast *scalar_pred = predicate.as(); + if (scalar_pred && !scalar_pred->value.type().is_scalar()) { + // Nested vectorization + scalar_pred = nullptr; + } ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index a4a25eeae87f..0fe26d266ace 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -115,6 +115,7 @@ tests(GROUPS correctness fuse_gpu_threads.cpp fused_where_inner_extent_is_zero.cpp fuzz_float_stores.cpp + fuzz_schedule.cpp gameoflife.cpp gather.cpp gpu_allocation_cache.cpp diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp new file mode 100644 index 000000000000..4a027c02cb51 --- /dev/null +++ b/test/correctness/fuzz_schedule.cpp @@ -0,0 +1,60 @@ +#include "Halide.h" + +using namespace Halide; + +void check_blur_output(const Buffer &out, const Buffer &correct) { + for (int y = 0; y < out.height(); y++) { + for (int x = 0; x < out.width(); x++) { + if (out(x, y) != correct(x, y)) { + printf("out(%d, %d) = %d instead of %d\n", + x, y, out(x, y), correct(x, y)); + exit(1); + } + } + } +} + +int main(int argc, char **argv) { + // This test is for schedules that crash the compiler found via fuzzing that + // are hard to otherwise reproduce. We don't need to check the output. + + Buffer correct; + { + // An unscheduled instance to act as a reference output + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + correct = blurry.realize({32, 32}); + } + + // https://github.com/halide/Halide/issues/7851 + { + Func input("input"); + Func local_sum("local_sum"); + Func blurry("blurry"); + Var x("x"), y("y"); + input(x, y) = 2 * x + 5 * y; + RDom r(-2, 5, -2, 5); + local_sum(x, y) = 0; + local_sum(x, y) += input(x + r.x, y + r.y); + blurry(x, y) = cast(local_sum(x, y) / 25); + Var yo("yo"), yi("yi"), xo("xo"), xi("xi"), yo_x_f("yo_x_f"), yo_x_fo("yo_x_fo"), yo_x_fi("yo_x_fi"); + blurry.split(y, yo, yi, 2, TailStrategy::RoundUp).fuse(yo, x, yo_x_f).vectorize(yi).split(yo_x_f, yo_x_fo, yo_x_fi, 2, TailStrategy::Predicate).reorder(yo_x_fo, yo_x_fi, yi); + input.split(y, yo, yi, 2, TailStrategy::PredicateStores).fuse(yo, x, yo_x_f).vectorize(yi).split(yo_x_f, yo_x_fo, yo_x_fi, 2, TailStrategy::Predicate).reorder(yo_x_fo, yo_x_fi, yi); + blurry.store_root(); + input.compute_at(blurry, yi); + Pipeline p({blurry}); + Buffer buf = p.realize({32, 32}); + check_blur_output(buf, correct); + } + + printf("Success!\n"); + + return 0; +}