@@ -2,7 +2,7 @@
|
||||
#'
|
||||
#' The cross validation function of xgboost
|
||||
#'
|
||||
#' @param params the list of parameters. The complete list of parameters is
|
||||
#' @param params the list of parameters. The complete list of parameters is
|
||||
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
|
||||
#' is a shorter summary:
|
||||
#' \itemize{
|
||||
@@ -137,20 +137,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
# stop("Either 'eval_metric' or 'feval' must be provided for CV")
|
||||
|
||||
# Check the labels
|
||||
if ( (inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
|
||||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
|
||||
if ((inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
|
||||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
|
||||
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
|
||||
} else if (inherits(data, 'xgb.DMatrix')) {
|
||||
if (!is.null(label))
|
||||
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
|
||||
cv_label = getinfo(data, 'label')
|
||||
cv_label <- getinfo(data, 'label')
|
||||
} else {
|
||||
cv_label = label
|
||||
cv_label <- label
|
||||
}
|
||||
|
||||
# CV folds
|
||||
if(!is.null(folds)) {
|
||||
if(!is.list(folds) || length(folds) < 2)
|
||||
if (!is.null(folds)) {
|
||||
if (!is.list(folds) || length(folds) < 2)
|
||||
stop("'folds' must be a list with 2 or more elements that are vectors of indices for each CV-fold")
|
||||
nfold <- length(folds)
|
||||
} else {
|
||||
@@ -165,7 +165,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
|
||||
# verbosity & evaluation printing callback:
|
||||
params <- c(params, list(silent = 1))
|
||||
print_every_n <- max( as.integer(print_every_n), 1L)
|
||||
print_every_n <- max(as.integer(print_every_n), 1L)
|
||||
if (!has.callbacks(callbacks, 'cb.print.evaluation') && verbose) {
|
||||
callbacks <- add.cb(callbacks, cb.print.evaluation(print_every_n, showsd = showsd))
|
||||
}
|
||||
@@ -196,20 +196,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
bst_folds <- lapply(seq_along(folds), function(k) {
|
||||
dtest <- slice(dall, folds[[k]])
|
||||
# code originally contributed by @RolandASc on stackoverflow
|
||||
if(is.null(train_folds))
|
||||
if (is.null(train_folds))
|
||||
dtrain <- slice(dall, unlist(folds[-k]))
|
||||
else
|
||||
dtrain <- slice(dall, train_folds[[k]])
|
||||
handle <- xgb.Booster.handle(params, list(dtrain, dtest))
|
||||
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test=dtest), index = folds[[k]])
|
||||
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]])
|
||||
})
|
||||
rm(dall)
|
||||
# a "basket" to collect some results from callbacks
|
||||
basket <- list()
|
||||
|
||||
# extract parameters that can affect the relationship b/w #trees and #iterations
|
||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1)
|
||||
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1)
|
||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) # nolint
|
||||
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1) # nolint
|
||||
|
||||
# those are fixed for CV (no training continuation)
|
||||
begin_iteration <- 1
|
||||
@@ -226,7 +226,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
})
|
||||
msg <- simplify2array(msg)
|
||||
bst_evaluation <- rowMeans(msg)
|
||||
bst_evaluation_err <- sqrt(rowMeans(msg^2) - bst_evaluation^2)
|
||||
bst_evaluation_err <- sqrt(rowMeans(msg^2) - bst_evaluation^2) # nolint
|
||||
|
||||
for (f in cb$post_iter) f()
|
||||
|
||||
@@ -285,10 +285,10 @@ print.xgb.cv.synchronous <- function(x, verbose = FALSE, ...) {
|
||||
}
|
||||
if (!is.null(x$params)) {
|
||||
cat('params (as set within xgb.cv):\n')
|
||||
cat( ' ',
|
||||
paste(names(x$params),
|
||||
paste0('"', unlist(x$params), '"'),
|
||||
sep = ' = ', collapse = ', '), '\n', sep = '')
|
||||
cat(' ',
|
||||
paste(names(x$params),
|
||||
paste0('"', unlist(x$params), '"'),
|
||||
sep = ' = ', collapse = ', '), '\n', sep = '')
|
||||
}
|
||||
if (!is.null(x$callbacks) && length(x$callbacks) > 0) {
|
||||
cat('callbacks:\n')
|
||||
|
||||
Reference in New Issue
Block a user