[CORE] The update process for a tree model, and its application to feature importance (#1670)

* [CORE] allow updating trees in an existing model

* [CORE] in refresh updater, allow keeping old leaf values and update stats only

* [R-package] xgb.train mod to allow updating trees in an existing model

* [R-package] added check for nrounds when is_update

* [CORE] merge parameter declaration changes; unify their code style

* [CORE] move the update-process trees initialization to Configure; rename default process_type to 'default'; fix the trees and trees_to_update sizes comparison check

* [R-package] unit tests for the update process type

* [DOC] documentation for process_type parameter; improved docs for updater, Gamma and Tweedie; added some parameter aliases; metrics indentation and some were non-documented

* fix my sloppy merge conflict resolutions

* [CORE] add a TreeProcessType enum

* whitespace fix
This commit is contained in:
Vadim Khotilovich 2016-12-04 11:33:52 -06:00 committed by Tianqi Chen
parent 4398fbbe4a
commit a44032d095
6 changed files with 221 additions and 60 deletions

View File

@ -284,6 +284,8 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
# Sort the callbacks into categories
cb <- categorize.callbacks(callbacks)
# The tree updating process would need slightly different handling
is_update <- NVL(params[['process_type']], '.') == 'update'
# Construct a booster (either a new one or load from xgb_model)
handle <- xgb.Booster(params, append(watchlist, dtrain), xgb_model)
@ -294,17 +296,20 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1)
# When the 'xgb_model' was set, find out how many boosting iterations it has
niter_skip <- 0
niter_init <- 0
if (!is.null(xgb_model)) {
niter_skip <- as.numeric(xgb.attr(bst, 'niter')) + 1
if (length(niter_skip) == 0) {
niter_skip <- xgb.ntree(bst) %/% (num_parallel_tree * num_class)
niter_init <- as.numeric(xgb.attr(bst, 'niter')) + 1
if (length(niter_init) == 0) {
niter_init <- xgb.ntree(bst) %/% (num_parallel_tree * num_class)
}
}
if(is_update && nrounds > niter_init)
stop("nrounds cannot be larger than ", niter_init, " (nrounds of xgb_model)")
# TODO: distributed code
rank <- 0
niter_skip <- ifelse(is_update, 0, niter_init)
begin_iteration <- niter_skip + 1
end_iteration <- niter_skip + nrounds
@ -337,6 +342,7 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
nrow(evaluation_log) > 0) {
# include the previous compatible history when available
if (class(xgb_model) == 'xgb.Booster' &&
!is_update &&
!is.null(xgb_model$evaluation_log) &&
all.equal(colnames(evaluation_log),
colnames(xgb_model$evaluation_log))) {

View File

@ -0,0 +1,76 @@
require(xgboost)
context("update trees in an existing model")
data(agaricus.train, package = 'xgboost')
data(agaricus.test, package = 'xgboost')
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
test_that("updating the model works", {
watchlist = list(train = dtrain, test = dtest)
cb = list(cb.evaluation.log()) # to run silent, but store eval. log
# no-subsampling
p1 <- list(objective = "binary:logistic", max_depth = 2, eta = 0.05, nthread = 2)
set.seed(11)
bst1 <- xgb.train(p1, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
tr1 <- xgb.model.dt.tree(model = bst1)
# with subsampling
p2 <- modifyList(p1, list(subsample = 0.1))
set.seed(11)
bst2 <- xgb.train(p2, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
tr2 <- xgb.model.dt.tree(model = bst2)
# the same no-subsampling boosting with an extra 'refresh' updater:
p1r <- modifyList(p1, list(updater = 'grow_colmaker,prune,refresh', refresh_leaf = FALSE))
set.seed(11)
bst1r <- xgb.train(p1r, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
tr1r <- xgb.model.dt.tree(model = bst1r)
# all should be the same when no subsampling
expect_equal(bst1$evaluation_log, bst1r$evaluation_log)
expect_equal(tr1, tr1r, tolerance = 0.00001, check.attributes = FALSE)
# the same boosting with subsampling with an extra 'refresh' updater:
p2r <- modifyList(p2, list(updater = 'grow_colmaker,prune,refresh', refresh_leaf = FALSE))
set.seed(11)
bst2r <- xgb.train(p2r, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
tr2r <- xgb.model.dt.tree(model = bst2r)
# should be the same evaluation but different gains and larger cover
expect_equal(bst2$evaluation_log, bst2r$evaluation_log)
expect_equal(tr2[Feature == 'Leaf']$Quality, tr2r[Feature == 'Leaf']$Quality)
expect_gt(sum(abs(tr2[Feature != 'Leaf']$Quality - tr2r[Feature != 'Leaf']$Quality)), 100)
expect_gt(sum(tr2r$Cover) / sum(tr2$Cover), 1.5)
# process type 'update' for no-subsampling model, refreshing the tree stats AND leaves from training data:
p1u <- modifyList(p1, list(process_type = 'update', updater = 'refresh', refresh_leaf = TRUE))
bst1u <- xgb.train(p1u, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb, xgb_model = bst1)
tr1u <- xgb.model.dt.tree(model = bst1u)
# all should be the same when no subsampling
expect_equal(bst1$evaluation_log, bst1u$evaluation_log)
expect_equal(tr1, tr1u, tolerance = 0.00001, check.attributes = FALSE)
# process type 'update' for model with subsampling, refreshing only the tree stats from training data:
p2u <- modifyList(p2, list(process_type = 'update', updater = 'refresh', refresh_leaf = FALSE))
bst2u <- xgb.train(p2u, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb, xgb_model = bst2)
tr2u <- xgb.model.dt.tree(model = bst2u)
# should be the same evaluation but different gains and larger cover
expect_equal(bst2$evaluation_log, bst2u$evaluation_log)
expect_equal(tr2[Feature == 'Leaf']$Quality, tr2u[Feature == 'Leaf']$Quality)
expect_gt(sum(abs(tr2[Feature != 'Leaf']$Quality - tr2u[Feature != 'Leaf']$Quality)), 100)
expect_gt(sum(tr2u$Cover) / sum(tr2$Cover), 1.5)
# the results should be the same as for the model with an extra 'refresh' updater
expect_equal(bst2r$evaluation_log, bst2u$evaluation_log)
expect_equal(tr2r, tr2u, tolerance = 0.00001, check.attributes = FALSE)
# process type 'update' for no-subsampling model, refreshing only the tree stats from TEST data:
p1ut <- modifyList(p1, list(process_type = 'update', updater = 'refresh', refresh_leaf = FALSE))
bst1ut <- xgb.train(p1ut, dtest, nrounds = 10, watchlist, verbose = 0, callbacks = cb, xgb_model = bst1)
tr1ut <- xgb.model.dt.tree(model = bst1ut)
# should be the same evaluations but different gains and smaller cover (test data is smaller)
expect_equal(bst1$evaluation_log, bst1ut$evaluation_log)
expect_equal(tr1[Feature == 'Leaf']$Quality, tr1ut[Feature == 'Leaf']$Quality)
expect_gt(sum(abs(tr1[Feature != 'Leaf']$Quality - tr1ut[Feature != 'Leaf']$Quality)), 100)
expect_lt(sum(tr1ut$Cover) / sum(tr1$Cover), 0.5)
})

View File

@ -25,10 +25,10 @@ General Parameters
Parameters for Tree Booster
---------------------------
* eta [default=0.3]
* eta [default=0.3, alias: learning_rate]
- step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.
- range: [0,1]
* gamma [default=0]
* gamma [default=0, alias: min_split_loss]
- minimum loss reduction required to make a further partition on a leaf node of the tree. The larger, the more conservative the algorithm will be.
- range: [0,∞]
* max_depth [default=6]
@ -49,9 +49,9 @@ Parameters for Tree Booster
* colsample_bylevel [default=1]
- subsample ratio of columns for each split, in each level.
- range: (0,1]
* lambda [default=1]
* lambda [default=1, alias: reg_lambda]
- L2 regularization term on weights, increase this value will make model more conservative.
* alpha [default=0]
* alpha [default=0, alias: reg_alpha]
- L1 regularization term on weights, increase this value will make model more conservative.
* tree_method, string [default='auto']
- The tree construction algorithm used in XGBoost(see description in the [reference paper](http://arxiv.org/abs/1603.02754))
@ -73,8 +73,27 @@ Parameters for Tree Booster
- range: (0, 1)
* scale_pos_weight, [default=1]
- Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases) See [Parameters Tuning](how_to/param_tuning.md) for more discussion. Also see Higgs Kaggle competition demo for examples: [R](../demo/kaggle-higgs/higgs-train.R ), [py1](../demo/kaggle-higgs/higgs-numpy.py ), [py2](../demo/kaggle-higgs/higgs-cv.py ), [py3](../demo/guide-python/cross_validation.py)
* updater_seq, [default="grow_colmaker,prune"]
- A comma separated string mentioning tThe sequence of Tree updaters that should be run. A tree updater is a pluggable operation performed on the tree at every step using the gradient information. Tree updaters can be registered using the plugin system provided.
* updater, [default='grow_colmaker,prune']
- A comma separated string defining the sequence of tree updaters to run, providing a modular way to construct and to modify the trees. This is an advanced parameter that is usually set automatically, depending on some other parameters. However, it could be also set explicitely by a user. The following updater plugins exist:
- 'grow_colmaker': non-distributed column-based construction of trees.
- 'distcol': distributed tree construction with column-based data splitting mode.
- 'grow_histmaker': distributed tree construction with row-based data splitting based on global proposal of histogram counting.
- 'grow_local_histmaker': based on local histogram counting.
- 'grow_skmaker': uses the approximate sketching algorithm.
- 'sync': synchronizes trees in all distributed nodes.
- 'refresh': refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- 'prune': prunes the splits where loss < min_split_loss (or gamma).
- In a distributed setting, the implicit updater sequence value would be adjusted as follows:
- 'grow_histmaker,prune' when dsplit='row' (or default) and prob_buffer_row == 1 (or default); or when data has multiple sparse pages
- 'grow_histmaker,refresh,prune' when dsplit='row' and prob_buffer_row < 1
- 'distcol' when dsplit='col'
* refresh_leaf, [default=1]
- This is a parameter of the 'refresh' updater plugin. When this flag is true, tree leafs as well as tree nodes' stats are updated. When it is false, only node stats are updated.
* process_type, [default='default']
- A type of boosting process to run.
- Choices: {'default', 'update'}
- 'default': the normal boosting process which creates new trees.
- 'update': starts from an existing model and only updates its trees. In each boosting iteration, a tree from the initial model is taken, a specified sequence of updater plugins is run for that tree, and a modified tree is added to the new model. The new model would have either the same or smaller number of trees, depending on the number of boosting iteratons performed. Currently, the following built-in updater plugins could be meaningfully used with this process type: 'refresh', 'prune'. With 'update', one cannot use updater plugins that create new nrees.
Additional parameters for Dart Booster
--------------------------------------
@ -100,17 +119,21 @@ Additional parameters for Dart Booster
Parameters for Linear Booster
-----------------------------
* lambda [default=0]
* lambda [default=0, alias: reg_lambda]
- L2 regularization term on weights, increase this value will make model more conservative.
* alpha [default=0]
* alpha [default=0, alias: reg_alpha]
- L1 regularization term on weights, increase this value will make model more conservative.
* lambda_bias
- L2 regularization term on bias, default 0 (no L1 reg on bias because it is not important)
* lambda_bias [default=0, alias: reg_lambda_bias]
- L2 regularization term on bias (no L1 reg on bias because it is not important)
Parameters for Tweedie Regression
---------------------------------
* tweedie_variance_power [default=1.5]
- Parameter that controls the variance of the tweedie distribution. Set closer to 2 to shift towards a gamma distribution and closer to 1 to shift towards a poisson distribution.
- parameter that controls the variance of the Tweedie distribution
- var(y) ~ E(y)^tweedie_variance_power
- range: (1,2)
- set closer to 2 to shift towards a gamma distribution
- set closer to 1 to shift towards a Poisson distribution.
Learning Task Parameters
------------------------
@ -125,9 +148,8 @@ Specify the learning task and the corresponding learning objective. The objectiv
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
- "multi:softprob" --same as softmax, but output a vector of ndata * nclass, which can be further reshaped to ndata, nclass matrix. The result contains predicted probability of each data point belonging to each class.
- "rank:pairwise" --set XGBoost to do ranking task by minimizing the pairwise loss
- "reg:gamma" --gamma regression for severity data, output mean of gamma distribution
- "reg:tweedie" --tweedie regression for insurance data
- tweedie_variance_power is set to 1.5 by default in tweedie regression and must be in the range [1, 2)
- "reg:gamma" --gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be [gamma-distributed](https://en.wikipedia.org/wiki/Gamma_distribution#Applications)
- "reg:tweedie" --Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be [Tweedie-distributed](https://en.wikipedia.org/wiki/Tweedie_distribution#Applications).
* base_score [ default=0.5 ]
- the initial prediction score of all instances, global bias
- for sufficient number of iterations, changing this value will not have too much effect.
@ -139,6 +161,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
- "mae": [mean absolute error](https://en.wikipedia.org/wiki/Mean_absolute_error)
- "logloss": negative [log-likelihood](http://en.wikipedia.org/wiki/Log-likelihood)
- "error": Binary classification error rate. It is calculated as #(wrong cases)/#(all cases). For the predictions, the evaluation will regard the instances with prediction value larger than 0.5 as positive instances, and the others as negative instances.
- "error@t": a different than 0.5 binary classification threshold value could be specified by providing a numerical value through 't'.
- "merror": Multiclass classification error rate. It is calculated as #(wrong cases)/#(all cases).
- "mlogloss": [Multiclass logloss](https://www.kaggle.com/wiki/MultiClassLogLoss)
- "auc": [Area under the curve](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_curve) for ranking evaluation.
@ -147,7 +170,10 @@ Specify the learning task and the corresponding learning objective. The objectiv
- "ndcg@n","map@n": n can be assigned as an integer to cut off the top positions in the lists for evaluation.
- "ndcg-","map-","ndcg@n-","map@n-": In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions.
training repeatedly
- "gamma-deviance": [residual deviance for gamma regression]
- "poisson-nloglik": negative log-likelihood for Poisson regression
- "gamma-nloglik": negative log-likelihood for gamma regression
- "gamma-deviance": residual deviance for gamma regression
- "tweedie-nloglik": negative log-likelihood for Tweedie regression (at a specified value of the tweedie_variance_power parameter)
* seed [ default=0 ]
- random number seed.

View File

@ -26,6 +26,12 @@ namespace gbm {
DMLC_REGISTRY_FILE_TAG(gbtree);
// boosting process types
enum TreeProcessType {
kDefault,
kUpdate
};
/*! \brief training parameters */
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
/*!
@ -35,13 +41,24 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
int num_parallel_tree;
/*! \brief tree updater sequence */
std::string updater_seq;
/*! \brief type of boosting process to run */
int process_type;
// declare parameters
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
DMLC_DECLARE_FIELD(num_parallel_tree).set_lower_bound(1).set_default(1)
DMLC_DECLARE_FIELD(num_parallel_tree)
.set_default(1)
.set_lower_bound(1)
.describe("Number of parallel trees constructed during each iteration."\
" This option is used to support boosted random forest");
DMLC_DECLARE_FIELD(updater_seq).set_default("grow_colmaker,prune")
DMLC_DECLARE_FIELD(updater_seq)
.set_default("grow_colmaker,prune")
.describe("Tree updater sequence.");
DMLC_DECLARE_FIELD(process_type)
.set_default(kDefault)
.add_enum("default", kDefault)
.add_enum("update", kUpdate)
.describe("Whether to run the normal boosting process that creates new trees,"\
" or to update the trees in an existing model.");
// add alias
DMLC_DECLARE_ALIAS(updater_seq, updater);
}
@ -63,21 +80,30 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
float learning_rate;
// declare parameters
DMLC_DECLARE_PARAMETER(DartTrainParam) {
DMLC_DECLARE_FIELD(silent).set_default(false)
DMLC_DECLARE_FIELD(silent)
.set_default(false)
.describe("Not print information during training.");
DMLC_DECLARE_FIELD(sample_type).set_default(0)
DMLC_DECLARE_FIELD(sample_type)
.set_default(0)
.add_enum("uniform", 0)
.add_enum("weighted", 1)
.describe("Different types of sampling algorithm.");
DMLC_DECLARE_FIELD(normalize_type).set_default(0)
DMLC_DECLARE_FIELD(normalize_type)
.set_default(0)
.add_enum("tree", 0)
.add_enum("forest", 1)
.describe("Different types of normalization algorithm.");
DMLC_DECLARE_FIELD(rate_drop).set_range(0.0f, 1.0f).set_default(0.0f)
DMLC_DECLARE_FIELD(rate_drop)
.set_range(0.0f, 1.0f)
.set_default(0.0f)
.describe("Parameter of how many trees are dropped.");
DMLC_DECLARE_FIELD(skip_drop).set_range(0.0f, 1.0f).set_default(0.0f)
DMLC_DECLARE_FIELD(skip_drop)
.set_range(0.0f, 1.0f)
.set_default(0.0f)
.describe("Parameter of whether to drop trees.");
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
DMLC_DECLARE_FIELD(learning_rate)
.set_lower_bound(0.0f)
.set_default(0.3f)
.describe("Learning rate(step size) of update.");
DMLC_DECLARE_ALIAS(learning_rate, eta);
}
@ -157,12 +183,21 @@ class GBTree : public GradientBooster {
for (const auto& up : updaters) {
up->Init(cfg);
}
// for the 'update' process_type, move trees into trees_to_update
if (tparam.process_type == kUpdate && trees_to_update.size() == 0u) {
for (size_t i = 0; i < trees.size(); ++i) {
trees_to_update.push_back(std::move(trees[i]));
}
trees.clear();
mparam.num_trees = 0;
}
}
void Load(dmlc::Stream* fi) override {
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
<< "GBTree: invalid model file";
trees.clear();
trees_to_update.clear();
for (int i = 0; i < mparam.num_trees; ++i) {
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->Load(fi);
@ -386,11 +421,20 @@ class GBTree : public GradientBooster {
ret->clear();
// create the trees
for (int i = 0; i < tparam.num_parallel_tree; ++i) {
if (tparam.process_type == kDefault) {
// create new tree
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->param.InitAllowUnknown(this->cfg);
ptr->InitModel();
new_trees.push_back(ptr.get());
ret->push_back(std::move(ptr));
} else if (tparam.process_type == kUpdate) {
CHECK_LT(trees.size(), trees_to_update.size());
// move an existing tree from trees_to_update
auto t = std::move(trees_to_update[trees.size()]);
new_trees.push_back(t.get());
ret->push_back(std::move(t));
}
}
// update the trees
for (auto& up : updaters) {
@ -493,6 +537,8 @@ class GBTree : public GradientBooster {
GBTreeModelParam mparam;
/*! \brief vector of trees stored in the model */
std::vector<std::unique_ptr<RegTree> > trees;
/*! \brief for the update process, a place to keep the initial trees */
std::vector<std::unique_ptr<RegTree> > trees_to_update;
/*! \brief some information indicator of the tree, reserved */
std::vector<int> tree_info;
// ----training fields----

View File

@ -64,6 +64,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
bool cache_opt;
// whether to not print info during training.
bool silent;
// whether refresh updater needs to update the leaf values
bool refresh_leaf;
// auxiliary data structure
std::vector<int> monotone_constraints;
// declare the parameters
@ -75,10 +77,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(min_split_loss)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe(
"Minimum loss reduction required to make a further partition.");
DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6).describe(
"Maximum depth of the tree.");
.describe("Minimum loss reduction required to make a further partition.");
DMLC_DECLARE_FIELD(max_depth)
.set_lower_bound(0)
.set_default(6)
.describe("Maximum depth of the tree.");
DMLC_DECLARE_FIELD(min_child_weight)
.set_lower_bound(0.0f)
.set_default(1.0f)
@ -100,8 +103,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(max_delta_step)
.set_lower_bound(0.0f)
.set_default(0.0f)
.describe(
"Maximum delta step we allow each tree's weight estimate to be. "
.describe("Maximum delta step we allow each tree's weight estimate to be. "\
"If the value is set to 0, it means there is no constraint");
DMLC_DECLARE_FIELD(subsample)
.set_range(0.0f, 1.0f)
@ -114,8 +116,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(colsample_bytree)
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe(
"Subsample ratio of columns, resample on each tree construction.");
.describe("Subsample ratio of columns, resample on each tree construction.");
DMLC_DECLARE_FIELD(opt_dense_col)
.set_range(0.0f, 1.0f)
.set_default(1.0f)
@ -127,8 +128,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(sketch_ratio)
.set_lower_bound(0.0f)
.set_default(2.0f)
.describe("EXP Param: Sketch accuracy related parameter of approximate "
"algorithm.");
.describe("EXP Param: Sketch accuracy related parameter of approximate algorithm.");
DMLC_DECLARE_FIELD(size_leaf_vector)
.set_lower_bound(0)
.set_default(0)
@ -136,10 +136,15 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
DMLC_DECLARE_FIELD(parallel_option)
.set_default(0)
.describe("Different types of parallelization algorithm.");
DMLC_DECLARE_FIELD(cache_opt).set_default(true).describe(
"EXP Param: Cache aware optimization.");
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
"Do not print information during training.");
DMLC_DECLARE_FIELD(cache_opt)
.set_default(true)
.describe("EXP Param: Cache aware optimization.");
DMLC_DECLARE_FIELD(silent)
.set_default(false)
.describe("Do not print information during trainig.");
DMLC_DECLARE_FIELD(refresh_leaf)
.set_default(true)
.describe("Whether the refresh updater needs to update leaf values.");
DMLC_DECLARE_FIELD(monotone_constraints)
.set_default(std::vector<int>())
.describe("Constraint of variable monotonicity");

View File

@ -134,7 +134,9 @@ class TreeRefresher: public TreeUpdater {
tree.stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
if (tree[nid].is_leaf()) {
if (param.refresh_leaf) {
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
}
} else {
tree.stat(nid).loss_chg = static_cast<bst_float>(
gstats[tree[nid].cleft()].CalcGain(param) +