Unify logging facilities. (#3982)
* Unify logging facilities. * Enhance `ConsoleLogger` to handle different verbosity. * Override macros from `dmlc`. * Don't use specialized gamma when building with GPU. * Remove verbosity cache in monitor. * Test monitor. * Deprecate `silent`. * Fix doc and messages. * Fix python test. * Fix silent tests.
This commit is contained in:
parent
fd722d60cd
commit
e0a279114e
@ -33,7 +33,7 @@ evalerror <- function(preds, dtrain) {
|
||||
return(list(metric = "error", value = err))
|
||||
}
|
||||
|
||||
param <- list(max_depth=2, eta=1, nthread = 2, silent=1,
|
||||
param <- list(max_depth=2, eta=1, nthread = 2, verbosity=0,
|
||||
objective=logregobj, eval_metric=evalerror)
|
||||
print ('start training with user customized objective')
|
||||
# training with customized objective, we can also do step by step training
|
||||
@ -57,7 +57,7 @@ logregobjattr <- function(preds, dtrain) {
|
||||
hess <- preds * (1 - preds)
|
||||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
param <- list(max_depth=2, eta=1, nthread = 2, silent=1,
|
||||
param <- list(max_depth=2, eta=1, nthread = 2, verbosity=0,
|
||||
objective=logregobjattr, eval_metric=evalerror)
|
||||
print ('start training with user customized objective, with additional attributes in DMatrix')
|
||||
# training with customized objective, we can also do step by step training
|
||||
|
||||
@ -7,7 +7,7 @@ dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||
# note: for customized objective function, we leave objective as default
|
||||
# note: what we are getting is margin value in prediction
|
||||
# you must know what you are doing
|
||||
param <- list(max_depth=2, eta=1, nthread = 2, silent=1)
|
||||
param <- list(max_depth=2, eta=1, nthread=2, verbosity=0)
|
||||
watchlist <- list(eval = dtest)
|
||||
num_round <- 20
|
||||
# user define objective function, given prediction, return gradient and second order gradient
|
||||
@ -32,9 +32,9 @@ evalerror <- function(preds, dtrain) {
|
||||
}
|
||||
print ('start training with early Stopping setting')
|
||||
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist,
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist,
|
||||
objective = logregobj, eval_metric = evalerror, maximize = FALSE,
|
||||
early_stopping_round = 3)
|
||||
bst <- xgb.cv(param, dtrain, num_round, nfold = 5,
|
||||
bst <- xgb.cv(param, dtrain, num_round, nfold = 5,
|
||||
objective = logregobj, eval_metric = evalerror,
|
||||
maximize = FALSE, early_stopping_rounds = 3)
|
||||
|
||||
@ -32,7 +32,10 @@ extern "C" {
|
||||
|
||||
namespace xgboost {
|
||||
ConsoleLogger::~ConsoleLogger() {
|
||||
dmlc::CustomLogMessage::Log(log_stream_.str());
|
||||
if (cur_verbosity_ == LogVerbosity::kIgnore ||
|
||||
cur_verbosity_ <= global_verbosity_) {
|
||||
dmlc::CustomLogMessage::Log(log_stream_.str());
|
||||
}
|
||||
}
|
||||
TrackerLogger::~TrackerLogger() {
|
||||
dmlc::CustomLogMessage::Log(log_stream_.str());
|
||||
@ -46,10 +49,11 @@ namespace common {
|
||||
bool CheckNAN(double v) {
|
||||
return ISNAN(v);
|
||||
}
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
double LogGamma(double v) {
|
||||
return lgammafn(v);
|
||||
}
|
||||
|
||||
#endif
|
||||
// customize random engine.
|
||||
void CustomGlobalRandomEngine::seed(CustomGlobalRandomEngine::result_type val) {
|
||||
// ignore the seed
|
||||
|
||||
@ -23,9 +23,16 @@ General Parameters
|
||||
|
||||
- Which booster to use. Can be ``gbtree``, ``gblinear`` or ``dart``; ``gbtree`` and ``dart`` use tree based models while ``gblinear`` uses linear functions.
|
||||
|
||||
* ``silent`` [default=0]
|
||||
* ``silent`` [default=0] [Deprecated]
|
||||
|
||||
- 0 means printing running messages, 1 means silent mode
|
||||
- Deprecated. Please use ``verbosity`` instead.
|
||||
|
||||
* ``verbosity`` [default=1]
|
||||
|
||||
- Verbosity of printing messages. Valid values are 0 (silent),
|
||||
1 (warning), 2 (info), 3 (debug). Sometimes XGBoost tries to change
|
||||
configurations based on heuristics, which is displayed as warning message.
|
||||
If there's unexpected behaviour, please try to increase value of verbosity.
|
||||
|
||||
* ``nthread`` [default to maximum number of threads available if not set]
|
||||
|
||||
|
||||
@ -9,8 +9,13 @@
|
||||
#define XGBOOST_LOGGING_H_
|
||||
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "./base.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -28,8 +33,54 @@ class BaseLogger {
|
||||
std::ostringstream log_stream_;
|
||||
};
|
||||
|
||||
// Parsing both silent and debug_verbose is to provide backward compatibility.
|
||||
struct ConsoleLoggerParam : public dmlc::Parameter<ConsoleLoggerParam> {
|
||||
bool silent; // deprecated.
|
||||
int verbosity;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(ConsoleLoggerParam) {
|
||||
DMLC_DECLARE_FIELD(silent)
|
||||
.set_default(false)
|
||||
.describe("Do not print information during training.");
|
||||
DMLC_DECLARE_FIELD(verbosity)
|
||||
.set_range(0, 3)
|
||||
.set_default(1) // shows only warning
|
||||
.describe("Flag to print out detailed breakdown of runtime.");
|
||||
DMLC_DECLARE_ALIAS(verbosity, debug_verbose);
|
||||
}
|
||||
};
|
||||
|
||||
class ConsoleLogger : public BaseLogger {
|
||||
public:
|
||||
enum class LogVerbosity {
|
||||
kSilent = 0,
|
||||
kWarning = 1,
|
||||
kInfo = 2, // information may interests users.
|
||||
kDebug = 3, // information only interesting to developers.
|
||||
kIgnore = 4 // ignore global setting
|
||||
};
|
||||
using LV = LogVerbosity;
|
||||
|
||||
private:
|
||||
static LogVerbosity global_verbosity_;
|
||||
static ConsoleLoggerParam param_;
|
||||
|
||||
LogVerbosity cur_verbosity_;
|
||||
static void Configure(const std::map<std::string, std::string>& args);
|
||||
|
||||
public:
|
||||
template <typename ArgIter>
|
||||
static void Configure(ArgIter begin, ArgIter end) {
|
||||
std::map<std::string, std::string> args(begin, end);
|
||||
Configure(args);
|
||||
}
|
||||
|
||||
static LogVerbosity GlobalVerbosity();
|
||||
static LogVerbosity DefaultVerbosity();
|
||||
|
||||
ConsoleLogger();
|
||||
explicit ConsoleLogger(LogVerbosity cur_verb);
|
||||
ConsoleLogger(const std::string& file, int line, LogVerbosity cur_verb);
|
||||
~ConsoleLogger();
|
||||
};
|
||||
|
||||
@ -68,13 +119,34 @@ class LogCallbackRegistry {
|
||||
|
||||
using LogCallbackRegistryStore = dmlc::ThreadLocalStore<LogCallbackRegistry>;
|
||||
|
||||
// Redefines LOG_WARNING for controling verbosity
|
||||
#if defined(LOG_WARNING)
|
||||
#undef LOG_WARNING
|
||||
#endif
|
||||
#define LOG_WARNING ::xgboost::ConsoleLogger( \
|
||||
__FILE__, __LINE__, ::xgboost::ConsoleLogger::LogVerbosity::kWarning)
|
||||
|
||||
// Redefines LOG_INFO for controling verbosity
|
||||
#if defined(LOG_INFO)
|
||||
#undef LOG_INFO
|
||||
#endif
|
||||
#define LOG_INFO ::xgboost::ConsoleLogger( \
|
||||
__FILE__, __LINE__, ::xgboost::ConsoleLogger::LogVerbosity::kInfo)
|
||||
|
||||
#if defined(LOG_DEBUG)
|
||||
#undef LOG_DEBUG
|
||||
#endif
|
||||
#define LOG_DEBUG ::xgboost::ConsoleLogger( \
|
||||
__FILE__, __LINE__, ::xgboost::ConsoleLogger::LogVerbosity::kDebug)
|
||||
|
||||
// redefines the logging macro if not existed
|
||||
#ifndef LOG
|
||||
#define LOG(severity) LOG_##severity.stream()
|
||||
#endif
|
||||
|
||||
// Enable LOG(CONSOLE) for print messages to console.
|
||||
#define LOG_CONSOLE ::xgboost::ConsoleLogger()
|
||||
#define LOG_CONSOLE ::xgboost::ConsoleLogger( \
|
||||
::xgboost::ConsoleLogger::LogVerbosity::kIgnore)
|
||||
// Enable LOG(TRACKER) for print messages to tracker
|
||||
#define LOG_TRACKER ::xgboost::TrackerLogger()
|
||||
} // namespace xgboost.
|
||||
|
||||
@ -237,7 +237,7 @@ class XGBModel(XGBModelBase):
|
||||
else:
|
||||
xgb_params['nthread'] = n_jobs
|
||||
|
||||
xgb_params['silent'] = 1 if self.silent else 0
|
||||
xgb_params['verbosity'] = 0 if self.silent else 0
|
||||
|
||||
if xgb_params['nthread'] <= 0:
|
||||
xgb_params.pop('nthread', None)
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include "./common/common.h"
|
||||
#include "./common/config.h"
|
||||
|
||||
|
||||
@ -33,8 +34,6 @@ enum CLITask {
|
||||
struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
/*! \brief the task name */
|
||||
int task;
|
||||
/*! \brief whether silent */
|
||||
int silent;
|
||||
/*! \brief whether evaluate training statistics */
|
||||
bool eval_train;
|
||||
/*! \brief number of boosting iterations */
|
||||
@ -82,8 +81,6 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
.add_enum("dump", kDumpModel)
|
||||
.add_enum("pred", kPredict)
|
||||
.describe("Task to be performed by the CLI program.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(0).set_range(0, 2)
|
||||
.describe("Silent level during the task.");
|
||||
DMLC_DECLARE_FIELD(eval_train).set_default(false)
|
||||
.describe("Whether evaluate on training data during training.");
|
||||
DMLC_DECLARE_FIELD(num_round).set_default(10).set_lower_bound(1)
|
||||
@ -125,10 +122,10 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
DMLC_DECLARE_ALIAS(name_fmap, fmap);
|
||||
}
|
||||
// customized configure function of CLIParam
|
||||
inline void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) {
|
||||
this->cfg = cfg;
|
||||
this->InitAllowUnknown(cfg);
|
||||
for (const auto& kv : cfg) {
|
||||
inline void Configure(const std::vector<std::pair<std::string, std::string> >& _cfg) {
|
||||
this->cfg = _cfg;
|
||||
this->InitAllowUnknown(_cfg);
|
||||
for (const auto& kv : _cfg) {
|
||||
if (!strncmp("eval[", kv.first.c_str(), 5)) {
|
||||
char evname[256];
|
||||
CHECK_EQ(sscanf(kv.first.c_str(), "eval[%[^]]", evname), 1)
|
||||
@ -140,13 +137,13 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
// constraint.
|
||||
if (name_pred == "stdout") {
|
||||
save_period = 0;
|
||||
silent = 1;
|
||||
this->cfg.emplace_back(std::make_pair("silent", "0"));
|
||||
}
|
||||
if (dsplit == 0 && rabit::IsDistributed()) {
|
||||
dsplit = 2;
|
||||
}
|
||||
if (rabit::GetRank() != 0) {
|
||||
silent = 2;
|
||||
this->cfg.emplace_back(std::make_pair("silent", "1"));
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -161,15 +158,20 @@ void CLITrain(const CLIParam& param) {
|
||||
}
|
||||
// load in data.
|
||||
std::shared_ptr<DMatrix> dtrain(
|
||||
DMatrix::Load(param.train_path, param.silent != 0, param.dsplit == 2));
|
||||
DMatrix::Load(
|
||||
param.train_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param.dsplit == 2));
|
||||
std::vector<std::shared_ptr<DMatrix> > deval;
|
||||
std::vector<std::shared_ptr<DMatrix> > cache_mats;
|
||||
std::vector<DMatrix*> eval_datasets;
|
||||
cache_mats.push_back(dtrain);
|
||||
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
|
||||
deval.emplace_back(
|
||||
std::shared_ptr<DMatrix>(DMatrix::Load(param.eval_data_paths[i],
|
||||
param.silent != 0, param.dsplit == 2)));
|
||||
std::shared_ptr<DMatrix>(DMatrix::Load(
|
||||
param.eval_data_paths[i],
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param.dsplit == 2)));
|
||||
eval_datasets.push_back(deval.back().get());
|
||||
cache_mats.push_back(deval.back());
|
||||
}
|
||||
@ -193,17 +195,14 @@ void CLITrain(const CLIParam& param) {
|
||||
learner->InitModel();
|
||||
}
|
||||
}
|
||||
if (param.silent == 0) {
|
||||
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
|
||||
}
|
||||
LOG(INFO) << "Loading data: " << dmlc::GetTime() - tstart_data_load << " sec";
|
||||
|
||||
// start training.
|
||||
const double start = dmlc::GetTime();
|
||||
for (int i = version / 2; i < param.num_round; ++i) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
if (version % 2 == 0) {
|
||||
if (param.silent == 0) {
|
||||
LOG(CONSOLE) << "boosting round " << i << ", " << elapsed << " sec elapsed";
|
||||
}
|
||||
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
|
||||
learner->UpdateOneIter(i, dtrain.get());
|
||||
if (learner->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(learner.get());
|
||||
@ -219,9 +218,7 @@ void CLITrain(const CLIParam& param) {
|
||||
LOG(TRACKER) << res;
|
||||
}
|
||||
} else {
|
||||
if (param.silent < 2) {
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
if (param.save_period != 0 &&
|
||||
(i + 1) % param.save_period == 0 &&
|
||||
@ -260,10 +257,8 @@ void CLITrain(const CLIParam& param) {
|
||||
learner->Save(fo.get());
|
||||
}
|
||||
|
||||
if (param.silent == 0) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
LOG(CONSOLE) << "update end, " << elapsed << " sec in all";
|
||||
}
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
LOG(INFO) << "update end, " << elapsed << " sec in all";
|
||||
}
|
||||
|
||||
void CLIDumpModel(const CLIParam& param) {
|
||||
@ -310,7 +305,10 @@ void CLIPredict(const CLIParam& param) {
|
||||
<< "Test dataset parameter test:data must be specified.";
|
||||
// load data
|
||||
std::unique_ptr<DMatrix> dtest(
|
||||
DMatrix::Load(param.test_path, param.silent != 0, param.dsplit == 2));
|
||||
DMatrix::Load(
|
||||
param.test_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param.dsplit == 2));
|
||||
// load model
|
||||
CHECK_NE(param.model_in, "NULL")
|
||||
<< "Must specify model_in for predict";
|
||||
@ -320,14 +318,11 @@ void CLIPredict(const CLIParam& param) {
|
||||
learner->Load(fi.get());
|
||||
learner->Configure(param.cfg);
|
||||
|
||||
if (param.silent == 0) {
|
||||
LOG(CONSOLE) << "start prediction...";
|
||||
}
|
||||
LOG(INFO) << "start prediction...";
|
||||
HostDeviceVector<bst_float> preds;
|
||||
learner->Predict(dtest.get(), param.pred_margin, &preds, param.ntree_limit);
|
||||
if (param.silent == 0) {
|
||||
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
|
||||
}
|
||||
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(param.name_pred.c_str(), "w"));
|
||||
dmlc::ostream os(fo.get());
|
||||
|
||||
@ -480,7 +480,7 @@ class BulkAllocator {
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void Allocate(int device_idx, bool silent, Args... args) {
|
||||
void Allocate(int device_idx, Args... args) {
|
||||
size_t size = GetSizeBytes(args...);
|
||||
|
||||
char *ptr = AllocateDevice(device_idx, size, MemoryT);
|
||||
|
||||
@ -116,7 +116,6 @@ inline static bool CmpSecond(const std::pair<float, unsigned> &a,
|
||||
#if XGBOOST_STRICT_R_MODE
|
||||
// check nan
|
||||
bool CheckNAN(double v);
|
||||
double LogGamma(double v);
|
||||
#else
|
||||
template<typename T>
|
||||
inline bool CheckNAN(T v) {
|
||||
@ -126,9 +125,19 @@ inline bool CheckNAN(T v) {
|
||||
return std::isnan(v);
|
||||
#endif
|
||||
}
|
||||
#endif // XGBOOST_STRICT_R_MODE_
|
||||
|
||||
// GPU version is not uploaded in CRAN anyway.
|
||||
// Specialize only when using R with CPU.
|
||||
#if XGBOOST_STRICT_R_MODE && !defined(XGBOOST_USE_CUDA)
|
||||
double LogGamma(double v);
|
||||
|
||||
#else // Not R or R with GPU.
|
||||
|
||||
template<typename T>
|
||||
XGBOOST_DEVICE inline T LogGamma(T v) {
|
||||
#ifdef _MSC_VER
|
||||
|
||||
#if _MSC_VER >= 1800
|
||||
return lgamma(v);
|
||||
#else
|
||||
@ -136,12 +145,15 @@ XGBOOST_DEVICE inline T LogGamma(T v) {
|
||||
", poisson regression will be disabled")
|
||||
utils::Error("lgamma function was not available until VS2013");
|
||||
return static_cast<T>(1.0);
|
||||
#endif
|
||||
#endif // _MSC_VER >= 1800
|
||||
|
||||
#else
|
||||
return lgamma(v);
|
||||
#endif
|
||||
}
|
||||
#endif // XGBOOST_STRICT_R_MODE_
|
||||
|
||||
#endif // XGBOOST_STRICT_R_MODE && !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_MATH_H_
|
||||
|
||||
@ -49,15 +49,19 @@ struct Monitor {
|
||||
Timer timer;
|
||||
size_t count{0};
|
||||
};
|
||||
bool debug_verbose = false;
|
||||
std::string label = "";
|
||||
std::map<std::string, Statistics> statistics_map;
|
||||
Timer self_timer;
|
||||
bool IsVerbose() {
|
||||
// Don't cache debug verbosity in here to deal with changed parameter.
|
||||
return (ConsoleLogger::GlobalVerbosity() == ConsoleLogger::LV::kDebug);
|
||||
}
|
||||
|
||||
public:
|
||||
Monitor() { self_timer.Start(); }
|
||||
|
||||
~Monitor() {
|
||||
if (!debug_verbose) return;
|
||||
if (!IsVerbose()) return;
|
||||
|
||||
LOG(CONSOLE) << "======== Monitor: " << label << " ========";
|
||||
for (auto &kv : statistics_map) {
|
||||
@ -70,13 +74,12 @@ struct Monitor {
|
||||
}
|
||||
self_timer.Stop();
|
||||
}
|
||||
void Init(std::string label, bool debug_verbose) {
|
||||
this->debug_verbose = debug_verbose;
|
||||
void Init(std::string label) {
|
||||
this->label = label;
|
||||
}
|
||||
void Start(const std::string &name) { statistics_map[name].timer.Start(); }
|
||||
void Start(const std::string &name, GPUSet devices) {
|
||||
if (debug_verbose) {
|
||||
if (IsVerbose()) {
|
||||
#ifdef __CUDACC__
|
||||
for (auto device : devices) {
|
||||
cudaSetDevice(device);
|
||||
@ -91,7 +94,7 @@ struct Monitor {
|
||||
statistics_map[name].count++;
|
||||
}
|
||||
void Stop(const std::string &name, GPUSet devices) {
|
||||
if (debug_verbose) {
|
||||
if (IsVerbose()) {
|
||||
#ifdef __CUDACC__
|
||||
for (auto device : devices) {
|
||||
cudaSetDevice(device);
|
||||
|
||||
@ -223,7 +223,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
|
||||
<< dmat->Info().num_nonzero_ << " entries loaded from " << uri;
|
||||
}
|
||||
/* sync up number of features after matrix loaded.
|
||||
* partitioned data will fail the train/val validation check
|
||||
* partitioned data will fail the train/val validation check
|
||||
* since partitioned data not knowing the real number of features. */
|
||||
rabit::Allreduce<rabit::op::Max>(&dmat->Info().num_col_, 1);
|
||||
// backward compatiblity code.
|
||||
|
||||
@ -26,7 +26,6 @@ struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
|
||||
std::string updater;
|
||||
float tolerance;
|
||||
size_t max_row_perbatch;
|
||||
int debug_verbose;
|
||||
DMLC_DECLARE_PARAMETER(GBLinearTrainParam) {
|
||||
DMLC_DECLARE_FIELD(updater)
|
||||
.set_default("shotgun")
|
||||
@ -38,10 +37,6 @@ struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
|
||||
DMLC_DECLARE_FIELD(max_row_perbatch)
|
||||
.set_default(std::numeric_limits<size_t>::max())
|
||||
.describe("Maximum rows per batch.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
}
|
||||
};
|
||||
/*!
|
||||
@ -69,7 +64,7 @@ class GBLinear : public GradientBooster {
|
||||
param_.InitAllowUnknown(cfg);
|
||||
updater_.reset(LinearUpdater::Create(param_.updater));
|
||||
updater_->Init(cfg);
|
||||
monitor_.Init("GBLinear ", param_.debug_verbose);
|
||||
monitor_.Init("GBLinear");
|
||||
}
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
model_.Load(fi);
|
||||
|
||||
@ -45,8 +45,6 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
std::string updater_seq;
|
||||
/*! \brief type of boosting process to run */
|
||||
int process_type;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
std::string predictor;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
||||
@ -64,10 +62,6 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
.add_enum("update", kUpdate)
|
||||
.describe("Whether to run the normal boosting process that creates new trees,"\
|
||||
" or to update the trees in an existing model.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
// add alias
|
||||
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||
DMLC_DECLARE_FIELD(predictor)
|
||||
@ -78,8 +72,6 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
|
||||
/*! \brief training parameters */
|
||||
struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
||||
/*! \brief whether to not print info during training */
|
||||
bool silent;
|
||||
/*! \brief type of sampling algorithm */
|
||||
int sample_type;
|
||||
/*! \brief type of normalization algorithm */
|
||||
@ -94,9 +86,6 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
||||
float learning_rate;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(DartTrainParam) {
|
||||
DMLC_DECLARE_FIELD(silent)
|
||||
.set_default(false)
|
||||
.describe("Not print information during training.");
|
||||
DMLC_DECLARE_FIELD(sample_type)
|
||||
.set_default(0)
|
||||
.add_enum("uniform", 0)
|
||||
@ -160,7 +149,7 @@ class GBTree : public GradientBooster {
|
||||
// configure predictor
|
||||
predictor_ = std::unique_ptr<Predictor>(Predictor::Create(tparam_.predictor));
|
||||
predictor_->Init(cfg, cache_);
|
||||
monitor_.Init("GBTree", tparam_.debug_verbose);
|
||||
monitor_.Init("GBTree");
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
@ -488,10 +477,8 @@ class Dart : public GBTree {
|
||||
model_.CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
size_t num_drop = NormalizeTrees(num_new_trees);
|
||||
if (dparam_.silent != 1) {
|
||||
LOG(INFO) << "drop " << num_drop << " trees, "
|
||||
<< "weight = " << weight_drop_.back();
|
||||
}
|
||||
LOG(INFO) << "drop " << num_drop << " trees, "
|
||||
<< "weight = " << weight_drop_.back();
|
||||
}
|
||||
|
||||
// predict the leaf scores without dropped trees
|
||||
|
||||
@ -122,8 +122,6 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
|
||||
// number of threads to use if OpenMP is enabled
|
||||
// if equals 0, use system default
|
||||
int nthread;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
// flag to disable default metric
|
||||
int disable_default_eval_metric;
|
||||
// declare parameters
|
||||
@ -155,10 +153,6 @@ struct LearnerTrainParam : public dmlc::Parameter<LearnerTrainParam> {
|
||||
"Internal test flag");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
|
||||
"Number of threads to use.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
||||
.set_default(0)
|
||||
.describe("flag to disable default metric. Set to >0 to disable");
|
||||
@ -196,7 +190,7 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
// `updater` parameter was manually specified
|
||||
if (cfg_.count("updater") > 0) {
|
||||
LOG(CONSOLE) << "DANGER AHEAD: You have manually specified `updater` "
|
||||
LOG(WARNING) << "DANGER AHEAD: You have manually specified `updater` "
|
||||
"parameter. The `tree_method` parameter will be ignored. "
|
||||
"Incorrect sequence of updaters will produce undefined "
|
||||
"behavior. For common uses, we recommend using "
|
||||
@ -217,8 +211,9 @@ class LearnerImpl : public Learner {
|
||||
cfg_["updater"] = "grow_colmaker,prune";
|
||||
break;
|
||||
case TreeMethod::kHist:
|
||||
LOG(CONSOLE) << "Tree method is selected to be 'hist', which uses a "
|
||||
"single updater grow_quantile_histmaker.";
|
||||
LOG(INFO) <<
|
||||
"Tree method is selected to be 'hist', which uses a "
|
||||
"single updater grow_quantile_histmaker.";
|
||||
cfg_["updater"] = "grow_quantile_histmaker";
|
||||
break;
|
||||
case TreeMethod::kGPUExact:
|
||||
@ -245,8 +240,10 @@ class LearnerImpl : public Learner {
|
||||
const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
// add to configurations
|
||||
tparam_.InitAllowUnknown(args);
|
||||
monitor_.Init("Learner", tparam_.debug_verbose);
|
||||
ConsoleLogger::Configure(args.cbegin(), args.cend());
|
||||
monitor_.Init("Learner");
|
||||
cfg_.clear();
|
||||
|
||||
for (const auto& kv : args) {
|
||||
if (kv.first == "eval_metric") {
|
||||
// check duplication
|
||||
@ -270,7 +267,6 @@ class LearnerImpl : public Learner {
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
if (cfg_.count("num_class") != 0) {
|
||||
cfg_["num_output_group"] = cfg_["num_class"];
|
||||
if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) {
|
||||
@ -612,15 +608,16 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
switch (current_tree_method) {
|
||||
case TreeMethod::kAuto:
|
||||
LOG(CONSOLE) << "Tree method is automatically selected to be 'approx' "
|
||||
"for distributed training.";
|
||||
LOG(WARNING) <<
|
||||
"Tree method is automatically selected to be 'approx' "
|
||||
"for distributed training.";
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
// things are okay, do nothing
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
case TreeMethod::kHist:
|
||||
LOG(CONSOLE) << "Tree method was set to be '"
|
||||
LOG(WARNING) << "Tree method was set to be '"
|
||||
<< (current_tree_method == TreeMethod::kExact ?
|
||||
"exact" : "hist")
|
||||
<< "', but only 'approx' is available for distributed "
|
||||
@ -640,14 +637,14 @@ class LearnerImpl : public Learner {
|
||||
/* Some tree methods are not available for external-memory DMatrix */
|
||||
switch (current_tree_method) {
|
||||
case TreeMethod::kAuto:
|
||||
LOG(CONSOLE) << "Tree method is automatically set to 'approx' "
|
||||
LOG(WARNING) << "Tree method is automatically set to 'approx' "
|
||||
"since external-memory data matrix is used.";
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
// things are okay, do nothing
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
LOG(CONSOLE) << "Tree method was set to be 'exact', "
|
||||
LOG(WARNING) << "Tree method was set to be 'exact', "
|
||||
"but currently we are only able to proceed with "
|
||||
"approximate algorithm ('approx') because external-"
|
||||
"memory data matrix is used.";
|
||||
@ -668,7 +665,7 @@ class LearnerImpl : public Learner {
|
||||
} else if (p_train->Info().num_row_ >= (4UL << 20UL)
|
||||
&& current_tree_method == TreeMethod::kAuto) {
|
||||
/* Choose tree_method='approx' automatically for large data matrix */
|
||||
LOG(CONSOLE) << "Tree method is automatically selected to be "
|
||||
LOG(WARNING) << "Tree method is automatically selected to be "
|
||||
"'approx' for faster speed. To use old behavior "
|
||||
"(exact greedy algorithm on single machine), "
|
||||
"set tree_method to 'exact'.";
|
||||
|
||||
@ -22,7 +22,6 @@ struct CoordinateTrainParam : public dmlc::Parameter<CoordinateTrainParam> {
|
||||
float reg_alpha;
|
||||
int feature_selector;
|
||||
int top_k;
|
||||
int debug_verbose;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(CoordinateTrainParam) {
|
||||
DMLC_DECLARE_FIELD(learning_rate)
|
||||
@ -50,10 +49,6 @@ struct CoordinateTrainParam : public dmlc::Parameter<CoordinateTrainParam> {
|
||||
.set_default(0)
|
||||
.describe("The number of top features to select in 'thrifty' feature_selector. "
|
||||
"The value of zero means using all the features.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
// alias of parameters
|
||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
@ -82,7 +77,7 @@ class CoordinateUpdater : public LinearUpdater {
|
||||
const std::vector<std::pair<std::string, std::string> > &args) override {
|
||||
param.InitAllowUnknown(args);
|
||||
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||
monitor.Init("CoordinateUpdater", param.debug_verbose);
|
||||
monitor.Init("CoordinateUpdater");
|
||||
}
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
|
||||
@ -27,10 +27,8 @@ struct GPUCoordinateTrainParam
|
||||
float reg_alpha;
|
||||
int feature_selector;
|
||||
int top_k;
|
||||
int debug_verbose;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
bool silent;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUCoordinateTrainParam) {
|
||||
DMLC_DECLARE_FIELD(learning_rate)
|
||||
@ -56,16 +54,10 @@ struct GPUCoordinateTrainParam
|
||||
DMLC_DECLARE_FIELD(top_k).set_lower_bound(0).set_default(0).describe(
|
||||
"The number of top features to select in 'thrifty' feature_selector. "
|
||||
"The value of zero means using all the features.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).describe(
|
||||
"Number of devices to use.");
|
||||
DMLC_DECLARE_FIELD(gpu_id).set_default(0).describe(
|
||||
"Primary device ordinal.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
||||
"Do not print information during trainig.");
|
||||
// alias of parameters
|
||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
@ -126,7 +118,7 @@ class DeviceShard {
|
||||
std::make_pair(column_begin - col.data(), column_end - col.data()));
|
||||
row_ptr_.push_back(row_ptr_.back() + column_end - column_begin);
|
||||
}
|
||||
ba_.Allocate(device_id_, param.silent, &data_, row_ptr_.back(), &gpair_,
|
||||
ba_.Allocate(device_id_, &data_, row_ptr_.back(), &gpair_,
|
||||
(row_end - row_begin) * model_param.num_output_group);
|
||||
|
||||
for (int fidx = 0; fidx < batch.Size(); fidx++) {
|
||||
@ -209,7 +201,7 @@ class GPUCoordinateUpdater : public LinearUpdater {
|
||||
const std::vector<std::pair<std::string, std::string>> &args) override {
|
||||
param.InitAllowUnknown(args);
|
||||
selector.reset(FeatureSelector::Create(param.feature_selector));
|
||||
monitor.Init("GPUCoordinateUpdater", param.debug_verbose);
|
||||
monitor.Init("GPUCoordinateUpdater");
|
||||
}
|
||||
|
||||
void LazyInitShards(DMatrix *p_fmat,
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* \file logging.cc
|
||||
* \brief Implementation of loggers.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <rabit/rabit.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||
// Override logging mechanism for non-R interfaces
|
||||
@ -18,8 +21,12 @@ void dmlc::CustomLogMessage::Log(const std::string& msg) {
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
ConsoleLogger::~ConsoleLogger() {
|
||||
dmlc::CustomLogMessage::Log(log_stream_.str());
|
||||
if (cur_verbosity_ == LogVerbosity::kIgnore ||
|
||||
cur_verbosity_ <= global_verbosity_) {
|
||||
dmlc::CustomLogMessage::Log(BaseLogger::log_stream_.str());
|
||||
}
|
||||
}
|
||||
|
||||
TrackerLogger::~TrackerLogger() {
|
||||
@ -28,4 +35,71 @@ TrackerLogger::~TrackerLogger() {
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
DMLC_REGISTER_PARAMETER(ConsoleLoggerParam);
|
||||
|
||||
ConsoleLogger::LogVerbosity ConsoleLogger::global_verbosity_ =
|
||||
ConsoleLogger::DefaultVerbosity();
|
||||
|
||||
ConsoleLoggerParam ConsoleLogger::param_ = ConsoleLoggerParam();
|
||||
void ConsoleLogger::Configure(const std::map<std::string, std::string>& args) {
|
||||
param_.InitAllowUnknown(args);
|
||||
// Deprecated, but when trying to display deprecation message some R
|
||||
// tests trying to catch stdout will fail.
|
||||
if (param_.silent) {
|
||||
global_verbosity_ = LogVerbosity::kSilent;
|
||||
return;
|
||||
}
|
||||
switch (param_.verbosity) {
|
||||
case 0:
|
||||
global_verbosity_ = LogVerbosity::kSilent;
|
||||
break;
|
||||
case 1:
|
||||
global_verbosity_ = LogVerbosity::kWarning;
|
||||
break;
|
||||
case 2:
|
||||
global_verbosity_ = LogVerbosity::kInfo;
|
||||
break;
|
||||
case 3:
|
||||
global_verbosity_ = LogVerbosity::kDebug;
|
||||
default:
|
||||
// global verbosity doesn't require kIgnore
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ConsoleLogger::LogVerbosity ConsoleLogger::DefaultVerbosity() {
|
||||
return LogVerbosity::kWarning;
|
||||
}
|
||||
|
||||
ConsoleLogger::LogVerbosity ConsoleLogger::GlobalVerbosity() {
|
||||
return global_verbosity_;
|
||||
}
|
||||
|
||||
ConsoleLogger::ConsoleLogger() : cur_verbosity_{LogVerbosity::kInfo} {}
|
||||
ConsoleLogger::ConsoleLogger(LogVerbosity cur_verb) :
|
||||
cur_verbosity_{cur_verb} {}
|
||||
|
||||
ConsoleLogger::ConsoleLogger(
|
||||
const std::string& file, int line, LogVerbosity cur_verb) {
|
||||
cur_verbosity_ = cur_verb;
|
||||
switch (cur_verbosity_) {
|
||||
case LogVerbosity::kWarning:
|
||||
BaseLogger::log_stream_ << "WARNING: ";
|
||||
case LogVerbosity::kDebug:
|
||||
BaseLogger::log_stream_ << "DEBUG: ";
|
||||
case LogVerbosity::kInfo:
|
||||
BaseLogger::log_stream_ << "INFO: ";
|
||||
case LogVerbosity::kIgnore:
|
||||
BaseLogger::log_stream_ << file << ":" << line << ": ";
|
||||
break;
|
||||
case LogVerbosity::kSilent:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -24,15 +24,12 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
struct GPUPredictionParam : public dmlc::Parameter<GPUPredictionParam> {
|
||||
int gpu_id;
|
||||
int n_gpus;
|
||||
bool silent;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUPredictionParam) {
|
||||
DMLC_DECLARE_FIELD(gpu_id).set_lower_bound(0).set_default(0).describe(
|
||||
"Device ordinal for GPU prediction.");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_lower_bound(-1).set_default(1).describe(
|
||||
"Number of devices to use for prediction.");
|
||||
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
||||
"Do not print information during trainig.");
|
||||
}
|
||||
};
|
||||
DMLC_REGISTER_PARAMETER(GPUPredictionParam);
|
||||
|
||||
@ -34,8 +34,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
// growing policy
|
||||
enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 };
|
||||
int grow_policy;
|
||||
// flag to print out detailed breakdown of runtime
|
||||
int debug_verbose;
|
||||
|
||||
//----- the rest parameters are less important ----
|
||||
// minimum amount of hessian(weight) allowed in a child
|
||||
float min_child_weight;
|
||||
@ -67,8 +66,6 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
int parallel_option;
|
||||
// option to open cacheline optimization
|
||||
bool cache_opt;
|
||||
// whether to not print info during training.
|
||||
bool silent;
|
||||
// whether refresh updater needs to update the leaf values
|
||||
bool refresh_leaf;
|
||||
// auxiliary data structure
|
||||
@ -107,10 +104,6 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
.set_default(0.0f)
|
||||
.describe(
|
||||
"Minimum loss reduction required to make a further partition.");
|
||||
DMLC_DECLARE_FIELD(debug_verbose)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("flag to print out detailed breakdown of runtime");
|
||||
DMLC_DECLARE_FIELD(max_depth)
|
||||
.set_lower_bound(0)
|
||||
.set_default(6)
|
||||
@ -186,9 +179,6 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
DMLC_DECLARE_FIELD(cache_opt)
|
||||
.set_default(true)
|
||||
.describe("EXP Param: Cache aware optimization.");
|
||||
DMLC_DECLARE_FIELD(silent)
|
||||
.set_default(false)
|
||||
.describe("Do not print information during trainig.");
|
||||
DMLC_DECLARE_FIELD(refresh_leaf)
|
||||
.set_default(true)
|
||||
.describe("Whether the refresh updater needs to update leaf values.");
|
||||
|
||||
@ -625,7 +625,7 @@ class GPUMaker : public TreeUpdater {
|
||||
|
||||
void allocateAllData(int offsetSize) {
|
||||
int tmpBuffSize = ScanTempBufferSize(nVals);
|
||||
ba.Allocate(param.gpu_id, param.silent, &vals, nVals,
|
||||
ba.Allocate(param.gpu_id, &vals, nVals,
|
||||
&vals_cached, nVals, &instIds, nVals, &instIds_cached, nVals,
|
||||
&colOffsets, offsetSize, &gradsInst, nRows, &nodeAssigns, nVals,
|
||||
&nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst,
|
||||
|
||||
@ -803,7 +803,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
int max_nodes =
|
||||
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
|
||||
|
||||
ba.Allocate(device_id_, param.silent,
|
||||
ba.Allocate(device_id_,
|
||||
&gpair, n_rows,
|
||||
&ridx, n_rows,
|
||||
&position, n_rows,
|
||||
@ -833,7 +833,7 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_id_, param.silent, &gidx_buffer, compressed_size_bytes);
|
||||
ba.Allocate(device_id_, &gidx_buffer, compressed_size_bytes);
|
||||
gidx_buffer.Fill(0);
|
||||
|
||||
int nbits = common::detail::SymbolBits(num_symbols);
|
||||
@ -931,7 +931,7 @@ class GPUHistMakerSpecialised{
|
||||
qexpand_.reset(new ExpandQueue(DepthWise));
|
||||
}
|
||||
|
||||
monitor_.Init("updater_gpu_hist", param_.debug_verbose);
|
||||
monitor_.Init("updater_gpu_hist");
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
@ -966,7 +966,9 @@ class GPUHistMakerSpecialised{
|
||||
device_list_[index] = device_id;
|
||||
}
|
||||
|
||||
reducer_.Init(device_list_, param_.debug_verbose);
|
||||
reducer_.Init(
|
||||
device_list_,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity());
|
||||
|
||||
auto batch_iter = dmat->GetRowBatches().begin();
|
||||
const SparsePage& batch = *batch_iter;
|
||||
|
||||
@ -71,11 +71,9 @@ class TreePruner: public TreeUpdater {
|
||||
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
|
||||
}
|
||||
}
|
||||
if (!param_.silent) {
|
||||
LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, "
|
||||
<< tree.NumExtraNodes() << " extra nodes, " << npruned
|
||||
<< " pruned nodes, max_depth=" << tree.MaxDepth();
|
||||
}
|
||||
LOG(INFO) << "tree pruning end, " << tree.param.num_roots << " roots, "
|
||||
<< tree.NumExtraNodes() << " extra nodes, " << npruned
|
||||
<< " pruned nodes, max_depth=" << tree.MaxDepth();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
*/
|
||||
#include <dmlc/timer.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <cmath>
|
||||
@ -60,9 +61,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
gmatb_.Init(gmat_, column_matrix_, param_);
|
||||
}
|
||||
is_gmat_initialized_ = true;
|
||||
if (param_.debug_verbose > 0) {
|
||||
LOG(INFO) << "Generating gmat: " << dmlc::GetTime() - tstart << " sec";
|
||||
}
|
||||
LOG(INFO) << "Generating gmat: " << dmlc::GetTime() - tstart << " sec";
|
||||
}
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
@ -207,32 +206,34 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
||||
|
||||
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
|
||||
|
||||
if (param_.debug_verbose > 0) {
|
||||
double total_time = dmlc::GetTime() - gstart;
|
||||
LOG(INFO) << "\nInitData: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_data
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_init_data / total_time * 100 << "%)\n"
|
||||
<< "InitNewNode: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_new_node
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_init_new_node / total_time * 100 << "%)\n"
|
||||
<< "BuildHist: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_build_hist
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_build_hist / total_time * 100 << "%)\n"
|
||||
<< "EvaluateSplit: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_evaluate_split
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_evaluate_split / total_time * 100 << "%)\n"
|
||||
<< "ApplySplit: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_apply_split
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_apply_split / total_time * 100 << "%)\n"
|
||||
<< "========================================\n"
|
||||
<< "Total: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << total_time;
|
||||
if (ConsoleLogger::GlobalVerbosity() <= ConsoleLogger::DefaultVerbosity()) {
|
||||
// Don't construct the following huge stream.
|
||||
return;
|
||||
}
|
||||
double total_time = dmlc::GetTime() - gstart;
|
||||
LOG(INFO) << "\nInitData: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_data
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_init_data / total_time * 100 << "%)\n"
|
||||
<< "InitNewNode: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_new_node
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_init_new_node / total_time * 100 << "%)\n"
|
||||
<< "BuildHist: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_build_hist
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_build_hist / total_time * 100 << "%)\n"
|
||||
<< "EvaluateSplit: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_evaluate_split
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_evaluate_split / total_time * 100 << "%)\n"
|
||||
<< "ApplySplit: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << time_apply_split
|
||||
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||
<< time_apply_split / total_time * 100 << "%)\n"
|
||||
<< "========================================\n"
|
||||
<< "Total: "
|
||||
<< std::fixed << std::setw(6) << std::setprecision(4) << total_time;
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <thrust/device_vector.h>
|
||||
#include <xgboost/base.h>
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/common/timer.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
struct Shard { int id; };
|
||||
|
||||
32
tests/cpp/common/test_monitor.cc
Normal file
32
tests/cpp/common/test_monitor.cc
Normal file
@ -0,0 +1,32 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <string>
|
||||
#include "../../../src/common/timer.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(Monitor, Basic) {
|
||||
auto run_monitor =
|
||||
[]() {
|
||||
Monitor monitor_;
|
||||
monitor_.Init("Monitor test");
|
||||
monitor_.Start("basic");
|
||||
monitor_.Stop("basic");
|
||||
};
|
||||
|
||||
std::map<std::string, std::string> args = {std::make_pair("verbosity", "3")};
|
||||
ConsoleLogger::Configure(args.cbegin(), args.cend());
|
||||
testing::internal::CaptureStderr();
|
||||
run_monitor();
|
||||
std::string output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_NE(output.find("Monitor"), std::string::npos);
|
||||
|
||||
args = {std::make_pair("verbosity", "2")};
|
||||
ConsoleLogger::Configure(args.cbegin(), args.cend());
|
||||
testing::internal::CaptureStderr();
|
||||
run_monitor();
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_EQ(output.find("Monitor"), std::string::npos);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
47
tests/cpp/test_logging.cc
Normal file
47
tests/cpp/test_logging.cc
Normal file
@ -0,0 +1,47 @@
|
||||
#include <map>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(Logging, Basic) {
|
||||
std::map<std::string, std::string> args {};
|
||||
args["verbosity"] = "0"; // silent
|
||||
ConsoleLogger::Configure(args.cbegin(), args.cend());
|
||||
testing::internal::CaptureStderr();
|
||||
std::string output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_EQ(output.length(), 0);
|
||||
|
||||
args["verbosity"] = "3"; // debug
|
||||
ConsoleLogger::Configure(args.cbegin(), args.cend());
|
||||
|
||||
testing::internal::CaptureStderr();
|
||||
LOG(WARNING) << "Test Log Warning.";
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_NE(output.find("WARNING"), std::string::npos);
|
||||
|
||||
testing::internal::CaptureStderr();
|
||||
LOG(INFO) << "Test Log Info";
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_NE(output.find("Test Log Info"), std::string::npos);
|
||||
|
||||
testing::internal::CaptureStderr();
|
||||
LOG(DEBUG) << "Test Log Debug.";
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_NE(output.find("DEBUG"), std::string::npos);
|
||||
|
||||
args["silent"] = "True";
|
||||
ConsoleLogger::Configure(args.cbegin(), args.cend());
|
||||
testing::internal::CaptureStderr();
|
||||
LOG(INFO) << "Test Log Info";
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_EQ(output.length(), 0);
|
||||
|
||||
testing::internal::CaptureStderr();
|
||||
LOG(CONSOLE) << "Test Log Console";
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
ASSERT_NE(output.find("Test Log Console"), std::string::npos);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@ -252,7 +252,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
|
||||
// Copy cut matrix to device.
|
||||
DeviceShard<GradientPairPrecise>::DeviceHistCutMatrix cut;
|
||||
shard->ba.Allocate(0, true,
|
||||
shard->ba.Allocate(0,
|
||||
&(shard->cut_.feature_segments), cmat.row_ptr.size(),
|
||||
&(shard->cut_.min_fvalue), cmat.min_val.size(),
|
||||
&(shard->cut_.gidx_fvalue_map), 24,
|
||||
@ -315,7 +315,6 @@ TEST(GpuHist, ApplySplit) {
|
||||
int constexpr n_cols = 8;
|
||||
|
||||
TrainParam param;
|
||||
param.silent = true;
|
||||
|
||||
// Initialize shard
|
||||
for (size_t i = 0; i < n_cols; ++i) {
|
||||
@ -330,7 +329,7 @@ TEST(GpuHist, ApplySplit) {
|
||||
shard->node_sum_gradients.resize(3);
|
||||
|
||||
shard->ridx_segments[0] = Segment(0, n_rows);
|
||||
shard->ba.Allocate(0, true, &(shard->ridx), n_rows,
|
||||
shard->ba.Allocate(0, &(shard->ridx), n_rows,
|
||||
&(shard->position), n_rows);
|
||||
shard->row_stride = n_cols;
|
||||
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
|
||||
@ -367,8 +366,7 @@ TEST(GpuHist, ApplySplit) {
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(
|
||||
row_stride * n_rows, num_symbols);
|
||||
shard->ba.Allocate(0, param.silent,
|
||||
&(shard->gidx_buffer), compressed_size_bytes);
|
||||
shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes);
|
||||
|
||||
common::CompressedBufferWriter wr(num_symbols);
|
||||
std::vector<int> h_gidx (n_rows * row_stride);
|
||||
|
||||
@ -290,7 +290,7 @@ class TestBasic(unittest.TestCase):
|
||||
assert len(cv) == (4)
|
||||
|
||||
def test_cv_explicit_fold_indices_labels(self):
|
||||
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective':
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0, 'objective':
|
||||
'reg:linear'}
|
||||
N = 100
|
||||
F = 3
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user