Skip to content

Commit

Permalink
Add callback interface to re-direct console output (dmlc#3438)
Browse files Browse the repository at this point in the history
* Add callback interface to re-direct console output

* Exempt TrackerLogger from custom logging

* Fix lint
  • Loading branch information
hcho3 authored and Nan Zhu committed Jul 26, 2018
1 parent 6a184f1 commit 5f51aef
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ if(USE_AVX)
add_definitions(-DXGBOOST_USE_AVX)
endif()

# enable custom logging
add_definitions(-DDMLC_LOG_CUSTOMIZE=1)

# compiled code customizations for R package
if(R_LIB)
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ endif
endif

export LDFLAGS= -pthread -lm $(ADD_LDFLAGS) $(DMLC_LDFLAGS) $(PLUGIN_LDFLAGS)
export CFLAGS= -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS)
export CFLAGS= -DDMLC_LOG_CUSTOMIZE=1 -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS)
CFLAGS += -I$(DMLC_CORE)/include -I$(RABIT)/include -I$(GTEST_PATH)/include
#java include path
export JAVAINCFLAGS = -I${JAVA_HOME}/include -I./java
Expand Down
9 changes: 9 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
*/
XGB_DLL const char *XGBGetLastError(void);

/*!
* \brief register callback function for LOG(INFO) messages -- helpful messages
* that are not errors.
* Note: this function can be called by multiple threads. The callback function
* will run on the thread that registered it
* \return 0 for success, -1 for failure
*/
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*));

/*!
* \brief load a data matrix
* \param fname the name of the file
Expand Down
18 changes: 18 additions & 0 deletions include/xgboost/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define XGBOOST_LOGGING_H_

#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <sstream>
#include "./base.h"

Expand Down Expand Up @@ -37,6 +38,23 @@ class TrackerLogger : public BaseLogger {
~TrackerLogger();
};

class LogCallbackRegistry {
public:
using Callback = void (*)(const char*);
LogCallbackRegistry()
: log_callback_([] (const char* msg) { std::cerr << msg << std::endl; }) {}
inline void Register(Callback log_callback) {
this->log_callback_ = log_callback;
}
inline Callback Get() const {
return log_callback_;
}
private:
Callback log_callback_;
};

using LogCallbackRegistryStore = dmlc::ThreadLocalStore<LogCallbackRegistry>;

// redefines the logging macro if not existed
#ifndef LOG
#define LOG(severity) LOG_##severity.stream()
Expand Down
15 changes: 15 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,28 @@ def from_cstr_to_pystr(data, length):
return res


def _log_callback(msg):
"""Redirect logs from native library into Python console"""
print("{0:s}".format(py_str(msg)))


def _get_log_callback_func():
"""Wrap log_callback() method in ctypes callback type"""
# pylint: disable=invalid-name
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return CALLBACK(_log_callback)


def _load_lib():
"""Load xgboost Library."""
lib_path = find_lib_path()
if len(lib_path) == 0:
return None
lib = ctypes.cdll.LoadLibrary(lib_path[0])
lib.XGBGetLastError.restype = ctypes.c_char_p
lib.callback = _get_log_callback_func()
if lib.XGBRegisterLogCallback(lib.callback) != 0:
raise XGBoostError(lib.XGBGetLastError())
return lib


Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ struct XGBAPIThreadLocalEntry {
// define the threadlocal store.
using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;

int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN();
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
registry->Register(callback);
API_END();
}

int XGDMatrixCreateFromFile(const char *fname,
int silent,
DMatrixHandle *out) {
Expand Down
16 changes: 12 additions & 4 deletions src/logging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@
#include <iostream>
#include "./common/sync.h"

namespace xgboost {
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// Override logging mechanism for non-R interfaces
void dmlc::CustomLogMessage::Log(const std::string& msg) {
const xgboost::LogCallbackRegistry* registry
= xgboost::LogCallbackRegistryStore::Get();
auto callback = registry->Get();
callback(msg.c_str());
}

#if XGBOOST_CUSTOMIZE_LOGGER == 0
namespace xgboost {
ConsoleLogger::~ConsoleLogger() {
std::cerr << log_stream_.str() << std::endl;
dmlc::CustomLogMessage::Log(log_stream_.str());
}

TrackerLogger::~TrackerLogger() {
log_stream_ << '\n';
rabit::TrackerPrint(log_stream_.str());
}
#endif

} // namespace xgboost
#endif

0 comments on commit 5f51aef

Please sign in to comment.