diff --git a/CMakeLists.txt b/CMakeLists.txt index ddb231c94..f3b879925 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,8 @@ if(USE_AVX) add_definitions(-DXGBOOST_USE_AVX) endif() +# enable custom logging +add_definitions(-DDMLC_LOG_CUSTOMIZE=1) # compiled code customizations for R package if(R_LIB) diff --git a/Makefile b/Makefile index c7d8a032c..9ad197ee5 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ endif endif export LDFLAGS= -pthread -lm $(ADD_LDFLAGS) $(DMLC_LDFLAGS) $(PLUGIN_LDFLAGS) -export CFLAGS= -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS) +export CFLAGS= -DDMLC_LOG_CUSTOMIZE=1 -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS) CFLAGS += -I$(DMLC_CORE)/include -I$(RABIT)/include -I$(GTEST_PATH)/include #java include path export JAVAINCFLAGS = -I${JAVA_HOME}/include -I./java diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index cfa64b142..a64e50b36 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -96,6 +96,15 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*) */ XGB_DLL const char *XGBGetLastError(void); +/*! + * \brief register callback function for LOG(INFO) messages -- helpful messages + * that are not errors. + * Note: this function can be called by multiple threads. The callback function + * will run on the thread that registered it + * \return 0 for success, -1 for failure + */ +XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)); + /*! * \brief load a data matrix * \param fname the name of the file diff --git a/include/xgboost/logging.h b/include/xgboost/logging.h index 4228087e1..c15c0491d 100644 --- a/include/xgboost/logging.h +++ b/include/xgboost/logging.h @@ -9,6 +9,7 @@ #define XGBOOST_LOGGING_H_ #include +#include #include #include "./base.h" @@ -37,6 +38,23 @@ class TrackerLogger : public BaseLogger { ~TrackerLogger(); }; +class LogCallbackRegistry { + public: + using Callback = void (*)(const char*); + LogCallbackRegistry() + : log_callback_([] (const char* msg) { std::cerr << msg << std::endl; }) {} + inline void Register(Callback log_callback) { + this->log_callback_ = log_callback; + } + inline Callback Get() const { + return log_callback_; + } + private: + Callback log_callback_; +}; + +using LogCallbackRegistryStore = dmlc::ThreadLocalStore; + // redefines the logging macro if not existed #ifndef LOG #define LOG(severity) LOG_##severity.stream() diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 47e05b614..c03321a20 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -100,6 +100,18 @@ def from_cstr_to_pystr(data, length): return res +def _log_callback(msg): + """Redirect logs from native library into Python console""" + print("{0:s}".format(py_str(msg))) + + +def _get_log_callback_func(): + """Wrap log_callback() method in ctypes callback type""" + # pylint: disable=invalid-name + CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p) + return CALLBACK(_log_callback) + + def _load_lib(): """Load xgboost Library.""" lib_path = find_lib_path() @@ -107,6 +119,9 @@ def _load_lib(): return None lib = ctypes.cdll.LoadLibrary(lib_path[0]) lib.XGBGetLastError.restype = ctypes.c_char_p + lib.callback = _get_log_callback_func() + if lib.XGBRegisterLogCallback(lib.callback) != 0: + raise XGBoostError(lib.XGBGetLastError()) return lib diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 340940366..988f43225 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -202,6 +202,13 @@ struct XGBAPIThreadLocalEntry { // define the threadlocal store. using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore; +int XGBRegisterLogCallback(void (*callback)(const char*)) { + API_BEGIN(); + LogCallbackRegistry* registry = LogCallbackRegistryStore::Get(); + registry->Register(callback); + API_END(); +} + int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) { diff --git a/src/logging.cc b/src/logging.cc index 934bb0414..6f5eb384b 100644 --- a/src/logging.cc +++ b/src/logging.cc @@ -8,16 +8,24 @@ #include #include "./common/sync.h" -namespace xgboost { +#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 +// Override logging mechanism for non-R interfaces +void dmlc::CustomLogMessage::Log(const std::string& msg) { + const xgboost::LogCallbackRegistry* registry + = xgboost::LogCallbackRegistryStore::Get(); + auto callback = registry->Get(); + callback(msg.c_str()); +} -#if XGBOOST_CUSTOMIZE_LOGGER == 0 +namespace xgboost { ConsoleLogger::~ConsoleLogger() { - std::cerr << log_stream_.str() << std::endl; + dmlc::CustomLogMessage::Log(log_stream_.str()); } TrackerLogger::~TrackerLogger() { log_stream_ << '\n'; rabit::TrackerPrint(log_stream_.str()); } -#endif + } // namespace xgboost +#endif