From cc360071a13e0b75376fcf8b477f574835eb723c Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 13 Jul 2023 14:04:18 +0800 Subject: [PATCH] Cinn schedule error (#54983) * [CINN] Schedule error message optimization * format code style * add test * fix format * using CINN_THROW and using flags * optimize error msg * do not use abtract class of error hanlder * fix header --- paddle/cinn/backends/ir_schedule_test.cc | 43 ++++++ paddle/cinn/ir/CMakeLists.txt | 1 + paddle/cinn/ir/ir_schedule.cc | 44 +++++- paddle/cinn/ir/ir_schedule.h | 18 ++- paddle/cinn/ir/ir_schedule_error.cc | 74 ++++++++++ paddle/cinn/ir/ir_schedule_error.h | 175 +++++++++++++++++++++++ paddle/cinn/ir/ir_schedule_util.cc | 32 +++-- paddle/cinn/ir/ir_schedule_util.h | 4 +- paddle/cinn/runtime/flags.cc | 5 + 9 files changed, 373 insertions(+), 23 deletions(-) create mode 100644 paddle/cinn/ir/ir_schedule_error.cc create mode 100644 paddle/cinn/ir/ir_schedule_error.h diff --git a/paddle/cinn/backends/ir_schedule_test.cc b/paddle/cinn/backends/ir_schedule_test.cc index 427d4e0767c06..4b2b7abbb3604 100644 --- a/paddle/cinn/backends/ir_schedule_test.cc +++ b/paddle/cinn/backends/ir_schedule_test.cc @@ -25,6 +25,7 @@ #include "paddle/cinn/backends/codegen_cuda_dev.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/ir_schedule_error.h" #include "paddle/cinn/lang/lower.h" #include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/remove_schedule_block.h" @@ -156,6 +157,48 @@ void test_split_and_fuse2(void* _args, int32_t num_args) ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); } +void TestSplitThrow() { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec( + "test_split_throw", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch( + mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral); + auto fused = ir_sch.Fuse("B", {0, 1}); + // statement that cause the exception + auto splited = ir_sch.Split(fused, {-1, -1}); + + auto loops = ir_sch.GetLoops("B"); + fused = ir_sch.Fuse(loops); + splited = ir_sch.Split(fused, {256, -1}); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); +} +TEST(IrSchedule, split_throw) { + ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet); +} + TEST(IrSchedule, reorder1) { Context::Global().ResetNameId(); Expr M(32); diff --git a/paddle/cinn/ir/CMakeLists.txt b/paddle/cinn/ir/CMakeLists.txt index 533acdd680236..fad631ec34dd7 100755 --- a/paddle/cinn/ir/CMakeLists.txt +++ b/paddle/cinn/ir/CMakeLists.txt @@ -8,6 +8,7 @@ gather_srcs( ir.cc ir_base.cc ir_schedule.cc + ir_schedule_error.cc ir_schedule_util.cc ir_visitor.cc ir_printer.cc diff --git a/paddle/cinn/ir/ir_schedule.cc b/paddle/cinn/ir/ir_schedule.cc index de48a7c28d8f1..288e2db4832a0 100644 --- a/paddle/cinn/ir/ir_schedule.cc +++ b/paddle/cinn/ir/ir_schedule.cc @@ -33,6 +33,7 @@ #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_operators.h" #include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/ir_schedule_error.h" #include "paddle/cinn/ir/ir_schedule_util.h" #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/lang/compute.h" @@ -41,6 +42,8 @@ #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/utils/string.h" +DECLARE_int32(cinn_schedule_error_message_level); + namespace cinn { namespace ir { @@ -50,8 +53,15 @@ namespace ir { class ScheduleImpl { public: ScheduleImpl() = default; - explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false) - : module_expr_(module_expr), debug_flag_(debug_flag) {} + explicit ScheduleImpl(const ModuleExpr& module_expr, + bool debug_flag = false, + ScheduleErrorMessageLevel err_msg_level = + ScheduleErrorMessageLevel::kGeneral) + : module_expr_(module_expr), debug_flag_(debug_flag) { + err_msg_level_ = static_cast( + FLAGS_cinn_schedule_error_message_level || + static_cast(err_msg_level)); + } explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {} @@ -129,8 +139,26 @@ class ScheduleImpl { ModuleExpr module_expr_; bool debug_flag_{false}; + ScheduleErrorMessageLevel err_msg_level_ = + ScheduleErrorMessageLevel::kGeneral; }; +/** \brief A macro that guards the beginning of each implementation of schedule + */ +#define CINN_IR_SCHEDULE_BEGIN() try { +/** + * \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential + * errors and error message printing. + * @param primitive A string representing the kind of schedule primitive. + * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message + * printing + */ +#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \ + } \ + catch (const IRScheduleErrorHandler& err_hanlder) { \ + CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \ + } + std::vector ScheduleImpl::Split(const Expr& loop, const std::vector& factors) { CHECK(loop.As()) @@ -147,7 +175,10 @@ std::vector ScheduleImpl::Split(const Expr& loop, << ") at loop:\n" << loop; - auto processed_factors = ValidateFactors(factors, tot_extent); + std::vector processed_factors; + CINN_IR_SCHEDULE_BEGIN(); + processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_); + CINN_IR_SCHEDULE_END("split", this->err_msg_level_); int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, @@ -1194,7 +1225,6 @@ struct LoopReconstructor : public ir::IRMutator<> { return utils::Join(new_var_names, ","); } - private: public: /*! \brief The root block */ Expr root_; @@ -2286,8 +2316,10 @@ IRSchedule::IRSchedule() {} IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, - bool debug_flag) { - impl_ = std::make_unique(module_expr, debug_flag); + bool debug_flag, + ScheduleErrorMessageLevel err_msg_level) { + impl_ = + std::make_unique(module_expr, debug_flag, err_msg_level); this->InitSeed(rand_seed); } diff --git a/paddle/cinn/ir/ir_schedule.h b/paddle/cinn/ir/ir_schedule.h index 2689eb48a27e5..a68fdbbd26282 100644 --- a/paddle/cinn/ir/ir_schedule.h +++ b/paddle/cinn/ir/ir_schedule.h @@ -29,6 +29,20 @@ namespace cinn { namespace ir { +/** + * \brief Indicates the level of printing error message in the current Schedule + */ +enum class ScheduleErrorMessageLevel : int32_t { + /** \brief Print an error message in short mode. + * Short mode shows which and where the schedule error happens*/ + kGeneral = 0, + /** \brief Print an error message in detailed mode. + * Detailed mode shows which and where the schedule error happens, and the + * schedule input parameters. + */ + kDetailed = 1, +}; + /** * A struct representing a module that contains Expr. This struct is only used * in Schedule process. @@ -70,7 +84,9 @@ class IRSchedule { IRSchedule(); explicit IRSchedule(const ModuleExpr& modexpr, utils::LinearRandomEngine::StateType rand_seed = -1, - bool debug_flag = false); + bool debug_flag = false, + ScheduleErrorMessageLevel err_msg_level = + ScheduleErrorMessageLevel::kGeneral); IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1); diff --git a/paddle/cinn/ir/ir_schedule_error.cc b/paddle/cinn/ir/ir_schedule_error.cc new file mode 100644 index 0000000000000..b6d2c2f94d4f3 --- /dev/null +++ b/paddle/cinn/ir/ir_schedule_error.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/ir_schedule_error.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_printer.h" + +namespace cinn { +namespace ir { + +std::string IRScheduleErrorHandler::GeneralErrorMessage() const { + return this->err_msg_; +} + +std::string IRScheduleErrorHandler::DetailedErrorMessage() const { + std::ostringstream os; + os << GeneralErrorMessage(); + os << "[Expr info] The Expr of current schedule is: " + << this->module_expr_.GetExprs() << std::endl; + return os.str(); +} + +std::string IRScheduleErrorHandler::FormatErrorMessage( + const std::string& primitive, + const ScheduleErrorMessageLevel& err_msg_level) const { + std::ostringstream os; + std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed + ? DetailedErrorMessage() + : GeneralErrorMessage(); + + os << "[IRScheduleError] An error occurred in the scheduel primitive <" + << primitive << ">. " << std::endl; + os << "[Error info] " << err_msg; + return os.str(); +} + +std::string NegativeFactorErrorMessage(const int64_t& factor, + const size_t& idx) { + std::ostringstream os; + os << "The params in factors of Split should be positive. However, the " + "factor at position " + << idx << " is " << factor << std::endl; + return os.str(); +} + +std::string InferFactorErrorMessage() { + std::ostringstream os; + os << "The params in factors of Split should not be less than -1 or have " + "more than one -1!" + << std::endl; + return os.str(); +} + +std::string FactorProductErrorMessage() { + std::ostringstream os; + os << "In Split, the factors' product should be not larger than or equal " + "to original loop's extent!" + << std::endl; + return os.str(); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_schedule_error.h b/paddle/cinn/ir/ir_schedule_error.h new file mode 100644 index 0000000000000..eb4a70175a8ca --- /dev/null +++ b/paddle/cinn/ir/ir_schedule_error.h @@ -0,0 +1,175 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __GNUC__ +#include // for __cxa_demangle +#endif // __GNUC__ + +#if !defined(_WIN32) +#include // dladdr +#include // sleep, usleep +#else // _WIN32 +#ifndef NOMINMAX +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#endif +#include // GetModuleFileName, Sleep +#endif + +#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include "paddle/cinn/ir/ir_schedule.h" + +namespace cinn { +namespace ir { + +namespace enforce { + +#ifdef __GNUC__ +inline std::string demangle(std::string name) { + int status = -4; // some arbitrary value to eliminate the compiler warning + std::unique_ptr res{ + abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; + return (status == 0) ? res.get() : name; +} +#else +inline std::string demangle(std::string name) { return name; } +#endif + +static std::string GetErrorSumaryString(const std::string& what, + const char* file, + int line) { + std::ostringstream sout; + sout << "\n----------------------\nError Message " + "Summary:\n----------------------\n"; + sout << what << "(at " << file << " : " << line << ")" << std::endl; + return sout.str(); +} + +static std::string GetCurrentTraceBackString() { + std::ostringstream sout; + sout << "\n\n--------------------------------------\n"; + sout << "C++ Traceback (most recent call last):"; + sout << "\n--------------------------------------\n"; +#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) + static constexpr int TRACE_STACK_LIMIT = 100; + + void* call_stack[TRACE_STACK_LIMIT]; + auto size = backtrace(call_stack, TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack, size); + Dl_info info; + int idx = 0; + int end_idx = 0; + for (int i = size - 1; i >= end_idx; --i) { + if (dladdr(call_stack[i], &info) && info.dli_sname) { + auto demangled = demangle(info.dli_sname); + std::string path(info.dli_fname); + // C++ traceback info are from core.so + if (path.substr(path.length() - 3).compare(".so") == 0) { + sout << idx++ << " " << demangled << "\n"; + } + } + } + free(symbols); +#else + sout << "Not support stack backtrace yet.\n"; +#endif + return sout.str(); +} + +static std::string GetTraceBackString(const std::string& what, + const char* file, + int line) { + return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line); +} + +struct EnforceNotMet : public std::exception { + public: + EnforceNotMet(const std::string& str, const char* file, int line) + : err_str_(GetTraceBackString(str, file, line)) {} + + const char* what() const noexcept override { return err_str_.c_str(); } + + private: + std::string err_str_; +}; + +#define CINN_THROW(...) \ + do { \ + try { \ + throw enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \ + } catch (const std::exception& e) { \ + std::cout << e.what() << std::endl; \ + throw; \ + } \ + } while (0) +} // namespace enforce + +/** + * This handler is dealing with the errors happen in in the current + * Scheduling. + */ +class IRScheduleErrorHandler { + public: + /** + * \brief constructor + * \param err_msg the error message + */ + explicit IRScheduleErrorHandler(const std::string& err_msg, + const ModuleExpr& module_expr) + : err_msg_(err_msg), module_expr_(module_expr) {} + + /** + * \brief Returns a short error message corresponding to the kGeneral error + * level. + */ + std::string GeneralErrorMessage() const; + + /** + * \brief Returns a detailed error message corresponding to the kDetailed + * error level. + */ + std::string DetailedErrorMessage() const; + + /** + * \brief Returns a detailed error message corresponding to the kDetailed + * error level. + */ + std::string FormatErrorMessage( + const std::string& primitive, + const ScheduleErrorMessageLevel& err_msg_level) const; + + private: + ModuleExpr module_expr_; + std::string err_msg_; +}; + +std::string NegativeFactorErrorMessage(const int64_t& factor, + const size_t& idx); + +std::string InferFactorErrorMessage(); + +std::string FactorProductErrorMessage(); + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_schedule_util.cc b/paddle/cinn/ir/ir_schedule_util.cc index 4b7ca20648742..4399716796d03 100644 --- a/paddle/cinn/ir/ir_schedule_util.cc +++ b/paddle/cinn/ir/ir_schedule_util.cc @@ -220,19 +220,22 @@ void ReplaceExpr(Expr* source, } std::vector ValidateFactors(const std::vector& factors, - int total_extent) { + int total_extent, + const ModuleExpr& module_expr) { CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check."; bool has_minus_one = false; int product = 1; + int idx = -1; for (auto& i : factors) { - CHECK(i != 0) - << "The params in factors of Split should not be 0! Please check."; - CHECK(i >= -1) << "The params in factors of Split should not be less than " - "-1! Please check."; - if (i == -1) { - CHECK(!has_minus_one) << "The params in factors of Split should not have " - "more than one -1! Please check."; + idx++; + if (i == 0 || i < -1) { + throw IRScheduleErrorHandler(NegativeFactorErrorMessage(i, idx), + module_expr); + } else if (i == -1) { + if (has_minus_one) { + throw IRScheduleErrorHandler(InferFactorErrorMessage(), module_expr); + } has_minus_one = true; } else { product *= i; @@ -240,15 +243,14 @@ std::vector ValidateFactors(const std::vector& factors, } std::vector validated_factors = factors; if (!has_minus_one) { - CHECK_GE(product, total_extent) - << "In Split, the factors' product should be equal to original loop's " - "extent! Please check."; + if (product < total_extent) { + throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr); + } return validated_factors; } else { - CHECK_LE(product, total_extent) - << "In Split, when there is -1 in factors, the other factors' product " - "should be <= " - "original loop's extent! Please check."; + if (product > total_extent) { + throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr); + } int minus_one_candidate = static_cast( ceil(static_cast(total_extent) / static_cast(product))); for (int i = 0; i < validated_factors.size(); ++i) { diff --git a/paddle/cinn/ir/ir_schedule_util.h b/paddle/cinn/ir/ir_schedule_util.h index 762cd166d2004..33a8337b8f911 100644 --- a/paddle/cinn/ir/ir_schedule_util.h +++ b/paddle/cinn/ir/ir_schedule_util.h @@ -23,6 +23,7 @@ #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/ir_schedule_error.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/utils/random_engine.h" #include "paddle/cinn/utils/string.h" @@ -248,7 +249,8 @@ void ReplaceExpr(Expr* source, * @return return The valiated factors. */ std::vector ValidateFactors(const std::vector& factors, - int total_extent); + int total_extent, + const ModuleExpr& module_expr); void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis); diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 6a6f13a741bfc..05b181a315a18 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -164,6 +164,11 @@ DEFINE_int32(cinn_profiler_state, "Specify the ProfilerState by Int in CINN, 0 for kDisabled, 1 for " "kCPU, 2 for kCUDA, 3 for kAll, default 0."); +DEFINE_int32(cinn_schedule_error_message_level, + Int32FromEnv("FLAGS_cinn_schedule_error_message_level", 0), + "Specify the level of printing error message in the schedule." + "0 means short, 1 means detailed."); + namespace cinn { namespace runtime {