[CORE] The update process for a tree model, and its application to feature importance (#1670)
* [CORE] allow updating trees in an existing model * [CORE] in refresh updater, allow keeping old leaf values and update stats only * [R-package] xgb.train mod to allow updating trees in an existing model * [R-package] added check for nrounds when is_update * [CORE] merge parameter declaration changes; unify their code style * [CORE] move the update-process trees initialization to Configure; rename default process_type to 'default'; fix the trees and trees_to_update sizes comparison check * [R-package] unit tests for the update process type * [DOC] documentation for process_type parameter; improved docs for updater, Gamma and Tweedie; added some parameter aliases; metrics indentation and some were non-documented * fix my sloppy merge conflict resolutions * [CORE] add a TreeProcessType enum * whitespace fix
This commit is contained in:
parent
4398fbbe4a
commit
a44032d095
@ -284,7 +284,9 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
|||||||
# Sort the callbacks into categories
|
# Sort the callbacks into categories
|
||||||
cb <- categorize.callbacks(callbacks)
|
cb <- categorize.callbacks(callbacks)
|
||||||
|
|
||||||
|
# The tree updating process would need slightly different handling
|
||||||
|
is_update <- NVL(params[['process_type']], '.') == 'update'
|
||||||
|
|
||||||
# Construct a booster (either a new one or load from xgb_model)
|
# Construct a booster (either a new one or load from xgb_model)
|
||||||
handle <- xgb.Booster(params, append(watchlist, dtrain), xgb_model)
|
handle <- xgb.Booster(params, append(watchlist, dtrain), xgb_model)
|
||||||
bst <- xgb.handleToBooster(handle)
|
bst <- xgb.handleToBooster(handle)
|
||||||
@ -294,17 +296,20 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
|||||||
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1)
|
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1)
|
||||||
|
|
||||||
# When the 'xgb_model' was set, find out how many boosting iterations it has
|
# When the 'xgb_model' was set, find out how many boosting iterations it has
|
||||||
niter_skip <- 0
|
niter_init <- 0
|
||||||
if (!is.null(xgb_model)) {
|
if (!is.null(xgb_model)) {
|
||||||
niter_skip <- as.numeric(xgb.attr(bst, 'niter')) + 1
|
niter_init <- as.numeric(xgb.attr(bst, 'niter')) + 1
|
||||||
if (length(niter_skip) == 0) {
|
if (length(niter_init) == 0) {
|
||||||
niter_skip <- xgb.ntree(bst) %/% (num_parallel_tree * num_class)
|
niter_init <- xgb.ntree(bst) %/% (num_parallel_tree * num_class)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(is_update && nrounds > niter_init)
|
||||||
|
stop("nrounds cannot be larger than ", niter_init, " (nrounds of xgb_model)")
|
||||||
|
|
||||||
# TODO: distributed code
|
# TODO: distributed code
|
||||||
rank <- 0
|
rank <- 0
|
||||||
|
|
||||||
|
niter_skip <- ifelse(is_update, 0, niter_init)
|
||||||
begin_iteration <- niter_skip + 1
|
begin_iteration <- niter_skip + 1
|
||||||
end_iteration <- niter_skip + nrounds
|
end_iteration <- niter_skip + nrounds
|
||||||
|
|
||||||
@ -337,6 +342,7 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
|||||||
nrow(evaluation_log) > 0) {
|
nrow(evaluation_log) > 0) {
|
||||||
# include the previous compatible history when available
|
# include the previous compatible history when available
|
||||||
if (class(xgb_model) == 'xgb.Booster' &&
|
if (class(xgb_model) == 'xgb.Booster' &&
|
||||||
|
!is_update &&
|
||||||
!is.null(xgb_model$evaluation_log) &&
|
!is.null(xgb_model$evaluation_log) &&
|
||||||
all.equal(colnames(evaluation_log),
|
all.equal(colnames(evaluation_log),
|
||||||
colnames(xgb_model$evaluation_log))) {
|
colnames(xgb_model$evaluation_log))) {
|
||||||
|
|||||||
76
R-package/tests/testthat/test_update.R
Normal file
76
R-package/tests/testthat/test_update.R
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
require(xgboost)
|
||||||
|
|
||||||
|
context("update trees in an existing model")
|
||||||
|
|
||||||
|
data(agaricus.train, package = 'xgboost')
|
||||||
|
data(agaricus.test, package = 'xgboost')
|
||||||
|
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
|
||||||
|
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||||
|
|
||||||
|
test_that("updating the model works", {
|
||||||
|
watchlist = list(train = dtrain, test = dtest)
|
||||||
|
cb = list(cb.evaluation.log()) # to run silent, but store eval. log
|
||||||
|
|
||||||
|
# no-subsampling
|
||||||
|
p1 <- list(objective = "binary:logistic", max_depth = 2, eta = 0.05, nthread = 2)
|
||||||
|
set.seed(11)
|
||||||
|
bst1 <- xgb.train(p1, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
|
||||||
|
tr1 <- xgb.model.dt.tree(model = bst1)
|
||||||
|
|
||||||
|
# with subsampling
|
||||||
|
p2 <- modifyList(p1, list(subsample = 0.1))
|
||||||
|
set.seed(11)
|
||||||
|
bst2 <- xgb.train(p2, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
|
||||||
|
tr2 <- xgb.model.dt.tree(model = bst2)
|
||||||
|
|
||||||
|
# the same no-subsampling boosting with an extra 'refresh' updater:
|
||||||
|
p1r <- modifyList(p1, list(updater = 'grow_colmaker,prune,refresh', refresh_leaf = FALSE))
|
||||||
|
set.seed(11)
|
||||||
|
bst1r <- xgb.train(p1r, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
|
||||||
|
tr1r <- xgb.model.dt.tree(model = bst1r)
|
||||||
|
# all should be the same when no subsampling
|
||||||
|
expect_equal(bst1$evaluation_log, bst1r$evaluation_log)
|
||||||
|
expect_equal(tr1, tr1r, tolerance = 0.00001, check.attributes = FALSE)
|
||||||
|
|
||||||
|
# the same boosting with subsampling with an extra 'refresh' updater:
|
||||||
|
p2r <- modifyList(p2, list(updater = 'grow_colmaker,prune,refresh', refresh_leaf = FALSE))
|
||||||
|
set.seed(11)
|
||||||
|
bst2r <- xgb.train(p2r, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb)
|
||||||
|
tr2r <- xgb.model.dt.tree(model = bst2r)
|
||||||
|
# should be the same evaluation but different gains and larger cover
|
||||||
|
expect_equal(bst2$evaluation_log, bst2r$evaluation_log)
|
||||||
|
expect_equal(tr2[Feature == 'Leaf']$Quality, tr2r[Feature == 'Leaf']$Quality)
|
||||||
|
expect_gt(sum(abs(tr2[Feature != 'Leaf']$Quality - tr2r[Feature != 'Leaf']$Quality)), 100)
|
||||||
|
expect_gt(sum(tr2r$Cover) / sum(tr2$Cover), 1.5)
|
||||||
|
|
||||||
|
# process type 'update' for no-subsampling model, refreshing the tree stats AND leaves from training data:
|
||||||
|
p1u <- modifyList(p1, list(process_type = 'update', updater = 'refresh', refresh_leaf = TRUE))
|
||||||
|
bst1u <- xgb.train(p1u, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb, xgb_model = bst1)
|
||||||
|
tr1u <- xgb.model.dt.tree(model = bst1u)
|
||||||
|
# all should be the same when no subsampling
|
||||||
|
expect_equal(bst1$evaluation_log, bst1u$evaluation_log)
|
||||||
|
expect_equal(tr1, tr1u, tolerance = 0.00001, check.attributes = FALSE)
|
||||||
|
|
||||||
|
# process type 'update' for model with subsampling, refreshing only the tree stats from training data:
|
||||||
|
p2u <- modifyList(p2, list(process_type = 'update', updater = 'refresh', refresh_leaf = FALSE))
|
||||||
|
bst2u <- xgb.train(p2u, dtrain, nrounds = 10, watchlist, verbose = 0, callbacks = cb, xgb_model = bst2)
|
||||||
|
tr2u <- xgb.model.dt.tree(model = bst2u)
|
||||||
|
# should be the same evaluation but different gains and larger cover
|
||||||
|
expect_equal(bst2$evaluation_log, bst2u$evaluation_log)
|
||||||
|
expect_equal(tr2[Feature == 'Leaf']$Quality, tr2u[Feature == 'Leaf']$Quality)
|
||||||
|
expect_gt(sum(abs(tr2[Feature != 'Leaf']$Quality - tr2u[Feature != 'Leaf']$Quality)), 100)
|
||||||
|
expect_gt(sum(tr2u$Cover) / sum(tr2$Cover), 1.5)
|
||||||
|
# the results should be the same as for the model with an extra 'refresh' updater
|
||||||
|
expect_equal(bst2r$evaluation_log, bst2u$evaluation_log)
|
||||||
|
expect_equal(tr2r, tr2u, tolerance = 0.00001, check.attributes = FALSE)
|
||||||
|
|
||||||
|
# process type 'update' for no-subsampling model, refreshing only the tree stats from TEST data:
|
||||||
|
p1ut <- modifyList(p1, list(process_type = 'update', updater = 'refresh', refresh_leaf = FALSE))
|
||||||
|
bst1ut <- xgb.train(p1ut, dtest, nrounds = 10, watchlist, verbose = 0, callbacks = cb, xgb_model = bst1)
|
||||||
|
tr1ut <- xgb.model.dt.tree(model = bst1ut)
|
||||||
|
# should be the same evaluations but different gains and smaller cover (test data is smaller)
|
||||||
|
expect_equal(bst1$evaluation_log, bst1ut$evaluation_log)
|
||||||
|
expect_equal(tr1[Feature == 'Leaf']$Quality, tr1ut[Feature == 'Leaf']$Quality)
|
||||||
|
expect_gt(sum(abs(tr1[Feature != 'Leaf']$Quality - tr1ut[Feature != 'Leaf']$Quality)), 100)
|
||||||
|
expect_lt(sum(tr1ut$Cover) / sum(tr1$Cover), 0.5)
|
||||||
|
})
|
||||||
@ -25,10 +25,10 @@ General Parameters
|
|||||||
|
|
||||||
Parameters for Tree Booster
|
Parameters for Tree Booster
|
||||||
---------------------------
|
---------------------------
|
||||||
* eta [default=0.3]
|
* eta [default=0.3, alias: learning_rate]
|
||||||
- step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.
|
- step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.
|
||||||
- range: [0,1]
|
- range: [0,1]
|
||||||
* gamma [default=0]
|
* gamma [default=0, alias: min_split_loss]
|
||||||
- minimum loss reduction required to make a further partition on a leaf node of the tree. The larger, the more conservative the algorithm will be.
|
- minimum loss reduction required to make a further partition on a leaf node of the tree. The larger, the more conservative the algorithm will be.
|
||||||
- range: [0,∞]
|
- range: [0,∞]
|
||||||
* max_depth [default=6]
|
* max_depth [default=6]
|
||||||
@ -49,9 +49,9 @@ Parameters for Tree Booster
|
|||||||
* colsample_bylevel [default=1]
|
* colsample_bylevel [default=1]
|
||||||
- subsample ratio of columns for each split, in each level.
|
- subsample ratio of columns for each split, in each level.
|
||||||
- range: (0,1]
|
- range: (0,1]
|
||||||
* lambda [default=1]
|
* lambda [default=1, alias: reg_lambda]
|
||||||
- L2 regularization term on weights, increase this value will make model more conservative.
|
- L2 regularization term on weights, increase this value will make model more conservative.
|
||||||
* alpha [default=0]
|
* alpha [default=0, alias: reg_alpha]
|
||||||
- L1 regularization term on weights, increase this value will make model more conservative.
|
- L1 regularization term on weights, increase this value will make model more conservative.
|
||||||
* tree_method, string [default='auto']
|
* tree_method, string [default='auto']
|
||||||
- The tree construction algorithm used in XGBoost(see description in the [reference paper](http://arxiv.org/abs/1603.02754))
|
- The tree construction algorithm used in XGBoost(see description in the [reference paper](http://arxiv.org/abs/1603.02754))
|
||||||
@ -73,8 +73,27 @@ Parameters for Tree Booster
|
|||||||
- range: (0, 1)
|
- range: (0, 1)
|
||||||
* scale_pos_weight, [default=1]
|
* scale_pos_weight, [default=1]
|
||||||
- Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases) See [Parameters Tuning](how_to/param_tuning.md) for more discussion. Also see Higgs Kaggle competition demo for examples: [R](../demo/kaggle-higgs/higgs-train.R ), [py1](../demo/kaggle-higgs/higgs-numpy.py ), [py2](../demo/kaggle-higgs/higgs-cv.py ), [py3](../demo/guide-python/cross_validation.py)
|
- Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: sum(negative cases) / sum(positive cases) See [Parameters Tuning](how_to/param_tuning.md) for more discussion. Also see Higgs Kaggle competition demo for examples: [R](../demo/kaggle-higgs/higgs-train.R ), [py1](../demo/kaggle-higgs/higgs-numpy.py ), [py2](../demo/kaggle-higgs/higgs-cv.py ), [py3](../demo/guide-python/cross_validation.py)
|
||||||
* updater_seq, [default="grow_colmaker,prune"]
|
* updater, [default='grow_colmaker,prune']
|
||||||
- A comma separated string mentioning tThe sequence of Tree updaters that should be run. A tree updater is a pluggable operation performed on the tree at every step using the gradient information. Tree updaters can be registered using the plugin system provided.
|
- A comma separated string defining the sequence of tree updaters to run, providing a modular way to construct and to modify the trees. This is an advanced parameter that is usually set automatically, depending on some other parameters. However, it could be also set explicitely by a user. The following updater plugins exist:
|
||||||
|
- 'grow_colmaker': non-distributed column-based construction of trees.
|
||||||
|
- 'distcol': distributed tree construction with column-based data splitting mode.
|
||||||
|
- 'grow_histmaker': distributed tree construction with row-based data splitting based on global proposal of histogram counting.
|
||||||
|
- 'grow_local_histmaker': based on local histogram counting.
|
||||||
|
- 'grow_skmaker': uses the approximate sketching algorithm.
|
||||||
|
- 'sync': synchronizes trees in all distributed nodes.
|
||||||
|
- 'refresh': refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
|
||||||
|
- 'prune': prunes the splits where loss < min_split_loss (or gamma).
|
||||||
|
- In a distributed setting, the implicit updater sequence value would be adjusted as follows:
|
||||||
|
- 'grow_histmaker,prune' when dsplit='row' (or default) and prob_buffer_row == 1 (or default); or when data has multiple sparse pages
|
||||||
|
- 'grow_histmaker,refresh,prune' when dsplit='row' and prob_buffer_row < 1
|
||||||
|
- 'distcol' when dsplit='col'
|
||||||
|
* refresh_leaf, [default=1]
|
||||||
|
- This is a parameter of the 'refresh' updater plugin. When this flag is true, tree leafs as well as tree nodes' stats are updated. When it is false, only node stats are updated.
|
||||||
|
* process_type, [default='default']
|
||||||
|
- A type of boosting process to run.
|
||||||
|
- Choices: {'default', 'update'}
|
||||||
|
- 'default': the normal boosting process which creates new trees.
|
||||||
|
- 'update': starts from an existing model and only updates its trees. In each boosting iteration, a tree from the initial model is taken, a specified sequence of updater plugins is run for that tree, and a modified tree is added to the new model. The new model would have either the same or smaller number of trees, depending on the number of boosting iteratons performed. Currently, the following built-in updater plugins could be meaningfully used with this process type: 'refresh', 'prune'. With 'update', one cannot use updater plugins that create new nrees.
|
||||||
|
|
||||||
Additional parameters for Dart Booster
|
Additional parameters for Dart Booster
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
@ -100,17 +119,21 @@ Additional parameters for Dart Booster
|
|||||||
|
|
||||||
Parameters for Linear Booster
|
Parameters for Linear Booster
|
||||||
-----------------------------
|
-----------------------------
|
||||||
* lambda [default=0]
|
* lambda [default=0, alias: reg_lambda]
|
||||||
- L2 regularization term on weights, increase this value will make model more conservative.
|
- L2 regularization term on weights, increase this value will make model more conservative.
|
||||||
* alpha [default=0]
|
* alpha [default=0, alias: reg_alpha]
|
||||||
- L1 regularization term on weights, increase this value will make model more conservative.
|
- L1 regularization term on weights, increase this value will make model more conservative.
|
||||||
* lambda_bias
|
* lambda_bias [default=0, alias: reg_lambda_bias]
|
||||||
- L2 regularization term on bias, default 0 (no L1 reg on bias because it is not important)
|
- L2 regularization term on bias (no L1 reg on bias because it is not important)
|
||||||
|
|
||||||
Parameters for Tweedie Regression
|
Parameters for Tweedie Regression
|
||||||
---------------------------------
|
---------------------------------
|
||||||
* tweedie_variance_power [default=1.5]
|
* tweedie_variance_power [default=1.5]
|
||||||
- Parameter that controls the variance of the tweedie distribution. Set closer to 2 to shift towards a gamma distribution and closer to 1 to shift towards a poisson distribution.
|
- parameter that controls the variance of the Tweedie distribution
|
||||||
|
- var(y) ~ E(y)^tweedie_variance_power
|
||||||
|
- range: (1,2)
|
||||||
|
- set closer to 2 to shift towards a gamma distribution
|
||||||
|
- set closer to 1 to shift towards a Poisson distribution.
|
||||||
|
|
||||||
Learning Task Parameters
|
Learning Task Parameters
|
||||||
------------------------
|
------------------------
|
||||||
@ -125,9 +148,8 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
|||||||
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
|
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
|
||||||
- "multi:softprob" --same as softmax, but output a vector of ndata * nclass, which can be further reshaped to ndata, nclass matrix. The result contains predicted probability of each data point belonging to each class.
|
- "multi:softprob" --same as softmax, but output a vector of ndata * nclass, which can be further reshaped to ndata, nclass matrix. The result contains predicted probability of each data point belonging to each class.
|
||||||
- "rank:pairwise" --set XGBoost to do ranking task by minimizing the pairwise loss
|
- "rank:pairwise" --set XGBoost to do ranking task by minimizing the pairwise loss
|
||||||
- "reg:gamma" --gamma regression for severity data, output mean of gamma distribution
|
- "reg:gamma" --gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be [gamma-distributed](https://en.wikipedia.org/wiki/Gamma_distribution#Applications)
|
||||||
- "reg:tweedie" --tweedie regression for insurance data
|
- "reg:tweedie" --Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be [Tweedie-distributed](https://en.wikipedia.org/wiki/Tweedie_distribution#Applications).
|
||||||
- tweedie_variance_power is set to 1.5 by default in tweedie regression and must be in the range [1, 2)
|
|
||||||
* base_score [ default=0.5 ]
|
* base_score [ default=0.5 ]
|
||||||
- the initial prediction score of all instances, global bias
|
- the initial prediction score of all instances, global bias
|
||||||
- for sufficient number of iterations, changing this value will not have too much effect.
|
- for sufficient number of iterations, changing this value will not have too much effect.
|
||||||
@ -135,19 +157,23 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
|||||||
- evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and error for classification, mean average precision for ranking )
|
- evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and error for classification, mean average precision for ranking )
|
||||||
- User can add multiple evaluation metrics, for python user, remember to pass the metrics in as list of parameters pairs instead of map, so that latter 'eval_metric' won't override previous one
|
- User can add multiple evaluation metrics, for python user, remember to pass the metrics in as list of parameters pairs instead of map, so that latter 'eval_metric' won't override previous one
|
||||||
- The choices are listed below:
|
- The choices are listed below:
|
||||||
- "rmse": [root mean square error](http://en.wikipedia.org/wiki/Root_mean_square_error)
|
- "rmse": [root mean square error](http://en.wikipedia.org/wiki/Root_mean_square_error)
|
||||||
- "mae": [mean absolute error](https://en.wikipedia.org/wiki/Mean_absolute_error)
|
- "mae": [mean absolute error](https://en.wikipedia.org/wiki/Mean_absolute_error)
|
||||||
- "logloss": negative [log-likelihood](http://en.wikipedia.org/wiki/Log-likelihood)
|
- "logloss": negative [log-likelihood](http://en.wikipedia.org/wiki/Log-likelihood)
|
||||||
- "error": Binary classification error rate. It is calculated as #(wrong cases)/#(all cases). For the predictions, the evaluation will regard the instances with prediction value larger than 0.5 as positive instances, and the others as negative instances.
|
- "error": Binary classification error rate. It is calculated as #(wrong cases)/#(all cases). For the predictions, the evaluation will regard the instances with prediction value larger than 0.5 as positive instances, and the others as negative instances.
|
||||||
- "merror": Multiclass classification error rate. It is calculated as #(wrong cases)/#(all cases).
|
- "error@t": a different than 0.5 binary classification threshold value could be specified by providing a numerical value through 't'.
|
||||||
- "mlogloss": [Multiclass logloss](https://www.kaggle.com/wiki/MultiClassLogLoss)
|
- "merror": Multiclass classification error rate. It is calculated as #(wrong cases)/#(all cases).
|
||||||
- "auc": [Area under the curve](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_curve) for ranking evaluation.
|
- "mlogloss": [Multiclass logloss](https://www.kaggle.com/wiki/MultiClassLogLoss)
|
||||||
- "ndcg":[Normalized Discounted Cumulative Gain](http://en.wikipedia.org/wiki/NDCG)
|
- "auc": [Area under the curve](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_curve) for ranking evaluation.
|
||||||
- "map":[Mean average precision](http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision)
|
- "ndcg":[Normalized Discounted Cumulative Gain](http://en.wikipedia.org/wiki/NDCG)
|
||||||
- "ndcg@n","map@n": n can be assigned as an integer to cut off the top positions in the lists for evaluation.
|
- "map":[Mean average precision](http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision)
|
||||||
- "ndcg-","map-","ndcg@n-","map@n-": In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions.
|
- "ndcg@n","map@n": n can be assigned as an integer to cut off the top positions in the lists for evaluation.
|
||||||
|
- "ndcg-","map-","ndcg@n-","map@n-": In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions.
|
||||||
training repeatedly
|
training repeatedly
|
||||||
- "gamma-deviance": [residual deviance for gamma regression]
|
- "poisson-nloglik": negative log-likelihood for Poisson regression
|
||||||
|
- "gamma-nloglik": negative log-likelihood for gamma regression
|
||||||
|
- "gamma-deviance": residual deviance for gamma regression
|
||||||
|
- "tweedie-nloglik": negative log-likelihood for Tweedie regression (at a specified value of the tweedie_variance_power parameter)
|
||||||
* seed [ default=0 ]
|
* seed [ default=0 ]
|
||||||
- random number seed.
|
- random number seed.
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,12 @@ namespace gbm {
|
|||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(gbtree);
|
DMLC_REGISTRY_FILE_TAG(gbtree);
|
||||||
|
|
||||||
|
// boosting process types
|
||||||
|
enum TreeProcessType {
|
||||||
|
kDefault,
|
||||||
|
kUpdate
|
||||||
|
};
|
||||||
|
|
||||||
/*! \brief training parameters */
|
/*! \brief training parameters */
|
||||||
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||||
/*!
|
/*!
|
||||||
@ -35,13 +41,24 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
|||||||
int num_parallel_tree;
|
int num_parallel_tree;
|
||||||
/*! \brief tree updater sequence */
|
/*! \brief tree updater sequence */
|
||||||
std::string updater_seq;
|
std::string updater_seq;
|
||||||
|
/*! \brief type of boosting process to run */
|
||||||
|
int process_type;
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
DMLC_DECLARE_PARAMETER(GBTreeTrainParam) {
|
||||||
DMLC_DECLARE_FIELD(num_parallel_tree).set_lower_bound(1).set_default(1)
|
DMLC_DECLARE_FIELD(num_parallel_tree)
|
||||||
|
.set_default(1)
|
||||||
|
.set_lower_bound(1)
|
||||||
.describe("Number of parallel trees constructed during each iteration."\
|
.describe("Number of parallel trees constructed during each iteration."\
|
||||||
" This option is used to support boosted random forest");
|
" This option is used to support boosted random forest");
|
||||||
DMLC_DECLARE_FIELD(updater_seq).set_default("grow_colmaker,prune")
|
DMLC_DECLARE_FIELD(updater_seq)
|
||||||
|
.set_default("grow_colmaker,prune")
|
||||||
.describe("Tree updater sequence.");
|
.describe("Tree updater sequence.");
|
||||||
|
DMLC_DECLARE_FIELD(process_type)
|
||||||
|
.set_default(kDefault)
|
||||||
|
.add_enum("default", kDefault)
|
||||||
|
.add_enum("update", kUpdate)
|
||||||
|
.describe("Whether to run the normal boosting process that creates new trees,"\
|
||||||
|
" or to update the trees in an existing model.");
|
||||||
// add alias
|
// add alias
|
||||||
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||||
}
|
}
|
||||||
@ -63,21 +80,30 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
|||||||
float learning_rate;
|
float learning_rate;
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(DartTrainParam) {
|
DMLC_DECLARE_PARAMETER(DartTrainParam) {
|
||||||
DMLC_DECLARE_FIELD(silent).set_default(false)
|
DMLC_DECLARE_FIELD(silent)
|
||||||
|
.set_default(false)
|
||||||
.describe("Not print information during training.");
|
.describe("Not print information during training.");
|
||||||
DMLC_DECLARE_FIELD(sample_type).set_default(0)
|
DMLC_DECLARE_FIELD(sample_type)
|
||||||
|
.set_default(0)
|
||||||
.add_enum("uniform", 0)
|
.add_enum("uniform", 0)
|
||||||
.add_enum("weighted", 1)
|
.add_enum("weighted", 1)
|
||||||
.describe("Different types of sampling algorithm.");
|
.describe("Different types of sampling algorithm.");
|
||||||
DMLC_DECLARE_FIELD(normalize_type).set_default(0)
|
DMLC_DECLARE_FIELD(normalize_type)
|
||||||
|
.set_default(0)
|
||||||
.add_enum("tree", 0)
|
.add_enum("tree", 0)
|
||||||
.add_enum("forest", 1)
|
.add_enum("forest", 1)
|
||||||
.describe("Different types of normalization algorithm.");
|
.describe("Different types of normalization algorithm.");
|
||||||
DMLC_DECLARE_FIELD(rate_drop).set_range(0.0f, 1.0f).set_default(0.0f)
|
DMLC_DECLARE_FIELD(rate_drop)
|
||||||
|
.set_range(0.0f, 1.0f)
|
||||||
|
.set_default(0.0f)
|
||||||
.describe("Parameter of how many trees are dropped.");
|
.describe("Parameter of how many trees are dropped.");
|
||||||
DMLC_DECLARE_FIELD(skip_drop).set_range(0.0f, 1.0f).set_default(0.0f)
|
DMLC_DECLARE_FIELD(skip_drop)
|
||||||
|
.set_range(0.0f, 1.0f)
|
||||||
|
.set_default(0.0f)
|
||||||
.describe("Parameter of whether to drop trees.");
|
.describe("Parameter of whether to drop trees.");
|
||||||
DMLC_DECLARE_FIELD(learning_rate).set_lower_bound(0.0f).set_default(0.3f)
|
DMLC_DECLARE_FIELD(learning_rate)
|
||||||
|
.set_lower_bound(0.0f)
|
||||||
|
.set_default(0.3f)
|
||||||
.describe("Learning rate(step size) of update.");
|
.describe("Learning rate(step size) of update.");
|
||||||
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
DMLC_DECLARE_ALIAS(learning_rate, eta);
|
||||||
}
|
}
|
||||||
@ -157,12 +183,21 @@ class GBTree : public GradientBooster {
|
|||||||
for (const auto& up : updaters) {
|
for (const auto& up : updaters) {
|
||||||
up->Init(cfg);
|
up->Init(cfg);
|
||||||
}
|
}
|
||||||
|
// for the 'update' process_type, move trees into trees_to_update
|
||||||
|
if (tparam.process_type == kUpdate && trees_to_update.size() == 0u) {
|
||||||
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
|
trees_to_update.push_back(std::move(trees[i]));
|
||||||
|
}
|
||||||
|
trees.clear();
|
||||||
|
mparam.num_trees = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Load(dmlc::Stream* fi) override {
|
void Load(dmlc::Stream* fi) override {
|
||||||
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
||||||
<< "GBTree: invalid model file";
|
<< "GBTree: invalid model file";
|
||||||
trees.clear();
|
trees.clear();
|
||||||
|
trees_to_update.clear();
|
||||||
for (int i = 0; i < mparam.num_trees; ++i) {
|
for (int i = 0; i < mparam.num_trees; ++i) {
|
||||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||||
ptr->Load(fi);
|
ptr->Load(fi);
|
||||||
@ -386,11 +421,20 @@ class GBTree : public GradientBooster {
|
|||||||
ret->clear();
|
ret->clear();
|
||||||
// create the trees
|
// create the trees
|
||||||
for (int i = 0; i < tparam.num_parallel_tree; ++i) {
|
for (int i = 0; i < tparam.num_parallel_tree; ++i) {
|
||||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
if (tparam.process_type == kDefault) {
|
||||||
ptr->param.InitAllowUnknown(this->cfg);
|
// create new tree
|
||||||
ptr->InitModel();
|
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||||
new_trees.push_back(ptr.get());
|
ptr->param.InitAllowUnknown(this->cfg);
|
||||||
ret->push_back(std::move(ptr));
|
ptr->InitModel();
|
||||||
|
new_trees.push_back(ptr.get());
|
||||||
|
ret->push_back(std::move(ptr));
|
||||||
|
} else if (tparam.process_type == kUpdate) {
|
||||||
|
CHECK_LT(trees.size(), trees_to_update.size());
|
||||||
|
// move an existing tree from trees_to_update
|
||||||
|
auto t = std::move(trees_to_update[trees.size()]);
|
||||||
|
new_trees.push_back(t.get());
|
||||||
|
ret->push_back(std::move(t));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// update the trees
|
// update the trees
|
||||||
for (auto& up : updaters) {
|
for (auto& up : updaters) {
|
||||||
@ -493,6 +537,8 @@ class GBTree : public GradientBooster {
|
|||||||
GBTreeModelParam mparam;
|
GBTreeModelParam mparam;
|
||||||
/*! \brief vector of trees stored in the model */
|
/*! \brief vector of trees stored in the model */
|
||||||
std::vector<std::unique_ptr<RegTree> > trees;
|
std::vector<std::unique_ptr<RegTree> > trees;
|
||||||
|
/*! \brief for the update process, a place to keep the initial trees */
|
||||||
|
std::vector<std::unique_ptr<RegTree> > trees_to_update;
|
||||||
/*! \brief some information indicator of the tree, reserved */
|
/*! \brief some information indicator of the tree, reserved */
|
||||||
std::vector<int> tree_info;
|
std::vector<int> tree_info;
|
||||||
// ----training fields----
|
// ----training fields----
|
||||||
|
|||||||
@ -64,6 +64,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
bool cache_opt;
|
bool cache_opt;
|
||||||
// whether to not print info during training.
|
// whether to not print info during training.
|
||||||
bool silent;
|
bool silent;
|
||||||
|
// whether refresh updater needs to update the leaf values
|
||||||
|
bool refresh_leaf;
|
||||||
// auxiliary data structure
|
// auxiliary data structure
|
||||||
std::vector<int> monotone_constraints;
|
std::vector<int> monotone_constraints;
|
||||||
// declare the parameters
|
// declare the parameters
|
||||||
@ -75,10 +77,11 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
DMLC_DECLARE_FIELD(min_split_loss)
|
DMLC_DECLARE_FIELD(min_split_loss)
|
||||||
.set_lower_bound(0.0f)
|
.set_lower_bound(0.0f)
|
||||||
.set_default(0.0f)
|
.set_default(0.0f)
|
||||||
.describe(
|
.describe("Minimum loss reduction required to make a further partition.");
|
||||||
"Minimum loss reduction required to make a further partition.");
|
DMLC_DECLARE_FIELD(max_depth)
|
||||||
DMLC_DECLARE_FIELD(max_depth).set_lower_bound(0).set_default(6).describe(
|
.set_lower_bound(0)
|
||||||
"Maximum depth of the tree.");
|
.set_default(6)
|
||||||
|
.describe("Maximum depth of the tree.");
|
||||||
DMLC_DECLARE_FIELD(min_child_weight)
|
DMLC_DECLARE_FIELD(min_child_weight)
|
||||||
.set_lower_bound(0.0f)
|
.set_lower_bound(0.0f)
|
||||||
.set_default(1.0f)
|
.set_default(1.0f)
|
||||||
@ -100,9 +103,8 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
DMLC_DECLARE_FIELD(max_delta_step)
|
DMLC_DECLARE_FIELD(max_delta_step)
|
||||||
.set_lower_bound(0.0f)
|
.set_lower_bound(0.0f)
|
||||||
.set_default(0.0f)
|
.set_default(0.0f)
|
||||||
.describe(
|
.describe("Maximum delta step we allow each tree's weight estimate to be. "\
|
||||||
"Maximum delta step we allow each tree's weight estimate to be. "
|
"If the value is set to 0, it means there is no constraint");
|
||||||
"If the value is set to 0, it means there is no constraint");
|
|
||||||
DMLC_DECLARE_FIELD(subsample)
|
DMLC_DECLARE_FIELD(subsample)
|
||||||
.set_range(0.0f, 1.0f)
|
.set_range(0.0f, 1.0f)
|
||||||
.set_default(1.0f)
|
.set_default(1.0f)
|
||||||
@ -114,8 +116,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
DMLC_DECLARE_FIELD(colsample_bytree)
|
DMLC_DECLARE_FIELD(colsample_bytree)
|
||||||
.set_range(0.0f, 1.0f)
|
.set_range(0.0f, 1.0f)
|
||||||
.set_default(1.0f)
|
.set_default(1.0f)
|
||||||
.describe(
|
.describe("Subsample ratio of columns, resample on each tree construction.");
|
||||||
"Subsample ratio of columns, resample on each tree construction.");
|
|
||||||
DMLC_DECLARE_FIELD(opt_dense_col)
|
DMLC_DECLARE_FIELD(opt_dense_col)
|
||||||
.set_range(0.0f, 1.0f)
|
.set_range(0.0f, 1.0f)
|
||||||
.set_default(1.0f)
|
.set_default(1.0f)
|
||||||
@ -127,8 +128,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
DMLC_DECLARE_FIELD(sketch_ratio)
|
DMLC_DECLARE_FIELD(sketch_ratio)
|
||||||
.set_lower_bound(0.0f)
|
.set_lower_bound(0.0f)
|
||||||
.set_default(2.0f)
|
.set_default(2.0f)
|
||||||
.describe("EXP Param: Sketch accuracy related parameter of approximate "
|
.describe("EXP Param: Sketch accuracy related parameter of approximate algorithm.");
|
||||||
"algorithm.");
|
|
||||||
DMLC_DECLARE_FIELD(size_leaf_vector)
|
DMLC_DECLARE_FIELD(size_leaf_vector)
|
||||||
.set_lower_bound(0)
|
.set_lower_bound(0)
|
||||||
.set_default(0)
|
.set_default(0)
|
||||||
@ -136,10 +136,15 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
DMLC_DECLARE_FIELD(parallel_option)
|
DMLC_DECLARE_FIELD(parallel_option)
|
||||||
.set_default(0)
|
.set_default(0)
|
||||||
.describe("Different types of parallelization algorithm.");
|
.describe("Different types of parallelization algorithm.");
|
||||||
DMLC_DECLARE_FIELD(cache_opt).set_default(true).describe(
|
DMLC_DECLARE_FIELD(cache_opt)
|
||||||
"EXP Param: Cache aware optimization.");
|
.set_default(true)
|
||||||
DMLC_DECLARE_FIELD(silent).set_default(false).describe(
|
.describe("EXP Param: Cache aware optimization.");
|
||||||
"Do not print information during training.");
|
DMLC_DECLARE_FIELD(silent)
|
||||||
|
.set_default(false)
|
||||||
|
.describe("Do not print information during trainig.");
|
||||||
|
DMLC_DECLARE_FIELD(refresh_leaf)
|
||||||
|
.set_default(true)
|
||||||
|
.describe("Whether the refresh updater needs to update leaf values.");
|
||||||
DMLC_DECLARE_FIELD(monotone_constraints)
|
DMLC_DECLARE_FIELD(monotone_constraints)
|
||||||
.set_default(std::vector<int>())
|
.set_default(std::vector<int>())
|
||||||
.describe("Constraint of variable monotonicity");
|
.describe("Constraint of variable monotonicity");
|
||||||
|
|||||||
@ -134,7 +134,9 @@ class TreeRefresher: public TreeUpdater {
|
|||||||
tree.stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
|
tree.stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
|
||||||
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
|
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
|
||||||
if (tree[nid].is_leaf()) {
|
if (tree[nid].is_leaf()) {
|
||||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
if (param.refresh_leaf) {
|
||||||
|
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
tree.stat(nid).loss_chg = static_cast<bst_float>(
|
tree.stat(nid).loss_chg = static_cast<bst_float>(
|
||||||
gstats[tree[nid].cleft()].CalcGain(param) +
|
gstats[tree[nid].cleft()].CalcGain(param) +
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user