Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback interface to re-direct console output #3438

Merged
merged 3 commits into from
Jul 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ if(USE_AVX)
add_definitions(-DXGBOOST_USE_AVX)
endif()

# enable custom logging
add_definitions(-DDMLC_LOG_CUSTOMIZE=1)

# compiled code customizations for R package
if(R_LIB)
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ endif
endif

export LDFLAGS= -pthread -lm $(ADD_LDFLAGS) $(DMLC_LDFLAGS) $(PLUGIN_LDFLAGS)
export CFLAGS= -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS)
export CFLAGS= -DDMLC_LOG_CUSTOMIZE=1 -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS)
CFLAGS += -I$(DMLC_CORE)/include -I$(RABIT)/include -I$(GTEST_PATH)/include
#java include path
export JAVAINCFLAGS = -I${JAVA_HOME}/include -I./java
Expand Down
9 changes: 9 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
*/
XGB_DLL const char *XGBGetLastError(void);

/*!
* \brief register callback function for LOG(INFO) messages -- helpful messages
* that are not errors.
* Note: this function can be called by multiple threads. The callback function
* will run on the thread that registered it
* \return 0 for success, -1 for failure
*/
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*));

/*!
* \brief load a data matrix
* \param fname the name of the file
Expand Down
18 changes: 18 additions & 0 deletions include/xgboost/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define XGBOOST_LOGGING_H_

#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <sstream>
#include "./base.h"

Expand Down Expand Up @@ -37,6 +38,23 @@ class TrackerLogger : public BaseLogger {
~TrackerLogger();
};

class LogCallbackRegistry {
public:
using Callback = void (*)(const char*);
LogCallbackRegistry()
: log_callback_([] (const char* msg) { std::cerr << msg << std::endl; }) {}
inline void Register(Callback log_callback) {
this->log_callback_ = log_callback;
}
inline Callback Get() const {
return log_callback_;
}
private:
Callback log_callback_;
};

using LogCallbackRegistryStore = dmlc::ThreadLocalStore<LogCallbackRegistry>;

// redefines the logging macro if not existed
#ifndef LOG
#define LOG(severity) LOG_##severity.stream()
Expand Down
15 changes: 15 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,28 @@ def from_cstr_to_pystr(data, length):
return res


def _log_callback(msg):
"""Redirect logs from native library into Python console"""
print("{0:s}".format(py_str(msg)))


def _get_log_callback_func():
"""Wrap log_callback() method in ctypes callback type"""
# pylint: disable=invalid-name
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return CALLBACK(_log_callback)


def _load_lib():
"""Load xgboost Library."""
lib_path = find_lib_path()
if len(lib_path) == 0:
return None
lib = ctypes.cdll.LoadLibrary(lib_path[0])
lib.XGBGetLastError.restype = ctypes.c_char_p
lib.callback = _get_log_callback_func()
if lib.XGBRegisterLogCallback(lib.callback) != 0:
raise XGBoostError(lib.XGBGetLastError())
return lib


Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ struct XGBAPIThreadLocalEntry {
// define the threadlocal store.
using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;

int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN();
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
registry->Register(callback);
API_END();
}

int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {
Expand Down
16 changes: 12 additions & 4 deletions src/logging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@
#include <iostream>
#include "./common/sync.h"

namespace xgboost {
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// Override logging mechanism for non-R interfaces
void dmlc::CustomLogMessage::Log(const std::string& msg) {
const xgboost::LogCallbackRegistry* registry
= xgboost::LogCallbackRegistryStore::Get();
auto callback = registry->Get();
callback(msg.c_str());
}

#if XGBOOST_CUSTOMIZE_LOGGER == 0
namespace xgboost {
ConsoleLogger::~ConsoleLogger() {
std::cerr << log_stream_.str() << std::endl;
dmlc::CustomLogMessage::Log(log_stream_.str());
}

TrackerLogger::~TrackerLogger() {
log_stream_ << '\n';
rabit::TrackerPrint(log_stream_.str());
}
#endif

} // namespace xgboost
#endif