[CORE] The update process for a tree model, and its application to feature importance (#1670)
* [CORE] allow updating trees in an existing model * [CORE] in refresh updater, allow keeping old leaf values and update stats only * [R-package] xgb.train mod to allow updating trees in an existing model * [R-package] added check for nrounds when is_update * [CORE] merge parameter declaration changes; unify their code style * [CORE] move the update-process trees initialization to Configure; rename default process_type to 'default'; fix the trees and trees_to_update sizes comparison check * [R-package] unit tests for the update process type * [DOC] documentation for process_type parameter; improved docs for updater, Gamma and Tweedie; added some parameter aliases; metrics indentation and some were non-documented * fix my sloppy merge conflict resolutions * [CORE] add a TreeProcessType enum * whitespace fix
This commit is contained in:
committed by
Tianqi Chen
parent
4398fbbe4a
commit
a44032d095
@@ -284,7 +284,9 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
# Sort the callbacks into categories
|
||||
cb <- categorize.callbacks(callbacks)
|
||||
|
||||
|
||||
# The tree updating process would need slightly different handling
|
||||
is_update <- NVL(params[['process_type']], '.') == 'update'
|
||||
|
||||
# Construct a booster (either a new one or load from xgb_model)
|
||||
handle <- xgb.Booster(params, append(watchlist, dtrain), xgb_model)
|
||||
bst <- xgb.handleToBooster(handle)
|
||||
@@ -294,17 +296,20 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1)
|
||||
|
||||
# When the 'xgb_model' was set, find out how many boosting iterations it has
|
||||
niter_skip <- 0
|
||||
niter_init <- 0
|
||||
if (!is.null(xgb_model)) {
|
||||
niter_skip <- as.numeric(xgb.attr(bst, 'niter')) + 1
|
||||
if (length(niter_skip) == 0) {
|
||||
niter_skip <- xgb.ntree(bst) %/% (num_parallel_tree * num_class)
|
||||
niter_init <- as.numeric(xgb.attr(bst, 'niter')) + 1
|
||||
if (length(niter_init) == 0) {
|
||||
niter_init <- xgb.ntree(bst) %/% (num_parallel_tree * num_class)
|
||||
}
|
||||
}
|
||||
if(is_update && nrounds > niter_init)
|
||||
stop("nrounds cannot be larger than ", niter_init, " (nrounds of xgb_model)")
|
||||
|
||||
# TODO: distributed code
|
||||
rank <- 0
|
||||
|
||||
niter_skip <- ifelse(is_update, 0, niter_init)
|
||||
begin_iteration <- niter_skip + 1
|
||||
end_iteration <- niter_skip + nrounds
|
||||
|
||||
@@ -337,6 +342,7 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
nrow(evaluation_log) > 0) {
|
||||
# include the previous compatible history when available
|
||||
if (class(xgb_model) == 'xgb.Booster' &&
|
||||
!is_update &&
|
||||
!is.null(xgb_model$evaluation_log) &&
|
||||
all.equal(colnames(evaluation_log),
|
||||
colnames(xgb_model$evaluation_log))) {
|
||||
|
||||
Reference in New Issue
Block a user