Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cinn schedule error #54983

Merged
merged 18 commits into from
Jul 13, 2023
42 changes: 42 additions & 0 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,48 @@ void test_split_and_fuse2(void* _args, int32_t num_args)
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}

void TestSplitThrow() {
Context::Global().ResetNameId();
Expr M(32);
Expr N(32);
Expr P(32);

Target target = common::DefaultHostTarget();

Placeholder<float> A("A", {M, N});
auto B = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");

auto stages = CreateStages({A, B});

auto func = cinn::lang::LowerVec(
"test_split_throw", stages, {A, B}, {}, {}, nullptr, target, true);
auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(
mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral);
auto fused = ir_sch.Fuse("B", {0, 1});
// statement that cause the exception
auto splited = ir_sch.Split(fused, {-1, -1});

auto loops = ir_sch.GetLoops("B");
fused = ir_sch.Fuse(loops);
splited = ir_sch.Split(fused, {256, -1});

Module::Builder builder("module1", target);
for (auto& i : func) {
builder.AddFunction(i);
}
auto module = builder.Build();
CodeGenC codegen(target);
codegen.SetInlineBuiltinCodes(false);
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}
TEST(IrSchedule, split_throw) {
ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet);
}

TEST(IrSchedule, reorder1) {
Context::Global().ResetNameId();
Expr M(32);
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ gather_srcs(
ir.cc
ir_base.cc
ir_schedule.cc
ir_schedule_error.cc
ir_schedule_util.cc
ir_visitor.cc
ir_printer.cc
Expand Down
43 changes: 37 additions & 6 deletions paddle/cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"

DECLARE_int32(cinn_schedule_error_message_level);

namespace cinn {
namespace ir {

Expand All @@ -50,8 +52,15 @@ namespace ir {
class ScheduleImpl {
public:
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false)
: module_expr_(module_expr), debug_flag_(debug_flag) {}
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<ScheduleErrorMessageLevel>(
FLAGS_cinn_schedule_error_message_level ||
static_cast<int>(err_msg_level));
}
explicit ScheduleImpl(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}

Expand Down Expand Up @@ -129,8 +138,26 @@ class ScheduleImpl {

ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_ =
ScheduleErrorMessageLevel::kGeneral;
};

/** \brief A macro that guards the beginning of each implementation of schedule
*/
#define CINN_IR_SCHEDULE_BEGIN() try {
/**
* \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential
* errors and error message printing.
* @param primitive A string representing the kind of schedule primitive.
* @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message
* printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \
}

std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
const std::vector<int>& factors) {
CHECK(loop.As<ir::For>())
Expand All @@ -147,7 +174,10 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
<< ") at loop:\n"
<< loop;

auto processed_factors = ValidateFactors(factors, tot_extent);
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(),
processed_factors.end(),
1,
Expand Down Expand Up @@ -1194,7 +1224,6 @@ struct LoopReconstructor : public ir::IRMutator<> {
return utils::Join(new_var_names, ",");
}

private:
public:
/*! \brief The root block */
Expr root_;
Expand Down Expand Up @@ -2286,8 +2315,10 @@ IRSchedule::IRSchedule() {}

IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag) {
impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag);
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
impl_ =
std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
}

Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_schedule_error.h"
#include "paddle/cinn/ir/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/utils/random_engine.h"
Expand Down Expand Up @@ -70,7 +71,9 @@ class IRSchedule {
IRSchedule();
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false);
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral);
IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed = -1);
Expand Down
35 changes: 35 additions & 0 deletions paddle/cinn/ir/ir_schedule_error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/ir_schedule_error.h"

namespace cinn {
namespace ir {

std::string IRScheduleErrorHandler::FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const {
std::ostringstream os;
std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed
? DetailedErrorMessage()
: GeneralErrorMessage();

os << "[IRScheduleError] An error occurred in the scheduel primitive <"
<< primitive << ">. " << std::endl;
os << "[Error info] " << err_msg;
return os.str();
}

} // namespace ir
} // namespace cinn
173 changes: 173 additions & 0 deletions paddle/cinn/ir/ir_schedule_error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__

#if !defined(_WIN32)
#include <dlfcn.h> // dladdr
#include <unistd.h> // sleep, usleep
#else // _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> // GetModuleFileName, Sleep
#endif

#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif

#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

namespace cinn {
namespace ir {

namespace enforce {

#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
std::unique_ptr<char, void (*)(void*)> res{
abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif

static std::string GetErrorSumaryString(const std::string& what,
const char* file,
int line) {
std::ostringstream sout;
sout << "\n----------------------\nError Message "
"Summary:\n----------------------\n";
sout << what << "(at " << file << " : " << line << ")" << std::endl;
return sout.str();
}

static std::string GetCurrentTraceBackString() {
std::ostringstream sout;
sout << "\n\n--------------------------------------\n";
sout << "C++ Traceback (most recent call last):";
sout << "\n--------------------------------------\n";
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
static constexpr int TRACE_STACK_LIMIT = 100;

void* call_stack[TRACE_STACK_LIMIT];
auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size);
Dl_info info;
int idx = 0;
int end_idx = 0;
for (int i = size - 1; i >= end_idx; --i) {
if (dladdr(call_stack[i], &info) && info.dli_sname) {
auto demangled = demangle(info.dli_sname);
std::string path(info.dli_fname);
// C++ traceback info are from core.so
if (path.substr(path.length() - 3).compare(".so") == 0) {
sout << idx++ << " " << demangled << "\n";
}
}
}
free(symbols);
#else
sout << "Not support stack backtrace yet.\n";
#endif
return sout.str();
}

static std::string GetTraceBackString(const std::string& what,
const char* file,
int line) {
return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line);
// return GetErrorSumaryString(what, file, line);
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
}

struct EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {}

const char* what() const noexcept override { return err_str_.c_str(); }

private:
std::string err_str_;
};

#define CINN_THROW(...) \
do { \
try { \
throw enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this same as Paddle enforce? As I remember, we discussed offline that we use PADDLE_THROW when we found the definition, otherwise we define by ourself, then we can handle it with both CINN-only and Paddle-CINN.

Does this implementation also handle both CINN-only and Paddle-CINN? Do we reuse Paddle code?

Copy link
Contributor Author

@ZzSean ZzSean Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_THROW is more complicated and has many redundant functions. If we want to reuse Paddle code we have to include more Paddle header files, thus we cannot build CINN-ONLY.
Using CINN_THROW defined by ourselves maybe a good solution for now, 'cause there is IR_THROW in PADDLE_IR to deal with the similar problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, not best but fine to me now.

} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
} // namespace enforce

/**
* \brief Indicates the level of printing error message in the current Schedule
*/
enum class ScheduleErrorMessageLevel : int32_t {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the ErrorMessage use at other places? If so, should we remove Schedule in naming? Then we can use the code you implemented all over the CINN!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically no problem. But its namespace is cinn::ir now, i don't know if it is convenient/necessary to use this in the same level of ir, like hlir, optim, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about changing the name and the location of this file when we have the need later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, can we change the namespace, location of file in this PR? We would like to reuse for other places.

/** \brief Print an error message in short mode*/
zhhsplendid marked this conversation as resolved.
Show resolved Hide resolved
kGeneral = 0,
/** \brief Print an error message in detailed mode*/
kDetailed = 1,
};

/**
* This handler is dealing with the errors happen in in the current Scheduling.
*/
class IRScheduleErrorHandler : public std::runtime_error {
public:
IRScheduleErrorHandler() : std::runtime_error("") {}
/**
* \brief constructor
* \param s the error message
*/
explicit IRScheduleErrorHandler(const std::string& s)
: std::runtime_error(s) {}

/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const;

/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
virtual std::string GeneralErrorMessage() const = 0;

/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
virtual std::string DetailedErrorMessage() const = 0;
};

} // namespace ir
} // namespace cinn
Loading