[R] CB naming change; cv-prediction as CB; add.cb function to ensure proper CB order; docs; minor fixes + changes
This commit is contained in:
@@ -15,18 +15,19 @@
|
||||
#' the environment from which they are called from, which is a fairly uncommon thing to do in R.
|
||||
#'
|
||||
#' To write a custom callback closure, make sure you first understand the main concepts about R envoronments.
|
||||
#' Check either the R docs on \code{\link[base]{environment}} or the
|
||||
#' \href{http://adv-r.had.co.nz/Environments.html}{Environments chapter} from Hadley Wickham's "Advanced R" book.
|
||||
#' Then take a look at the code of \code{cb.reset_learning_rate} for a simple example,
|
||||
#' and see the \code{cb.log_evaluation} code for something more involved.
|
||||
#' Also, you would need to get familiar with the objects available inside of the \code{xgb.train} internal environment.
|
||||
#' Check either R documentation on \code{\link[base]{environment}} or the
|
||||
#' \href{http://adv-r.had.co.nz/Environments.html}{Environments chapter} from the "Advanced R"
|
||||
#' book by Hadley Wickham. Further, the best option is to read the code of some of the existing callbacks -
|
||||
#' choose ones that do something similar to what you want to achieve. Also, you would need to get familiar
|
||||
#' with the objects available inside of the \code{xgb.train} and \code{xgb.cv} internal environments.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{cb.print_evaluation}},
|
||||
#' \code{\link{cb.log_evaluation}},
|
||||
#' \code{\link{cb.reset_parameters}},
|
||||
#' \code{\link{cb.early_stop}},
|
||||
#' \code{\link{cb.save_model}},
|
||||
#' \code{\link{cb.print.evaluation}},
|
||||
#' \code{\link{cb.evaluation.log}},
|
||||
#' \code{\link{cb.reset.parameters}},
|
||||
#' \code{\link{cb.early.stop}},
|
||||
#' \code{\link{cb.save.model}},
|
||||
#' \code{\link{cb.cv.predict}},
|
||||
#' \code{\link{xgb.train}},
|
||||
#' \code{\link{xgb.cv}}
|
||||
#'
|
||||
@@ -55,7 +56,7 @@ NULL
|
||||
#' \code{\link{callbacks}}
|
||||
#'
|
||||
#' @export
|
||||
cb.print_evaluation <- function(period=1) {
|
||||
cb.print.evaluation <- function(period=1) {
|
||||
|
||||
callback <- function(env = parent.frame()) {
|
||||
if (length(env$bst_evaluation) == 0 ||
|
||||
@@ -67,12 +68,12 @@ cb.print_evaluation <- function(period=1) {
|
||||
if ((i-1) %% period == 0 ||
|
||||
i == env$begin_iteration ||
|
||||
i == env$end_iteration) {
|
||||
msg <- format_eval_string(i, env$bst_evaluation, env$bst_evaluation_err)
|
||||
msg <- format.eval.string(i, env$bst_evaluation, env$bst_evaluation_err)
|
||||
cat(msg, '\n')
|
||||
}
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.print_evaluation'
|
||||
attr(callback, 'name') <- 'cb.print.evaluation'
|
||||
callback
|
||||
}
|
||||
|
||||
@@ -100,7 +101,7 @@ cb.print_evaluation <- function(period=1) {
|
||||
#' \code{\link{callbacks}}
|
||||
#'
|
||||
#' @export
|
||||
cb.log_evaluation <- function() {
|
||||
cb.evaluation.log <- function() {
|
||||
|
||||
mnames <- NULL
|
||||
|
||||
@@ -147,7 +148,7 @@ cb.log_evaluation <- function() {
|
||||
list(c(iter = env$iteration, ev)))
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.log_evaluation'
|
||||
attr(callback, 'name') <- 'cb.evaluation.log'
|
||||
callback
|
||||
}
|
||||
|
||||
@@ -178,17 +179,27 @@ cb.log_evaluation <- function() {
|
||||
#' \code{\link{callbacks}}
|
||||
#'
|
||||
#' @export
|
||||
cb.reset_parameters <- function(new_params) {
|
||||
cb.reset.parameters <- function(new_params) {
|
||||
|
||||
if (typeof(new_params) != "list")
|
||||
stop("'new_params' must be a list")
|
||||
pnames <- gsub("\\.", "_", names(new_params))
|
||||
# TODO: restrict the set of parameters that could be reset?
|
||||
nrounds <- NULL
|
||||
|
||||
# run some checks in the begining
|
||||
init <- function(env) {
|
||||
nrounds <<- env$end_iteration - env$begin_iteration + 1
|
||||
|
||||
if (is.null(env$bst) && is.null(env$bst_folds))
|
||||
stop("Parent frame has neither 'bst' nor 'bst_folds'")
|
||||
|
||||
# Some parameters are not allowed to be changed,
|
||||
# since changing them would simply wreck some chaos
|
||||
not_allowed <- pnames %in%
|
||||
c('num_class', 'num_output_group', 'size_leaf_vector', 'updater_seq')
|
||||
if (any(not_allowed))
|
||||
stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.")
|
||||
|
||||
for (n in pnames) {
|
||||
p <- new_params[[n]]
|
||||
if (is.function(p)) {
|
||||
@@ -223,7 +234,7 @@ cb.reset_parameters <- function(new_params) {
|
||||
}
|
||||
attr(callback, 'is_pre_iteration') <- TRUE
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.reset_parameters'
|
||||
attr(callback, 'name') <- 'cb.reset.parameters'
|
||||
callback
|
||||
}
|
||||
|
||||
@@ -246,15 +257,15 @@ cb.reset_parameters <- function(new_params) {
|
||||
#' This callback function determines the condition for early stopping
|
||||
#' by setting the \code{stop_condition = TRUE} flag in its calling frame.
|
||||
#'
|
||||
#' The following additional fields are assigned to the model R object:
|
||||
#' The following additional fields are assigned to the model's R object:
|
||||
#' \itemize{
|
||||
#' \item \code{best_score} the evaluation score at the best iteration
|
||||
#' \item \code{best_iteration} at which boosting iteration the best score has occurred (1-based index)
|
||||
#' \item \code{best_ntreelimit} to use with the \code{ntreelimit} parameter in \code{predict}.
|
||||
#' It differs from \code{best_iteration} in multiclass or random forest settings.
|
||||
#' It differs from \code{best_iteration} in multiclass or random forest settings.
|
||||
#' }
|
||||
#'
|
||||
#' The Same values are also stored as xgb-attributes, however:
|
||||
#' The Same values are also stored as xgb-attributes:
|
||||
#' \itemize{
|
||||
#' \item \code{best_iteration} is stored as a 0-based iteration index (for interoperability of binary models)
|
||||
#' \item \code{best_msg} message string is also stored.
|
||||
@@ -266,22 +277,22 @@ cb.reset_parameters <- function(new_params) {
|
||||
#' \code{stop_condition},
|
||||
#' \code{bst_evaluation},
|
||||
#' \code{rank},
|
||||
#' \code{bst} or \code{bst_folds},
|
||||
#' \code{bst} (or \code{bst_folds} and \code{basket}),
|
||||
#' \code{iteration},
|
||||
#' \code{begin_iteration},
|
||||
#' \code{end_iteration},
|
||||
#' \code{num_parallel_tree},
|
||||
#' \code{num_class}.
|
||||
#' \code{num_parallel_tree}.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{callbacks}},
|
||||
#' \code{\link{xgb.attr}}
|
||||
#'
|
||||
#' @export
|
||||
cb.early_stop <- function(stopping_rounds, maximize=FALSE,
|
||||
cb.early.stop <- function(stopping_rounds, maximize=FALSE,
|
||||
metric_name=NULL, verbose=TRUE) {
|
||||
# state variables
|
||||
best_iteration <- -1
|
||||
best_ntreelimit <- -1
|
||||
best_score <- Inf
|
||||
best_msg <- NULL
|
||||
metric_idx <- 1
|
||||
@@ -331,24 +342,23 @@ cb.early_stop <- function(stopping_rounds, maximize=FALSE,
|
||||
xgb.attributes(env$bst$handle) <- list(best_iteration = best_iteration - 1,
|
||||
best_score = best_score)
|
||||
}
|
||||
} else if (is.null(env$bst_folds)) {
|
||||
stop("Parent frame has neither 'bst' nor 'bst_folds'")
|
||||
} else if (is.null(env$bst_folds) || is.null(env$basket)) {
|
||||
stop("Parent frame has neither 'bst' nor ('bst_folds' and 'basket')")
|
||||
}
|
||||
}
|
||||
|
||||
finalizer <- function(env) {
|
||||
best_ntreelimit = best_iteration * env$num_parallel_tree * env$num_class
|
||||
if (!is.null(env$bst)) {
|
||||
attr_best_score = as.numeric(xgb.attr(env$bst$handle, 'best_score'))
|
||||
if (best_score != attr_best_score)
|
||||
stop("Inconsistent 'best_score' between the state: ", best_score,
|
||||
stop("Inconsistent 'best_score' values between the closure state: ", best_score,
|
||||
" and the xgb.attr: ", attr_best_score)
|
||||
env$bst$best_score = best_score
|
||||
env$bst$best_iteration = best_iteration
|
||||
env$bst$best_ntreelimit = best_ntreelimit
|
||||
env$bst$best_score = best_score
|
||||
} else {
|
||||
attr(env$bst_folds, 'best_iteration') <- best_iteration
|
||||
attr(env$bst_folds, 'best_ntreelimit') <- best_ntreelimit
|
||||
env$basket$best_iteration <- best_iteration
|
||||
env$basket$best_ntreelimit <- best_ntreelimit
|
||||
}
|
||||
}
|
||||
|
||||
@@ -365,16 +375,17 @@ cb.early_stop <- function(stopping_rounds, maximize=FALSE,
|
||||
if (( maximize && score > best_score) ||
|
||||
(!maximize && score < best_score)) {
|
||||
|
||||
best_msg <<- format_eval_string(i, env$bst_evaluation, env$bst_evaluation_err)
|
||||
best_msg <<- format.eval.string(i, env$bst_evaluation, env$bst_evaluation_err)
|
||||
best_score <<- score
|
||||
best_iteration <<- i
|
||||
best_ntreelimit <<- best_iteration * env$num_parallel_tree
|
||||
# save the property to attributes, so they will occur in checkpoint
|
||||
if (!is.null(env$bst)) {
|
||||
xgb.attributes(env$bst) <- list(
|
||||
best_iteration = best_iteration - 1, # convert to 0-based index
|
||||
best_score = best_score,
|
||||
best_msg = best_msg,
|
||||
best_ntreelimit = best_iteration * env$num_parallel_tree * env$num_class)
|
||||
best_ntreelimit = best_ntreelimit)
|
||||
}
|
||||
} else if (i - best_iteration >= stopping_rounds) {
|
||||
env$stop_condition <- TRUE
|
||||
@@ -384,7 +395,7 @@ cb.early_stop <- function(stopping_rounds, maximize=FALSE,
|
||||
}
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.early_stop'
|
||||
attr(callback, 'name') <- 'cb.early.stop'
|
||||
callback
|
||||
}
|
||||
|
||||
@@ -412,7 +423,7 @@ cb.early_stop <- function(stopping_rounds, maximize=FALSE,
|
||||
#' \code{\link{callbacks}}
|
||||
#'
|
||||
#' @export
|
||||
cb.save_model <- function(save_period = 0, save_name = "xgboost.model") {
|
||||
cb.save.model <- function(save_period = 0, save_name = "xgboost.model") {
|
||||
|
||||
if (save_period < 0)
|
||||
stop("'save_period' cannot be negative")
|
||||
@@ -426,7 +437,80 @@ cb.save_model <- function(save_period = 0, save_name = "xgboost.model") {
|
||||
xgb.save(env$bst, sprintf(save_name, env$iteration))
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.save_model'
|
||||
attr(callback, 'name') <- 'cb.save.model'
|
||||
callback
|
||||
}
|
||||
|
||||
|
||||
#' Callback closure for returning cross-validation based predictions.
|
||||
#'
|
||||
#' @param save_models a flag for whether to save the folds' models.
|
||||
#'
|
||||
#' @details
|
||||
#' This callback function saves predictions for all of the test folds,
|
||||
#' and also allows to save the folds' models.
|
||||
#'
|
||||
#' It is a "finalizer" callback and it uses early stopping information whenever it is available,
|
||||
#' thus it must be run after the early stopping callback if the early stopping is used.
|
||||
#'
|
||||
#' Callback function expects the following values to be set in its calling frame:
|
||||
#' \code{bst_folds},
|
||||
#' \code{basket},
|
||||
#' \code{data},
|
||||
#' \code{end_iteration},
|
||||
#' \code{num_parallel_tree},
|
||||
#' \code{num_class}.
|
||||
#'
|
||||
#' @return
|
||||
#' Predictions are returned inside of the \code{pred} element, which is either a vector or a matrix,
|
||||
#' depending on the number of prediction outputs per data row. The order of predictions corresponds
|
||||
#' to the order of rows in the original dataset. Note that when a custom \code{folds} list is
|
||||
#' provided in \code{xgb.cv}, the predictions would only be returned properly when this list is a
|
||||
#' non-overlapping list of k sets of indices, as in a standard k-fold CV. The predictions would not be
|
||||
#' meaningful when user-profided folds have overlapping indices as in, e.g., random sampling splits.
|
||||
#' When some of the indices in the training dataset are not included into user-provided \code{folds},
|
||||
#' their prediction value would be \code{NA}.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{callbacks}}
|
||||
#'
|
||||
#' @export
|
||||
cb.cv.predict <- function(save_models = FALSE) {
|
||||
|
||||
finalizer <- function(env) {
|
||||
if (is.null(env$basket) || is.null(env$bst_folds))
|
||||
stop("'cb.cv.predict' callback requires 'basket' and 'bst_folds' lists in its calling frame")
|
||||
|
||||
N <- nrow(env$data)
|
||||
pred <- ifelse(env$num_class > 1,
|
||||
matrix(NA_real_, N, env$num_class),
|
||||
rep(NA_real_, N))
|
||||
|
||||
ntreelimit <- NVL(env$basket$best_ntreelimit,
|
||||
env$end_iteration * env$num_parallel_tree)
|
||||
for (fd in env$bst_folds) {
|
||||
pr <- predict(fd$bst, fd$watchlist[[2]], ntreelimit = ntreelimit, reshape = TRUE)
|
||||
if (is.matrix(pred)) {
|
||||
pred[fd$index,] <- pr
|
||||
} else {
|
||||
pred[fd$index] <- pr
|
||||
}
|
||||
}
|
||||
env$basket$pred <- pred
|
||||
if (save_models) {
|
||||
env$basket$models <- lapply(env$bst_folds, function(fd) {
|
||||
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
|
||||
xgb.Booster.check(xgb.handleToBooster(fd$bst), saveraw = TRUE)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
callback <- function(env = parent.frame(), finalize = FALSE) {
|
||||
if (finalize)
|
||||
return(finalizer(env))
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.cv.predict'
|
||||
callback
|
||||
}
|
||||
|
||||
@@ -436,7 +520,7 @@ cb.save_model <- function(save_period = 0, save_name = "xgboost.model") {
|
||||
#
|
||||
|
||||
# Format the evaluation metric string
|
||||
format_eval_string <- function(iter, eval_res, eval_err=NULL) {
|
||||
format.eval.string <- function(iter, eval_res, eval_err=NULL) {
|
||||
if (length(eval_res) == 0)
|
||||
stop('no evaluation results')
|
||||
enames <- names(eval_res)
|
||||
@@ -454,47 +538,68 @@ format_eval_string <- function(iter, eval_res, eval_err=NULL) {
|
||||
}
|
||||
|
||||
# Extract callback names from the list of callbacks
|
||||
callback.names <- function(cb.list) {
|
||||
unlist(lapply(cb.list, function(x) attr(x, 'name')))
|
||||
callback.names <- function(cb_list) {
|
||||
unlist(lapply(cb_list, function(x) attr(x, 'name')))
|
||||
}
|
||||
|
||||
# Extract callback calls from the list of callbacks
|
||||
callback.calls <- function(cb.list) {
|
||||
unlist(lapply(cb.list, function(x) attr(x, 'call')))
|
||||
callback.calls <- function(cb_list) {
|
||||
unlist(lapply(cb_list, function(x) attr(x, 'call')))
|
||||
}
|
||||
|
||||
# Add a callback cb to the list and make sure that
|
||||
# cb.early.stop and cb.cv.predict are at the end of the list
|
||||
# with cb.cv.predict being the last (when present)
|
||||
add.cb <- function(cb_list, cb) {
|
||||
cb_list <- c(cb_list, cb)
|
||||
names(cb_list) <- callback.names(cb_list)
|
||||
if ('cb.early.stop' %in% names(cb_list)) {
|
||||
cb_list <- c(cb_list, cb_list['cb.early.stop'])
|
||||
# this removes only the first one
|
||||
cb_list['cb.early.stop'] <- NULL
|
||||
}
|
||||
if ('cb.cv.predict' %in% names(cb_list)) {
|
||||
cb_list <- c(cb_list, cb_list['cb.cv.predict'])
|
||||
cb_list['cb.cv.predict'] <- NULL
|
||||
}
|
||||
cb_list
|
||||
}
|
||||
|
||||
# Sort callbacks list into categories
|
||||
categorize.callbacks <- function(cb.list) {
|
||||
categorize.callbacks <- function(cb_list) {
|
||||
list(
|
||||
pre_iter = Filter(function(x) {
|
||||
pre <- attr(x, 'is_pre_iteration')
|
||||
!is.null(pre) && pre
|
||||
}, cb.list),
|
||||
}, cb_list),
|
||||
post_iter = Filter(function(x) {
|
||||
pre <- attr(x, 'is_pre_iteration')
|
||||
is.null(pre) || !pre
|
||||
}, cb.list),
|
||||
}, cb_list),
|
||||
finalize = Filter(function(x) {
|
||||
'finalize' %in% names(formals(x))
|
||||
}, cb.list)
|
||||
}, cb_list)
|
||||
)
|
||||
}
|
||||
|
||||
# Check whether all callback functions with names given by 'query.names' are present in the 'cb.list'.
|
||||
has.callbacks <- function(cb.list, query.names) {
|
||||
if (length(cb.list) < length(query.names))
|
||||
# Check whether all callback functions with names given by 'query_names' are present in the 'cb_list'.
|
||||
has.callbacks <- function(cb_list, query_names) {
|
||||
if (length(cb_list) < length(query_names))
|
||||
return(FALSE)
|
||||
if (!is.list(cb.list) ||
|
||||
!all(sapply(cb.list, class) == 'function'))
|
||||
stop('`cb.list`` must be a list of callback functions')
|
||||
cb.names <- callback.names(cb.list)
|
||||
if (!is.character(cb.names) ||
|
||||
length(cb.names) != length(cb.list) ||
|
||||
any(cb.names == ""))
|
||||
stop('All callbacks in the `cb.list` must have a non-empty `name` attribute')
|
||||
if (!is.character(query.names) ||
|
||||
length(query.names) == 0 ||
|
||||
any(query.names == ""))
|
||||
stop('query.names must be a non-empty vector of non-empty character names')
|
||||
return(all(query.names %in% cb.names))
|
||||
if (!is.list(cb_list) ||
|
||||
any(sapply(cb_list, class) != 'function')) {
|
||||
stop('`cb_list`` must be a list of callback functions')
|
||||
}
|
||||
cb_names <- callback.names(cb_list)
|
||||
if (!is.character(cb_names) ||
|
||||
length(cb_names) != length(cb_list) ||
|
||||
any(cb_names == "")) {
|
||||
stop('All callbacks in the `cb_list` must have a non-empty `name` attribute')
|
||||
}
|
||||
if (!is.character(query_names) ||
|
||||
length(query_names) == 0 ||
|
||||
any(query_names == "")) {
|
||||
stop('query_names must be a non-empty vector of non-empty character names')
|
||||
}
|
||||
return(all(query_names %in% cb_names))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user