Skip to content

Commit

Permalink
Cinn schedule error (PaddlePaddle#54983)
Browse files Browse the repository at this point in the history
* [CINN] Schedule error message optimization

* format code style

* add test

* fix format

* using CINN_THROW and using flags

* optimize error msg

* do not use abtract class of error hanlder

* fix header
  • Loading branch information
ZzSean authored and cqulilujia committed Jul 24, 2023
1 parent 6c84447 commit cc36007
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 23 deletions.
43 changes: 43 additions & 0 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
Expand Down Expand Up @@ -156,6 +157,48 @@ void test_split_and_fuse2(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}

void TestSplitThrow() {
Context::Global().ResetNameId();
Expr M(32);
Expr N(32);
Expr P(32);

Target target = common::DefaultHostTarget();

Placeholder<float> A("A", {M, N});
auto B = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");

auto stages = CreateStages({A, B});

auto func = cinn::lang::LowerVec(
"test_split_throw", stages, {A, B}, {}, {}, nullptr, target, true);
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(
mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral);
auto fused = ir_sch.Fuse("B", {0, 1});
// statement that cause the exception
auto splited = ir_sch.Split(fused, {-1, -1});

auto loops = ir_sch.GetLoops("B");
fused = ir_sch.Fuse(loops);
splited = ir_sch.Split(fused, {256, -1});

Module::Builder builder("module1", target);
for (auto& i : func) {
builder.AddFunction(i);
}
auto module = builder.Build();
CodeGenC codegen(target);
codegen.SetInlineBuiltinCodes(false);
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}
TEST(IrSchedule, split_throw) {
ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet);
}

TEST(IrSchedule, reorder1) {
Context::Global().ResetNameId();
Expr M(32);
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ gather_srcs(
ir.cc
ir_base.cc
ir_schedule.cc
ir_schedule_error.cc
ir_schedule_util.cc
ir_visitor.cc
ir_printer.cc
Expand Down
44 changes: 38 additions & 6 deletions paddle/cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_operators.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/ir/ir_schedule_util.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
Expand All @@ -41,6 +42,8 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"

DECLARE_int32(cinn_schedule_error_message_level);

namespace cinn {
namespace ir {

Expand All @@ -50,8 +53,15 @@ namespace ir {
class ScheduleImpl {
public:
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false)
: module_expr_(module_expr), debug_flag_(debug_flag) {}
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<ScheduleErrorMessageLevel>(
FLAGS_cinn_schedule_error_message_level ||
static_cast<int>(err_msg_level));
}
explicit ScheduleImpl(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}

Expand Down Expand Up @@ -129,8 +139,26 @@ class ScheduleImpl {

ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_ =
ScheduleErrorMessageLevel::kGeneral;
};

/** \brief A macro that guards the beginning of each implementation of schedule
*/
#define CINN_IR_SCHEDULE_BEGIN() try {
/**
* \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential
* errors and error message printing.
* @param primitive A string representing the kind of schedule primitive.
* @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message
* printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \
}

std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
const std::vector<int>& factors) {
CHECK(loop.As<ir::For>())
Expand All @@ -147,7 +175,10 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
<< ") at loop:\n"
<< loop;

auto processed_factors = ValidateFactors(factors, tot_extent);
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(),
processed_factors.end(),
1,
Expand Down Expand Up @@ -1194,7 +1225,6 @@ struct LoopReconstructor : public ir::IRMutator<> {
return utils::Join(new_var_names, ",");
}

private:
public:
/*! \brief The root block */
Expr root_;
Expand Down Expand Up @@ -2286,8 +2316,10 @@ IRSchedule::IRSchedule() {}

IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag) {
impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag);
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
impl_ =
std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
}

Expand Down
18 changes: 17 additions & 1 deletion paddle/cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
namespace cinn {
namespace ir {

/**
* \brief Indicates the level of printing error message in the current Schedule
*/
enum class ScheduleErrorMessageLevel : int32_t {
/** \brief Print an error message in short mode.
* Short mode shows which and where the schedule error happens*/
kGeneral = 0,
/** \brief Print an error message in detailed mode.
* Detailed mode shows which and where the schedule error happens, and the
* schedule input parameters.
*/
kDetailed = 1,
};

/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
Expand Down Expand Up @@ -70,7 +84,9 @@ class IRSchedule {
IRSchedule();
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false);
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral);
IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed = -1);
Expand Down
74 changes: 74 additions & 0 deletions paddle/cinn/ir/ir_schedule_error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"

namespace cinn {
namespace ir {

std::string IRScheduleErrorHandler::GeneralErrorMessage() const {
return this->err_msg_;
}

std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
std::ostringstream os;
os << GeneralErrorMessage();
os << "[Expr info] The Expr of current schedule is: "
<< this->module_expr_.GetExprs() << std::endl;
return os.str();
}

std::string IRScheduleErrorHandler::FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const {
std::ostringstream os;
std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed
? DetailedErrorMessage()
: GeneralErrorMessage();

os << "[IRScheduleError] An error occurred in the scheduel primitive <"
<< primitive << ">. " << std::endl;
os << "[Error info] " << err_msg;
return os.str();
}

std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx) {
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << factor << std::endl;
return os.str();
}

std::string InferFactorErrorMessage() {
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or have "
"more than one -1!"
<< std::endl;
return os.str();
}

std::string FactorProductErrorMessage() {
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
return os.str();
}

} // namespace ir
} // namespace cinn
Loading

0 comments on commit cc36007

Please sign in to comment.