Add callback interface to re-direct console output (#3438)
* Add callback interface to re-direct console output * Exempt TrackerLogger from custom logging * Fix lint
This commit is contained in:
parent
45bf4fbffb
commit
48d6e68690
@ -53,6 +53,8 @@ if(USE_AVX)
|
|||||||
add_definitions(-DXGBOOST_USE_AVX)
|
add_definitions(-DXGBOOST_USE_AVX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# enable custom logging
|
||||||
|
add_definitions(-DDMLC_LOG_CUSTOMIZE=1)
|
||||||
|
|
||||||
# compiled code customizations for R package
|
# compiled code customizations for R package
|
||||||
if(R_LIB)
|
if(R_LIB)
|
||||||
|
|||||||
2
Makefile
2
Makefile
@ -68,7 +68,7 @@ endif
|
|||||||
endif
|
endif
|
||||||
|
|
||||||
export LDFLAGS= -pthread -lm $(ADD_LDFLAGS) $(DMLC_LDFLAGS) $(PLUGIN_LDFLAGS)
|
export LDFLAGS= -pthread -lm $(ADD_LDFLAGS) $(DMLC_LDFLAGS) $(PLUGIN_LDFLAGS)
|
||||||
export CFLAGS= -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS)
|
export CFLAGS= -DDMLC_LOG_CUSTOMIZE=1 -std=c++11 -Wall -Wno-unknown-pragmas -Iinclude $(ADD_CFLAGS) $(PLUGIN_CFLAGS)
|
||||||
CFLAGS += -I$(DMLC_CORE)/include -I$(RABIT)/include -I$(GTEST_PATH)/include
|
CFLAGS += -I$(DMLC_CORE)/include -I$(RABIT)/include -I$(GTEST_PATH)/include
|
||||||
#java include path
|
#java include path
|
||||||
export JAVAINCFLAGS = -I${JAVA_HOME}/include -I./java
|
export JAVAINCFLAGS = -I${JAVA_HOME}/include -I./java
|
||||||
|
|||||||
@ -96,6 +96,15 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*)
|
|||||||
*/
|
*/
|
||||||
XGB_DLL const char *XGBGetLastError(void);
|
XGB_DLL const char *XGBGetLastError(void);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief register callback function for LOG(INFO) messages -- helpful messages
|
||||||
|
* that are not errors.
|
||||||
|
* Note: this function can be called by multiple threads. The callback function
|
||||||
|
* will run on the thread that registered it
|
||||||
|
* \return 0 for success, -1 for failure
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*));
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load a data matrix
|
* \brief load a data matrix
|
||||||
* \param fname the name of the file
|
* \param fname the name of the file
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
#define XGBOOST_LOGGING_H_
|
#define XGBOOST_LOGGING_H_
|
||||||
|
|
||||||
#include <dmlc/logging.h>
|
#include <dmlc/logging.h>
|
||||||
|
#include <dmlc/thread_local.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "./base.h"
|
#include "./base.h"
|
||||||
|
|
||||||
@ -37,6 +38,23 @@ class TrackerLogger : public BaseLogger {
|
|||||||
~TrackerLogger();
|
~TrackerLogger();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LogCallbackRegistry {
|
||||||
|
public:
|
||||||
|
using Callback = void (*)(const char*);
|
||||||
|
LogCallbackRegistry()
|
||||||
|
: log_callback_([] (const char* msg) { std::cerr << msg << std::endl; }) {}
|
||||||
|
inline void Register(Callback log_callback) {
|
||||||
|
this->log_callback_ = log_callback;
|
||||||
|
}
|
||||||
|
inline Callback Get() const {
|
||||||
|
return log_callback_;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
Callback log_callback_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using LogCallbackRegistryStore = dmlc::ThreadLocalStore<LogCallbackRegistry>;
|
||||||
|
|
||||||
// redefines the logging macro if not existed
|
// redefines the logging macro if not existed
|
||||||
#ifndef LOG
|
#ifndef LOG
|
||||||
#define LOG(severity) LOG_##severity.stream()
|
#define LOG(severity) LOG_##severity.stream()
|
||||||
|
|||||||
@ -100,6 +100,18 @@ def from_cstr_to_pystr(data, length):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _log_callback(msg):
|
||||||
|
"""Redirect logs from native library into Python console"""
|
||||||
|
print("{0:s}".format(py_str(msg)))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_log_callback_func():
|
||||||
|
"""Wrap log_callback() method in ctypes callback type"""
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
|
||||||
|
return CALLBACK(_log_callback)
|
||||||
|
|
||||||
|
|
||||||
def _load_lib():
|
def _load_lib():
|
||||||
"""Load xgboost Library."""
|
"""Load xgboost Library."""
|
||||||
lib_path = find_lib_path()
|
lib_path = find_lib_path()
|
||||||
@ -107,6 +119,9 @@ def _load_lib():
|
|||||||
return None
|
return None
|
||||||
lib = ctypes.cdll.LoadLibrary(lib_path[0])
|
lib = ctypes.cdll.LoadLibrary(lib_path[0])
|
||||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||||
|
lib.callback = _get_log_callback_func()
|
||||||
|
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||||
|
raise XGBoostError(lib.XGBGetLastError())
|
||||||
return lib
|
return lib
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -202,6 +202,13 @@ struct XGBAPIThreadLocalEntry {
|
|||||||
// define the threadlocal store.
|
// define the threadlocal store.
|
||||||
using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
|
using XGBAPIThreadLocalStore = dmlc::ThreadLocalStore<XGBAPIThreadLocalEntry>;
|
||||||
|
|
||||||
|
int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
||||||
|
API_BEGIN();
|
||||||
|
LogCallbackRegistry* registry = LogCallbackRegistryStore::Get();
|
||||||
|
registry->Register(callback);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
int XGDMatrixCreateFromFile(const char *fname,
|
int XGDMatrixCreateFromFile(const char *fname,
|
||||||
int silent,
|
int silent,
|
||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
|
|||||||
@ -8,16 +8,24 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "./common/sync.h"
|
#include "./common/sync.h"
|
||||||
|
|
||||||
namespace xgboost {
|
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||||
|
// Override logging mechanism for non-R interfaces
|
||||||
|
void dmlc::CustomLogMessage::Log(const std::string& msg) {
|
||||||
|
const xgboost::LogCallbackRegistry* registry
|
||||||
|
= xgboost::LogCallbackRegistryStore::Get();
|
||||||
|
auto callback = registry->Get();
|
||||||
|
callback(msg.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
#if XGBOOST_CUSTOMIZE_LOGGER == 0
|
namespace xgboost {
|
||||||
ConsoleLogger::~ConsoleLogger() {
|
ConsoleLogger::~ConsoleLogger() {
|
||||||
std::cerr << log_stream_.str() << std::endl;
|
dmlc::CustomLogMessage::Log(log_stream_.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
TrackerLogger::~TrackerLogger() {
|
TrackerLogger::~TrackerLogger() {
|
||||||
log_stream_ << '\n';
|
log_stream_ << '\n';
|
||||||
rabit::TrackerPrint(log_stream_.str());
|
rabit::TrackerPrint(log_stream_.str());
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
#endif
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user