Add callback interface to re-direct console output (#3438)

* Add callback interface to re-direct console output

* Exempt TrackerLogger from custom logging

* Fix lint
This commit is contained in:
Philip Hyunsu Cho 2018-07-05 11:32:30 -07:00 committed by GitHub
parent 45bf4fbffb
commit 48d6e68690
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 5 deletions

View File

@ -53,6 +53,8 @@ if(USE_AVX)
add_definitions(-DXGBOOST_USE_AVX) add_definitions(-DXGBOOST_USE_AVX)
endif() endif()
# enable custom logging
add_definitions(-DDMLC_LOG_CUSTOMIZE=1)
# compiled code customizations for R package # compiled code customizations for R package
if(R_LIB) if(R_LIB)

View File

@ -68,7 +68,7 @@ endif
endif endif
export LDFLAGS= -pthread -lm $(ADD_LDFLAGS) $(DMLC_LDFLAGS) $(PLUGIN_LDFLAGS) export LDFLAGS= -pthread -lm $(ADD_LDFLAGS) $(DMLC_LDFLAGS) $(PLUGIN_LDFLAGS)
export CFLAGS= -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS) export CFLAGS= -DDMLC_LOG_CUSTOMIZE=1 -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS)
CFLAGS += -I$(DMLC_CORE)/include -I$(RABIT)/include -I$(GTEST_PATH)/include CFLAGS += -I$(DMLC_CORE)/include -I$(RABIT)/include -I$(GTEST_PATH)/include
#java include path #java include path
export JAVAINCFLAGS = -I${JAVA_HOME}/include -I./java export JAVAINCFLAGS = -I${JAVA_HOME}/include -I./java

View File

@ -96,6 +96,15 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
*/ */
XGB_DLL const char *XGBGetLastError(void); XGB_DLL const char *XGBGetLastError(void);
/*!
* \brief register callback function for LOG(INFO) messages -- helpful messages
* that are not errors.
* Note: this function can be called by multiple threads. The callback function
* will run on the thread that registered it
* \return 0 for success, -1 for failure
*/
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*));
/*! /*!
* \brief load a data matrix * \brief load a data matrix
* \param fname the name of the file * \param fname the name of the file

View File

@ -9,6 +9,7 @@
#define XGBOOST_LOGGING_H_ #define XGBOOST_LOGGING_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <sstream> #include <sstream>
#include "./base.h" #include "./base.h"
@ -37,6 +38,23 @@ class TrackerLogger : public BaseLogger {
~TrackerLogger(); ~TrackerLogger();
}; };
class LogCallbackRegistry {
public:
using Callback = void (*)(const char*);
LogCallbackRegistry()
: log_callback_([] (const char* msg) { std::cerr << msg << std::endl; }) {}
inline void Register(Callback log_callback) {
this->log_callback_ = log_callback;
}
inline Callback Get() const {
return log_callback_;
}
private:
Callback log_callback_;
};
using LogCallbackRegistryStore = dmlc::ThreadLocalStore<LogCallbackRegistry>;
// redefines the logging macro if not existed // redefines the logging macro if not existed
#ifndef LOG #ifndef LOG
#define LOG(severity) LOG_##severity.stream() #define LOG(severity) LOG_##severity.stream()

View File

@ -100,6 +100,18 @@ def from_cstr_to_pystr(data, length):
return res return res
def _log_callback(msg):
"""Redirect logs from native library into Python console"""
print("{0:s}".format(py_str(msg)))
def _get_log_callback_func():
"""Wrap log_callback() method in ctypes callback type"""
# pylint: disable=invalid-name
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
return CALLBACK(_log_callback)
def _load_lib(): def _load_lib():
"""Load xgboost Library.""" """Load xgboost Library."""
lib_path = find_lib_path() lib_path = find_lib_path()
@ -107,6 +119,9 @@ def _load_lib():
return None return None
lib = ctypes.cdll.LoadLibrary(lib_path[0]) lib = ctypes.cdll.LoadLibrary(lib_path[0])
lib.XGBGetLastError.restype = ctypes.c_char_p lib.XGBGetLastError.restype = ctypes.c_char_p
lib.callback = _get_log_callback_func()
if lib.XGBRegisterLogCallback(lib.callback) != 0:
raise XGBoostError(lib.XGBGetLastError())
return lib return lib

View File

@ -202,6 +202,13 @@ struct XGBAPIThreadLocalEntry {
// define the threadlocal store. // define the threadlocal store.
using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>; using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
int XGBRegisterLogCallback(void (*callback)(const char*)) {
API_BEGIN();
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
registry->Register(callback);
API_END();
}
int XGDMatrixCreateFromFile(const char *fname, int XGDMatrixCreateFromFile(const char *fname,
int silent, int silent,
DMatrixHandle *out) { DMatrixHandle *out) {

View File

@ -8,16 +8,24 @@
#include <iostream> #include <iostream>
#include "./common/sync.h" #include "./common/sync.h"
namespace xgboost { #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// Override logging mechanism for non-R interfaces
void dmlc::CustomLogMessage::Log(const std::string& msg) {
const xgboost::LogCallbackRegistry* registry
= xgboost::LogCallbackRegistryStore::Get();
auto callback = registry->Get();
callback(msg.c_str());
}
#if XGBOOST_CUSTOMIZE_LOGGER == 0 namespace xgboost {
ConsoleLogger::~ConsoleLogger() { ConsoleLogger::~ConsoleLogger() {
std::cerr << log_stream_.str() << std::endl; dmlc::CustomLogMessage::Log(log_stream_.str());
} }
TrackerLogger::~TrackerLogger() { TrackerLogger::~TrackerLogger() {
log_stream_ << '\n'; log_stream_ << '\n';
rabit::TrackerPrint(log_stream_.str()); rabit::TrackerPrint(log_stream_.str());
} }
#endif
} // namespace xgboost } // namespace xgboost
#endif