[R] Refactor callback structure and attributes (#9957)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -142,7 +142,7 @@ check.custom.eval <- function(env = parent.frame()) {
|
||||
if (!is.null(env$feval) &&
|
||||
is.null(env$maximize) && (
|
||||
!is.null(env$early_stopping_rounds) ||
|
||||
has.callbacks(env$callbacks, 'cb.early.stop')))
|
||||
has.callbacks(env$callbacks, "early_stop")))
|
||||
stop("Please set 'maximize' to indicate whether the evaluation metric needs to be maximized or not")
|
||||
}
|
||||
|
||||
|
||||
@@ -1071,6 +1071,10 @@ xgb.best_iteration <- function(bst) {
|
||||
#' coef(model)
|
||||
#' @export
|
||||
coef.xgb.Booster <- function(object, ...) {
|
||||
return(.internal.coef.xgb.Booster(object, add_names = TRUE))
|
||||
}
|
||||
|
||||
.internal.coef.xgb.Booster <- function(object, add_names = TRUE) {
|
||||
booster_type <- xgb.booster_type(object)
|
||||
if (booster_type != "gblinear") {
|
||||
stop("Coefficients are not defined for Booster type ", booster_type)
|
||||
@@ -1089,21 +1093,27 @@ coef.xgb.Booster <- function(object, ...) {
|
||||
intercepts <- weights[seq(sep + 1, length(weights))]
|
||||
intercepts <- intercepts + as.numeric(base_score)
|
||||
|
||||
feature_names <- xgb.feature_names(object)
|
||||
if (!NROW(feature_names)) {
|
||||
# This mimics the default naming in R which names columns as "V1..N"
|
||||
# when names are needed but not available
|
||||
feature_names <- paste0("V", seq(1L, num_feature))
|
||||
if (add_names) {
|
||||
feature_names <- xgb.feature_names(object)
|
||||
if (!NROW(feature_names)) {
|
||||
# This mimics the default naming in R which names columns as "V1..N"
|
||||
# when names are needed but not available
|
||||
feature_names <- paste0("V", seq(1L, num_feature))
|
||||
}
|
||||
feature_names <- c("(Intercept)", feature_names)
|
||||
}
|
||||
feature_names <- c("(Intercept)", feature_names)
|
||||
if (n_cols == 1L) {
|
||||
out <- c(intercepts, coefs)
|
||||
names(out) <- feature_names
|
||||
if (add_names) {
|
||||
names(out) <- feature_names
|
||||
}
|
||||
} else {
|
||||
coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE)
|
||||
dim(intercepts) <- c(1L, n_cols)
|
||||
out <- rbind(intercepts, coefs)
|
||||
row.names(out) <- feature_names
|
||||
if (add_names) {
|
||||
row.names(out) <- feature_names
|
||||
}
|
||||
# TODO: if a class names attributes is added,
|
||||
# should use those names here.
|
||||
}
|
||||
@@ -1255,12 +1265,9 @@ print.xgb.Booster <- function(x, ...) {
|
||||
cat(" ", paste(attr_names, collapse = ", "), "\n")
|
||||
}
|
||||
|
||||
if (!is.null(R_attrs$callbacks) && length(R_attrs$callbacks) > 0) {
|
||||
cat('callbacks:\n')
|
||||
lapply(callback.calls(R_attrs$callbacks), function(x) {
|
||||
cat(' ')
|
||||
print(x)
|
||||
})
|
||||
additional_attr <- setdiff(names(R_attrs), .reserved_cb_names)
|
||||
if (NROW(additional_attr)) {
|
||||
cat("callbacks:\n ", paste(additional_attr, collapse = ", "), "\n")
|
||||
}
|
||||
|
||||
if (!is.null(R_attrs$evaluation_log)) {
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
#' that NA values should be considered as 'missing' by the algorithm.
|
||||
#' Sometimes, 0 or other extreme value might be used to represent missing values.
|
||||
#' @param prediction A logical value indicating whether to return the test fold predictions
|
||||
#' from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callback.
|
||||
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
|
||||
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
|
||||
#' @param metrics, list of evaluation metrics to be used in cross validation,
|
||||
#' when it is not specified, the evaluation metric is chosen according to objective function.
|
||||
@@ -57,17 +57,17 @@
|
||||
#' @param verbose \code{boolean}, print the statistics during the process
|
||||
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
|
||||
#' Default is 1 which means all messages are printed. This parameter is passed to the
|
||||
#' \code{\link{cb.print.evaluation}} callback.
|
||||
#' \code{\link{xgb.cb.print.evaluation}} callback.
|
||||
#' @param early_stopping_rounds If \code{NULL}, the early stopping function is not triggered.
|
||||
#' If set to an integer \code{k}, training with a validation set will stop if the performance
|
||||
#' doesn't improve for \code{k} rounds.
|
||||
#' Setting this parameter engages the \code{\link{cb.early.stop}} callback.
|
||||
#' Setting this parameter engages the \code{\link{xgb.cb.early.stop}} callback.
|
||||
#' @param maximize If \code{feval} and \code{early_stopping_rounds} are set,
|
||||
#' then this parameter must be set as well.
|
||||
#' When it is \code{TRUE}, it means the larger the evaluation score the better.
|
||||
#' This parameter is passed to the \code{\link{cb.early.stop}} callback.
|
||||
#' This parameter is passed to the \code{\link{xgb.cb.early.stop}} callback.
|
||||
#' @param callbacks a list of callback functions to perform various task during boosting.
|
||||
#' See \code{\link{callbacks}}. Some of the callbacks are automatically created depending on the
|
||||
#' See \code{\link{xgb.Callback}}. Some of the callbacks are automatically created depending on the
|
||||
#' parameters' values. User can provide either existing or their own callback methods in order
|
||||
#' to customize the training process.
|
||||
#' @param ... other parameters to pass to \code{params}.
|
||||
@@ -90,25 +90,25 @@
|
||||
#' \itemize{
|
||||
#' \item \code{call} a function call.
|
||||
#' \item \code{params} parameters that were passed to the xgboost library. Note that it does not
|
||||
#' capture parameters changed by the \code{\link{cb.reset.parameters}} callback.
|
||||
#' \item \code{callbacks} callback functions that were either automatically assigned or
|
||||
#' explicitly passed.
|
||||
#' capture parameters changed by the \code{\link{xgb.cb.reset.parameters}} callback.
|
||||
#' \item \code{evaluation_log} evaluation history stored as a \code{data.table} with the
|
||||
#' first column corresponding to iteration number and the rest corresponding to the
|
||||
#' CV-based evaluation means and standard deviations for the training and test CV-sets.
|
||||
#' It is created by the \code{\link{cb.evaluation.log}} callback.
|
||||
#' It is created by the \code{\link{xgb.cb.evaluation.log}} callback.
|
||||
#' \item \code{niter} number of boosting iterations.
|
||||
#' \item \code{nfeatures} number of features in training data.
|
||||
#' \item \code{folds} the list of CV folds' indices - either those passed through the \code{folds}
|
||||
#' parameter or randomly generated.
|
||||
#' \item \code{best_iteration} iteration number with the best evaluation metric value
|
||||
#' (only available with early stopping).
|
||||
#' \item \code{pred} CV prediction values available when \code{prediction} is set.
|
||||
#' It is either vector or matrix (see \code{\link{cb.cv.predict}}).
|
||||
#' \item \code{models} a list of the CV folds' models. It is only available with the explicit
|
||||
#' setting of the \code{cb.cv.predict(save_models = TRUE)} callback.
|
||||
#' }
|
||||
#'
|
||||
#' Plus other potential elements that are the result of callbacks, such as a list `cv_predict` with
|
||||
#' a sub-element `pred` when passing `prediction = TRUE`, which is added by the \link{xgb.cb.cv.predict}
|
||||
#' callback (note that one can also pass it manually under `callbacks` with different settings,
|
||||
#' such as saving also the models created during cross validation); or a list `early_stop` which
|
||||
#' will contain elements such as `best_iteration` when using the early stopping callback (\link{xgb.cb.early.stop}).
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
#' dtrain <- with(agaricus.train, xgb.DMatrix(data, label = label, nthread = 2))
|
||||
@@ -160,32 +160,38 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
|
||||
}
|
||||
|
||||
# Callbacks
|
||||
tmp <- .process.callbacks(callbacks, is_cv = TRUE)
|
||||
callbacks <- tmp$callbacks
|
||||
cb_names <- tmp$cb_names
|
||||
rm(tmp)
|
||||
|
||||
# Early stopping callback
|
||||
if (!is.null(early_stopping_rounds) && !("early_stop" %in% cb_names)) {
|
||||
callbacks <- add.callback(
|
||||
callbacks,
|
||||
xgb.cb.early.stop(
|
||||
early_stopping_rounds,
|
||||
maximize = maximize,
|
||||
verbose = verbose
|
||||
),
|
||||
as_first_elt = TRUE
|
||||
)
|
||||
}
|
||||
# verbosity & evaluation printing callback:
|
||||
params <- c(params, list(silent = 1))
|
||||
print_every_n <- max(as.integer(print_every_n), 1L)
|
||||
if (!has.callbacks(callbacks, 'cb.print.evaluation') && verbose) {
|
||||
callbacks <- add.cb(callbacks, cb.print.evaluation(print_every_n, showsd = showsd))
|
||||
if (verbose && !("print_evaluation" %in% cb_names)) {
|
||||
callbacks <- add.callback(callbacks, xgb.cb.print.evaluation(print_every_n, showsd = showsd))
|
||||
}
|
||||
# evaluation log callback: always is on in CV
|
||||
evaluation_log <- list()
|
||||
if (!has.callbacks(callbacks, 'cb.evaluation.log')) {
|
||||
callbacks <- add.cb(callbacks, cb.evaluation.log())
|
||||
}
|
||||
# Early stopping callback
|
||||
stop_condition <- FALSE
|
||||
if (!is.null(early_stopping_rounds) &&
|
||||
!has.callbacks(callbacks, 'cb.early.stop')) {
|
||||
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds,
|
||||
maximize = maximize, verbose = verbose))
|
||||
if (!("evaluation_log" %in% cb_names)) {
|
||||
callbacks <- add.callback(callbacks, xgb.cb.evaluation.log())
|
||||
}
|
||||
# CV-predictions callback
|
||||
if (prediction &&
|
||||
!has.callbacks(callbacks, 'cb.cv.predict')) {
|
||||
callbacks <- add.cb(callbacks, cb.cv.predict(save_models = FALSE))
|
||||
if (prediction && !("cv_predict" %in% cb_names)) {
|
||||
callbacks <- add.callback(callbacks, xgb.cb.cv.predict(save_models = FALSE))
|
||||
}
|
||||
# Sort the callbacks into categories
|
||||
cb <- categorize.callbacks(callbacks)
|
||||
|
||||
|
||||
# create the booster-folds
|
||||
# train_folds
|
||||
@@ -211,9 +217,6 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
bst <- bst$bst
|
||||
list(dtrain = dtrain, bst = bst, watchlist = list(train = dtrain, test = dtest), index = folds[[k]])
|
||||
})
|
||||
rm(dall)
|
||||
# a "basket" to collect some results from callbacks
|
||||
basket <- list()
|
||||
|
||||
# extract parameters that can affect the relationship b/w #trees and #iterations
|
||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) # nolint
|
||||
@@ -222,10 +225,25 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
begin_iteration <- 1
|
||||
end_iteration <- nrounds
|
||||
|
||||
.execute.cb.before.training(
|
||||
callbacks,
|
||||
bst_folds,
|
||||
dall,
|
||||
NULL,
|
||||
begin_iteration,
|
||||
end_iteration
|
||||
)
|
||||
|
||||
# synchronous CV boosting: run CV folds' models within each iteration
|
||||
for (iteration in begin_iteration:end_iteration) {
|
||||
|
||||
for (f in cb$pre_iter) f()
|
||||
.execute.cb.before.iter(
|
||||
callbacks,
|
||||
bst_folds,
|
||||
dall,
|
||||
NULL,
|
||||
iteration
|
||||
)
|
||||
|
||||
msg <- lapply(bst_folds, function(fd) {
|
||||
xgb.iter.update(
|
||||
@@ -242,27 +260,36 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
)
|
||||
})
|
||||
msg <- simplify2array(msg)
|
||||
# Note: these variables might look unused here, but they are used in the callbacks
|
||||
bst_evaluation <- rowMeans(msg) # nolint
|
||||
bst_evaluation_err <- apply(msg, 1, sd) # nolint
|
||||
|
||||
for (f in cb$post_iter) f()
|
||||
should_stop <- .execute.cb.after.iter(
|
||||
callbacks,
|
||||
bst_folds,
|
||||
dall,
|
||||
NULL,
|
||||
iteration,
|
||||
msg
|
||||
)
|
||||
|
||||
if (stop_condition) break
|
||||
if (should_stop) break
|
||||
}
|
||||
for (f in cb$finalize) f(finalize = TRUE)
|
||||
cb_outputs <- .execute.cb.after.training(
|
||||
callbacks,
|
||||
bst_folds,
|
||||
dall,
|
||||
NULL,
|
||||
iteration,
|
||||
msg
|
||||
)
|
||||
|
||||
# the CV result
|
||||
ret <- list(
|
||||
call = match.call(),
|
||||
params = params,
|
||||
callbacks = callbacks,
|
||||
evaluation_log = evaluation_log,
|
||||
niter = end_iteration,
|
||||
nfeatures = ncol(data),
|
||||
niter = iteration,
|
||||
nfeatures = ncol(dall),
|
||||
folds = folds
|
||||
)
|
||||
ret <- c(ret, basket)
|
||||
ret <- c(ret, cb_outputs)
|
||||
|
||||
class(ret) <- 'xgb.cv.synchronous'
|
||||
return(invisible(ret))
|
||||
@@ -308,23 +335,16 @@ print.xgb.cv.synchronous <- function(x, verbose = FALSE, ...) {
|
||||
paste0('"', unlist(x$params), '"'),
|
||||
sep = ' = ', collapse = ', '), '\n', sep = '')
|
||||
}
|
||||
if (!is.null(x$callbacks) && length(x$callbacks) > 0) {
|
||||
cat('callbacks:\n')
|
||||
lapply(callback.calls(x$callbacks), function(x) {
|
||||
cat(' ')
|
||||
print(x)
|
||||
})
|
||||
}
|
||||
|
||||
for (n in c('niter', 'best_iteration')) {
|
||||
if (is.null(x[[n]]))
|
||||
if (is.null(x$early_stop[[n]]))
|
||||
next
|
||||
cat(n, ': ', x[[n]], '\n', sep = '')
|
||||
cat(n, ': ', x$early_stop[[n]], '\n', sep = '')
|
||||
}
|
||||
|
||||
if (!is.null(x$pred)) {
|
||||
if (!is.null(x$cv_predict$pred)) {
|
||||
cat('pred:\n')
|
||||
str(x$pred)
|
||||
str(x$cv_predict$pred)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,9 +352,9 @@ print.xgb.cv.synchronous <- function(x, verbose = FALSE, ...) {
|
||||
cat('evaluation_log:\n')
|
||||
print(x$evaluation_log, row.names = FALSE, ...)
|
||||
|
||||
if (!is.null(x$best_iteration)) {
|
||||
if (!is.null(x$early_stop$best_iteration)) {
|
||||
cat('Best iteration:\n')
|
||||
print(x$evaluation_log[x$best_iteration], row.names = FALSE, ...)
|
||||
print(x$evaluation_log[x$early_stop$best_iteration], row.names = FALSE, ...)
|
||||
}
|
||||
invisible(x)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#'
|
||||
#' @details
|
||||
#' The input file is expected to contain a model saved in an xgboost model format
|
||||
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
|
||||
#' using either \code{\link{xgb.save}} or \code{\link{xgb.cb.save.model}} in R, or using some
|
||||
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
|
||||
#' saved from there in xgboost format, could be loaded from R.
|
||||
#'
|
||||
|
||||
@@ -118,7 +118,7 @@
|
||||
#' Metrics specified in either \code{eval_metric} or \code{feval} will be computed for each
|
||||
#' of these datasets during each boosting iteration, and stored in the end as a field named
|
||||
#' \code{evaluation_log} in the resulting object. When either \code{verbose>=1} or
|
||||
#' \code{\link{cb.print.evaluation}} callback is engaged, the performance results are continuously
|
||||
#' \code{\link{xgb.cb.print.evaluation}} callback is engaged, the performance results are continuously
|
||||
#' printed out during the training.
|
||||
#' E.g., specifying \code{watchlist=list(validation1=mat1, validation2=mat2)} allows to track
|
||||
#' the performance of each round's model on mat1 and mat2.
|
||||
@@ -130,31 +130,32 @@
|
||||
#' @param verbose If 0, xgboost will stay silent. If 1, it will print information about performance.
|
||||
#' If 2, some additional information will be printed out.
|
||||
#' Note that setting \code{verbose > 0} automatically engages the
|
||||
#' \code{cb.print.evaluation(period=1)} callback function.
|
||||
#' \code{xgb.cb.print.evaluation(period=1)} callback function.
|
||||
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
|
||||
#' Default is 1 which means all messages are printed. This parameter is passed to the
|
||||
#' \code{\link{cb.print.evaluation}} callback.
|
||||
#' \code{\link{xgb.cb.print.evaluation}} callback.
|
||||
#' @param early_stopping_rounds If \code{NULL}, the early stopping function is not triggered.
|
||||
#' If set to an integer \code{k}, training with a validation set will stop if the performance
|
||||
#' doesn't improve for \code{k} rounds.
|
||||
#' Setting this parameter engages the \code{\link{cb.early.stop}} callback.
|
||||
#' Setting this parameter engages the \code{\link{xgb.cb.early.stop}} callback.
|
||||
#' @param maximize If \code{feval} and \code{early_stopping_rounds} are set,
|
||||
#' then this parameter must be set as well.
|
||||
#' When it is \code{TRUE}, it means the larger the evaluation score the better.
|
||||
#' This parameter is passed to the \code{\link{cb.early.stop}} callback.
|
||||
#' This parameter is passed to the \code{\link{xgb.cb.early.stop}} callback.
|
||||
#' @param save_period when it is non-NULL, model is saved to disk after every \code{save_period} rounds,
|
||||
#' 0 means save at the end. The saving is handled by the \code{\link{cb.save.model}} callback.
|
||||
#' 0 means save at the end. The saving is handled by the \code{\link{xgb.cb.save.model}} callback.
|
||||
#' @param save_name the name or path for periodically saved model file.
|
||||
#' @param xgb_model a previously built model to continue the training from.
|
||||
#' Could be either an object of class \code{xgb.Booster}, or its raw data, or the name of a
|
||||
#' file with a previously saved model.
|
||||
#' @param callbacks a list of callback functions to perform various task during boosting.
|
||||
#' See \code{\link{callbacks}}. Some of the callbacks are automatically created depending on the
|
||||
#' See \code{\link{xgb.Callback}}. Some of the callbacks are automatically created depending on the
|
||||
#' parameters' values. User can provide either existing or their own callback methods in order
|
||||
#' to customize the training process.
|
||||
#'
|
||||
#' Note that some callbacks might try to set an evaluation log - be aware that these evaluation logs
|
||||
#' are kept as R attributes, and thus do not get saved when using non-R serializaters like
|
||||
#' Note that some callbacks might try to leave attributes in the resulting model object,
|
||||
#' such as an evaluation log (a `data.table` object) - be aware that these objects are kept
|
||||
#' as R attributes, and thus do not get saved when using XGBoost's own serializaters like
|
||||
#' \link{xgb.save} (but are kept when using R serializers like \link{saveRDS}).
|
||||
#' @param ... other parameters to pass to \code{params}.
|
||||
#' @param label vector of response values. Should not be provided when data is
|
||||
@@ -206,18 +207,19 @@
|
||||
#'
|
||||
#' The following callbacks are automatically created when certain parameters are set:
|
||||
#' \itemize{
|
||||
#' \item \code{cb.print.evaluation} is turned on when \code{verbose > 0};
|
||||
#' \item \code{xgb.cb.print.evaluation} is turned on when \code{verbose > 0};
|
||||
#' and the \code{print_every_n} parameter is passed to it.
|
||||
#' \item \code{cb.evaluation.log} is on when \code{watchlist} is present.
|
||||
#' \item \code{cb.early.stop}: when \code{early_stopping_rounds} is set.
|
||||
#' \item \code{cb.save.model}: when \code{save_period > 0} is set.
|
||||
#' \item \code{xgb.cb.evaluation.log} is on when \code{watchlist} is present.
|
||||
#' \item \code{xgb.cb.early.stop}: when \code{early_stopping_rounds} is set.
|
||||
#' \item \code{xgb.cb.save.model}: when \code{save_period > 0} is set.
|
||||
#' }
|
||||
#'
|
||||
#' Note that objects of type `xgb.Booster` as returned by this function behave a bit differently
|
||||
#' from typical R objects (it's an 'altrep' list class), and it makes a separation between
|
||||
#' internal booster attributes (restricted to jsonifyable data), accessed through \link{xgb.attr}
|
||||
#' and shared between interfaces through serialization functions like \link{xgb.save}; and
|
||||
#' R-specific attributes, accessed through \link{attributes} and \link{attr}, which are otherwise
|
||||
#' R-specific attributes (typically the result from a callback), accessed through \link{attributes}
|
||||
#' and \link{attr}, which are otherwise
|
||||
#' only used in the R interface, only kept when using R's serializers like \link{saveRDS}, and
|
||||
#' not anyhow used by functions like \link{predict.xgb.Booster}.
|
||||
#'
|
||||
@@ -229,7 +231,7 @@
|
||||
#' effect elsewhere.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{callbacks}},
|
||||
#' \code{\link{xgb.Callback}},
|
||||
#' \code{\link{predict.xgb.Booster}},
|
||||
#' \code{\link{xgb.cv}}
|
||||
#'
|
||||
@@ -295,7 +297,7 @@
|
||||
#' objective = "binary:logistic", eval_metric = "auc")
|
||||
#' my_etas <- list(eta = c(0.5, 0.1))
|
||||
#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0,
|
||||
#' callbacks = list(cb.reset.parameters(my_etas)))
|
||||
#' callbacks = list(xgb.cb.reset.parameters(my_etas)))
|
||||
#'
|
||||
#' ## Early stopping:
|
||||
#' bst <- xgb.train(param, dtrain, nrounds = 25, watchlist,
|
||||
@@ -339,47 +341,47 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
params <- c(params, list(eval_metric = m))
|
||||
}
|
||||
|
||||
# evaluation printing callback
|
||||
params <- c(params)
|
||||
print_every_n <- max(as.integer(print_every_n), 1L)
|
||||
if (!has.callbacks(callbacks, 'cb.print.evaluation') &&
|
||||
verbose) {
|
||||
callbacks <- add.cb(callbacks, cb.print.evaluation(print_every_n))
|
||||
}
|
||||
# evaluation log callback: it is automatically enabled when watchlist is provided
|
||||
evaluation_log <- list()
|
||||
if (!has.callbacks(callbacks, 'cb.evaluation.log') &&
|
||||
length(watchlist) > 0) {
|
||||
callbacks <- add.cb(callbacks, cb.evaluation.log())
|
||||
}
|
||||
# Model saving callback
|
||||
if (!is.null(save_period) &&
|
||||
!has.callbacks(callbacks, 'cb.save.model')) {
|
||||
callbacks <- add.cb(callbacks, cb.save.model(save_period, save_name))
|
||||
}
|
||||
# Early stopping callback
|
||||
stop_condition <- FALSE
|
||||
if (!is.null(early_stopping_rounds) &&
|
||||
!has.callbacks(callbacks, 'cb.early.stop')) {
|
||||
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds,
|
||||
maximize = maximize, verbose = verbose))
|
||||
}
|
||||
|
||||
# Sort the callbacks into categories
|
||||
cb <- categorize.callbacks(callbacks)
|
||||
params['validate_parameters'] <- TRUE
|
||||
if (!("seed" %in% names(params))) {
|
||||
params[["seed"]] <- sample(.Machine$integer.max, size = 1)
|
||||
}
|
||||
|
||||
# callbacks
|
||||
tmp <- .process.callbacks(callbacks, is_cv = FALSE)
|
||||
callbacks <- tmp$callbacks
|
||||
cb_names <- tmp$cb_names
|
||||
rm(tmp)
|
||||
|
||||
# Early stopping callback (should always come first)
|
||||
if (!is.null(early_stopping_rounds) && !("early_stop" %in% cb_names)) {
|
||||
callbacks <- add.callback(
|
||||
callbacks,
|
||||
xgb.cb.early.stop(
|
||||
early_stopping_rounds,
|
||||
maximize = maximize,
|
||||
verbose = verbose
|
||||
),
|
||||
as_first_elt = TRUE
|
||||
)
|
||||
}
|
||||
# evaluation printing callback
|
||||
print_every_n <- max(as.integer(print_every_n), 1L)
|
||||
if (verbose && !("print_evaluation" %in% cb_names)) {
|
||||
callbacks <- add.callback(callbacks, xgb.cb.print.evaluation(print_every_n))
|
||||
}
|
||||
# evaluation log callback: it is automatically enabled when watchlist is provided
|
||||
if (length(watchlist) && !("evaluation_log" %in% cb_names)) {
|
||||
callbacks <- add.callback(callbacks, xgb.cb.evaluation.log())
|
||||
}
|
||||
# Model saving callback
|
||||
if (!is.null(save_period) && !("save_model" %in% cb_names)) {
|
||||
callbacks <- add.callback(callbacks, xgb.cb.save.model(save_period, save_name))
|
||||
}
|
||||
|
||||
# The tree updating process would need slightly different handling
|
||||
is_update <- NVL(params[['process_type']], '.') == 'update'
|
||||
|
||||
past_evaluation_log <- NULL
|
||||
if (inherits(xgb_model, "xgb.Booster")) {
|
||||
past_evaluation_log <- attributes(xgb_model)$evaluation_log
|
||||
}
|
||||
|
||||
# Construct a booster (either a new one or load from xgb_model)
|
||||
bst <- xgb.Booster(
|
||||
params = params,
|
||||
@@ -394,11 +396,6 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
dtrain
|
||||
)
|
||||
|
||||
# extract parameters that can affect the relationship b/w #trees and #iterations
|
||||
# Note: it might look like these aren't used, but they need to be defined in this
|
||||
# environment for the callbacks for work correctly.
|
||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) # nolint
|
||||
|
||||
if (is_update && nrounds > niter_init)
|
||||
stop("nrounds cannot be larger than ", niter_init, " (nrounds of xgb_model)")
|
||||
|
||||
@@ -406,20 +403,36 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
begin_iteration <- niter_skip + 1
|
||||
end_iteration <- niter_skip + nrounds
|
||||
|
||||
.execute.cb.before.training(
|
||||
callbacks,
|
||||
bst,
|
||||
dtrain,
|
||||
watchlist,
|
||||
begin_iteration,
|
||||
end_iteration
|
||||
)
|
||||
|
||||
# the main loop for boosting iterations
|
||||
for (iteration in begin_iteration:end_iteration) {
|
||||
|
||||
for (f in cb$pre_iter) f()
|
||||
|
||||
xgb.iter.update(
|
||||
bst = bst,
|
||||
dtrain = dtrain,
|
||||
iter = iteration - 1,
|
||||
obj = obj
|
||||
.execute.cb.before.iter(
|
||||
callbacks,
|
||||
bst,
|
||||
dtrain,
|
||||
watchlist,
|
||||
iteration
|
||||
)
|
||||
|
||||
xgb.iter.update(
|
||||
bst = bst,
|
||||
dtrain = dtrain,
|
||||
iter = iteration - 1,
|
||||
obj = obj
|
||||
)
|
||||
|
||||
bst_evaluation <- NULL
|
||||
if (length(watchlist) > 0) {
|
||||
bst_evaluation <- xgb.iter.eval( # nolint: object_usage_linter
|
||||
bst_evaluation <- xgb.iter.eval(
|
||||
bst = bst,
|
||||
watchlist = watchlist,
|
||||
iter = iteration - 1,
|
||||
@@ -427,36 +440,46 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
)
|
||||
}
|
||||
|
||||
for (f in cb$post_iter) f()
|
||||
should_stop <- .execute.cb.after.iter(
|
||||
callbacks,
|
||||
bst,
|
||||
dtrain,
|
||||
watchlist,
|
||||
iteration,
|
||||
bst_evaluation
|
||||
)
|
||||
|
||||
if (stop_condition) break
|
||||
if (should_stop) break
|
||||
}
|
||||
for (f in cb$finalize) f(finalize = TRUE)
|
||||
|
||||
# store the evaluation results
|
||||
keep_evaluation_log <- FALSE
|
||||
if (length(evaluation_log) > 0 && nrow(evaluation_log) > 0) {
|
||||
keep_evaluation_log <- TRUE
|
||||
# include the previous compatible history when available
|
||||
if (inherits(xgb_model, 'xgb.Booster') &&
|
||||
!is_update &&
|
||||
!is.null(past_evaluation_log) &&
|
||||
isTRUE(all.equal(colnames(evaluation_log),
|
||||
colnames(past_evaluation_log)))) {
|
||||
evaluation_log <- rbindlist(list(past_evaluation_log, evaluation_log))
|
||||
}
|
||||
}
|
||||
cb_outputs <- .execute.cb.after.training(
|
||||
callbacks,
|
||||
bst,
|
||||
dtrain,
|
||||
watchlist,
|
||||
iteration,
|
||||
bst_evaluation
|
||||
)
|
||||
|
||||
extra_attrs <- list(
|
||||
call = match.call(),
|
||||
params = params,
|
||||
callbacks = callbacks
|
||||
params = params
|
||||
)
|
||||
if (keep_evaluation_log) {
|
||||
extra_attrs$evaluation_log <- evaluation_log
|
||||
}
|
||||
|
||||
curr_attrs <- attributes(bst)
|
||||
attributes(bst) <- c(curr_attrs, extra_attrs)
|
||||
if (NROW(curr_attrs)) {
|
||||
curr_attrs <- curr_attrs[
|
||||
setdiff(
|
||||
names(curr_attrs),
|
||||
c(names(extra_attrs), names(cb_outputs))
|
||||
)
|
||||
]
|
||||
}
|
||||
curr_attrs <- c(extra_attrs, curr_attrs)
|
||||
if (NROW(cb_outputs)) {
|
||||
curr_attrs <- c(curr_attrs, cb_outputs)
|
||||
}
|
||||
attributes(bst) <- curr_attrs
|
||||
|
||||
return(bst)
|
||||
}
|
||||
|
||||
@@ -82,12 +82,8 @@ NULL
|
||||
NULL
|
||||
|
||||
# Various imports
|
||||
#' @importClassesFrom Matrix dgCMatrix dgeMatrix dgRMatrix
|
||||
#' @importFrom Matrix colSums
|
||||
#' @importClassesFrom Matrix dgCMatrix dgRMatrix CsparseMatrix
|
||||
#' @importFrom Matrix sparse.model.matrix
|
||||
#' @importFrom Matrix sparseVector
|
||||
#' @importFrom Matrix sparseMatrix
|
||||
#' @importFrom Matrix t
|
||||
#' @importFrom data.table data.table
|
||||
#' @importFrom data.table is.data.table
|
||||
#' @importFrom data.table as.data.table
|
||||
@@ -103,6 +99,7 @@ NULL
|
||||
#' @importFrom stats coef
|
||||
#' @importFrom stats predict
|
||||
#' @importFrom stats median
|
||||
#' @importFrom stats sd
|
||||
#' @importFrom stats variable.names
|
||||
#' @importFrom utils head
|
||||
#' @importFrom graphics barplot
|
||||
|
||||
Reference in New Issue
Block a user