Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: improve RUNTIME_CHECK and RUNTIME_ASSERT #5434

Merged
merged 13 commits into from
Jul 27, 2022
53 changes: 34 additions & 19 deletions dbms/src/Common/Exception.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <errno.h>
#include <string.h>


namespace DB
{
namespace ErrorCodes
Expand All @@ -35,7 +34,6 @@ extern const int UNKNOWN_EXCEPTION;
extern const int CANNOT_TRUNCATE_FILE;
} // namespace ErrorCodes


void throwFromErrno(const std::string & s, int code, int e)
{
const size_t buf_size = 128;
Expand All @@ -54,14 +52,18 @@ void throwFromErrno(const std::string & s, int code, int e)
strcpy(buf, unknown_message);
strcpy(buf + strlen(unknown_message), code);
}
throw ErrnoException(s + ", errno: " + toString(e) + ", strerror: " + std::string(buf), code, e);
throw ErrnoException(s + ", errno: " + toString(e) + ", strerror: " + std::string(buf),
code,
e);
#else
throw ErrnoException(s + ", errno: " + toString(e) + ", strerror: " + std::string(strerror_r(e, buf, sizeof(buf))), code, e);
throw ErrnoException(s + ", errno: " + toString(e) + ", strerror: " + std::string(strerror_r(e, buf, sizeof(buf))),
code,
e);
#endif
}


void tryLogCurrentException(const char * log_name, const std::string & start_of_message)
void tryLogCurrentException(const char * log_name,
const std::string & start_of_message)
{
tryLogCurrentException(&Poco::Logger::get(log_name), start_of_message);
}
Expand All @@ -75,19 +77,22 @@ void tryLogCurrentException(const char * log_name, const std::string & start_of_
{ \
}

void tryLogCurrentException(const LoggerPtr & logger, const std::string & start_of_message)
void tryLogCurrentException(const LoggerPtr & logger,
const std::string & start_of_message)
{
TRY_LOG_CURRENT_EXCEPTION(logger, start_of_message);
}

void tryLogCurrentException(Poco::Logger * logger, const std::string & start_of_message)
void tryLogCurrentException(Poco::Logger * logger,
const std::string & start_of_message)
{
TRY_LOG_CURRENT_EXCEPTION(logger, start_of_message);
}

#undef TRY_LOG_CURRENT_EXCEPTION

std::string getCurrentExceptionMessage(bool with_stacktrace, bool check_embedded_stacktrace)
std::string getCurrentExceptionMessage(bool with_stacktrace,
bool check_embedded_stacktrace)
{
std::stringstream stream;

Expand All @@ -103,8 +108,10 @@ std::string getCurrentExceptionMessage(bool with_stacktrace, bool check_embedded
{
try
{
stream << "Poco::Exception. Code: " << ErrorCodes::POCO_EXCEPTION << ", e.code() = " << e.code()
<< ", e.displayText() = " << e.displayText() << ", e.what() = " << e.what();
stream << "Poco::Exception. Code: " << ErrorCodes::POCO_EXCEPTION
<< ", e.code() = " << e.code()
<< ", e.displayText() = " << e.displayText()
<< ", e.what() = " << e.what();
}
catch (...)
{
Expand All @@ -120,7 +127,8 @@ std::string getCurrentExceptionMessage(bool with_stacktrace, bool check_embedded
if (status)
name += " (demangling status: " + toString(status) + ")";

stream << "std::exception. Code: " << ErrorCodes::STD_EXCEPTION << ", type: " << name << ", e.what() = " << e.what();
stream << "std::exception. Code: " << ErrorCodes::STD_EXCEPTION
<< ", type: " << name << ", e.what() = " << e.what();
}
catch (...)
{
Expand All @@ -136,7 +144,8 @@ std::string getCurrentExceptionMessage(bool with_stacktrace, bool check_embedded
if (status)
name += " (demangling status: " + toString(status) + ")";

stream << "Unknown exception. Code: " << ErrorCodes::UNKNOWN_EXCEPTION << ", type: " << name;
stream << "Unknown exception. Code: " << ErrorCodes::UNKNOWN_EXCEPTION
<< ", type: " << name;
}
catch (...)
{
Expand All @@ -146,7 +155,6 @@ std::string getCurrentExceptionMessage(bool with_stacktrace, bool check_embedded
return stream.str();
}


int getCurrentExceptionCode()
{
try
Expand All @@ -171,15 +179,13 @@ int getCurrentExceptionCode()
}
}


void rethrowFirstException(const Exceptions & exceptions)
{
for (const auto & exception : exceptions)
if (exception)
std::rethrow_exception(exception);
}


std::string getExceptionMessage(const Exception & e, bool with_stacktrace, bool check_embedded_stacktrace)
{
std::stringstream stream;
Expand All @@ -200,7 +206,8 @@ std::string getExceptionMessage(const Exception & e, bool with_stacktrace, bool
}
}

stream << "Code: " << e.code() << ", e.displayText() = " << text << ", e.what() = " << e.what();
stream << "Code: " << e.code() << ", e.displayText() = " << text
<< ", e.what() = " << e.what();

if (with_stacktrace && !has_embedded_stack_trace)
stream << ", Stack trace:\n\n"
Expand All @@ -225,7 +232,6 @@ std::string getExceptionMessage(std::exception_ptr e, bool with_stacktrace)
}
}


std::string ExecutionStatus::serializeText() const
{
WriteBufferFromOwnString wb;
Expand Down Expand Up @@ -254,11 +260,20 @@ bool ExecutionStatus::tryDeserializeText(const std::string & data)
return true;
}

ExecutionStatus ExecutionStatus::fromCurrentException(const std::string & start_of_message)
ExecutionStatus
ExecutionStatus::fromCurrentException(const std::string & start_of_message)
{
String msg = (start_of_message.empty() ? "" : (start_of_message + ": ")) + getCurrentExceptionMessage(false, true);
return ExecutionStatus(getCurrentExceptionCode(), msg);
}

namespace exception_details
{
const LoggerPtr & getDefaultFatalLogger()
{
static const auto logger = std::make_shared<Logger>("DefaultFatal", "");
return logger;
}
} // namespace exception_details

} // namespace DB
158 changes: 109 additions & 49 deletions dbms/src/Common/Exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,17 @@

#pragma once

#include <Common/Logger.h>
#include <Common/StackTrace.h>
#include <Poco/Exception.h>
#include <fmt/format.h>
#include <common/defines.h>

#include <cerrno>
#include <memory>
#include <vector>


namespace Poco
{
class Logger;
}


namespace DB
{
class Logger;
using LoggerPtr = std::shared_ptr<Logger>;

class Exception : public Poco::Exception
{
public:
Expand Down Expand Up @@ -73,8 +64,8 @@ class Exception : public Poco::Exception
StackTrace trace;
};


/// Contains an additional member `saved_errno`. See the throwFromErrno function.
/// Contains an additional member `saved_errno`. See the throwFromErrno
/// function.
class ErrnoException : public Exception
{
public:
Expand All @@ -97,46 +88,48 @@ class ErrnoException : public Exception
int saved_errno;
};


using Exceptions = std::vector<std::exception_ptr>;


[[noreturn]] void throwFromErrno(const std::string & s, int code = 0, int e = errno);


/** Try to write an exception to the log (and forget about it).
* Can be used in destructors in the catch-all block.
*/
void tryLogCurrentException(const char * log_name, const std::string & start_of_message = "");
void tryLogCurrentException(const LoggerPtr & logger, const std::string & start_of_message = "");
void tryLogCurrentException(Poco::Logger * logger, const std::string & start_of_message = "");

* Can be used in destructors in the catch-all block.
*/
void tryLogCurrentException(const char * log_name,
const std::string & start_of_message = "");
void tryLogCurrentException(const LoggerPtr & logger,
const std::string & start_of_message = "");
void tryLogCurrentException(Poco::Logger * logger,
const std::string & start_of_message = "");

/** Prints current exception in canonical format.
* with_stacktrace - prints stack trace for DB::Exception.
* check_embedded_stacktrace - if DB::Exception has embedded stacktrace then
* only this stack trace will be printed.
*/
std::string getCurrentExceptionMessage(bool with_stacktrace, bool check_embedded_stacktrace = false);
* with_stacktrace - prints stack trace for DB::Exception.
* check_embedded_stacktrace - if DB::Exception has embedded stacktrace then
* only this stack trace will be printed.
*/
std::string getCurrentExceptionMessage(bool with_stacktrace,
bool check_embedded_stacktrace = false);

/// Returns error code from ErrorCodes
int getCurrentExceptionCode();


/// An execution status of any piece of code, contains return code and optional error
/// An execution status of any piece of code, contains return code and optional
/// error
struct ExecutionStatus
{
int code = 0;
std::string message;

ExecutionStatus() = default;

explicit ExecutionStatus(int return_code, const std::string & exception_message = "")
explicit ExecutionStatus(int return_code,
const std::string & exception_message = "")
: code(return_code)
, message(exception_message)
{}

static ExecutionStatus fromCurrentException(const std::string & start_of_message = "");
static ExecutionStatus
fromCurrentException(const std::string & start_of_message = "");

std::string serializeText() const;

Expand All @@ -145,14 +138,11 @@ struct ExecutionStatus
bool tryDeserializeText(const std::string & data);
};


std::string getExceptionMessage(const Exception & e, bool with_stacktrace, bool check_embedded_stacktrace = false);
std::string getExceptionMessage(std::exception_ptr e, bool with_stacktrace);


void rethrowFirstException(const Exceptions & exceptions);


template <typename T>
std::enable_if_t<std::is_pointer_v<T>, T> exception_cast(std::exception_ptr e)
{
Expand All @@ -172,28 +162,98 @@ std::enable_if_t<std::is_pointer_v<T>, T> exception_cast(std::exception_ptr e)

namespace exception_details
{
inline std::string generateFormattedMessage(const char * condition)
{
return fmt::format("Assert {} fail!", condition);
}

template <typename T, typename... Args>
inline std::string generateLogMessage(const char * condition, T && fmt_str, Args &&... args)
inline std::string generateFormattedMessage(const char * condition, T && fmt_str, Args &&... args)
{
return FmtBuffer().fmtAppend("Assert {} fail! ", condition).fmtAppend(fmt_str, std::forward<Args>(args)...).toString();
}

template <typename... Args>
inline Poco::Message generateLogMessage(const std::string & logger_name, const char * filename, int lineno, const char * condition, Args &&... args)
{
return Poco::Message(
logger_name,
generateFormattedMessage(condition, std::forward<Args>(args)...),
Poco::Message::PRIO_FATAL,
filename,
lineno);
}

const LoggerPtr & getDefaultFatalLogger();

inline void log(const char * filename, int lineno, const char * condition, const LoggerPtr & logger)
{
if (likely(logger->fatal()))
{
auto message = generateLogMessage(logger->name(), filename, lineno, condition);
logger->log(message);
}
}

inline void log(const char * filename, int lineno, const char * condition)
{
log(filename, lineno, condition, getDefaultFatalLogger());
}

template <typename... Args>
inline void log(const char * filename, int lineno, const char * condition, const LoggerPtr & logger, const char * fmt_str, Args &&... args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about removing const char * fmt_str?
If so, inline void log(const char * filename, int lineno, const char * condition, const LoggerPtr & logger) can be removed too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{
if (logger->fatal())
{
auto message = generateLogMessage(
logger->name(),
filename,
lineno,
condition,
fmt_str,
std::forward<Args>(args)...);
logger->log(message);
}
}

template <typename... Args>
inline void log(const char * filename, int lineno, const char * condition, const char * fmt_str, Args &&... args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fmt_str is used for overloadding from const LoggerPtr & or this version will always be chosen even when there's a const LoggerPtr &.

{
return fmt::format(std::forward<T>(fmt_str), condition, std::forward<Args>(args)...);
log(filename, lineno, condition, getDefaultFatalLogger(), fmt_str, std::forward<Args>(args)...);
}
} // namespace exception_details

#define RUNTIME_CHECK(condition, ExceptionType, ...) \
do \
{ \
if (unlikely(!(condition))) \
throw ExceptionType(__VA_ARGS__); \
/// Usage:
/// ```
/// RUNTIME_CHECK(a != b, Exception("{} does not equal to {}", a, b));
/// ```
#define RUNTIME_CHECK(condition, ExceptionGenerationCode) \
do \
{ \
if (unlikely(!(condition))) \
throw(ExceptionGenerationCode); \
} while (false)

#define RUNTIME_ASSERT(condition, logger, ...) \
do \
{ \
if (unlikely(!(condition))) \
{ \
LOG_FATAL((logger), exception_details::generateLogMessage(#condition, "Assert {} fail! " __VA_ARGS__)); \
std::terminate(); \
} \
/// Usage:
/// ```
/// RUNTIME_ASSERT(a != b);
/// RUNTIME_ASSERT(a != b, "fail");
/// RUNTIME_ASSERT(a != b, "{} does not equal to {}", a, b);
/// RUNTIME_ASSERT(a != b, logger);
/// RUNTIME_ASSERT(a != b, logger, "{} does not equal to {}", a, b);
/// ```
#define RUNTIME_ASSERT(condition, ...) \
do \
{ \
if (unlikely(!(condition))) \
{ \
exception_details::log( \
&__FILE__[LogFmtDetails::getFileNameOffset(__FILE__)], \
__LINE__, \
#condition, \
##__VA_ARGS__); \
std::terminate(); \
} \
} while (false)

} // namespace DB
Loading