make it possible to use a list of pre-defined CV folds in xgb.cv
This commit is contained in:
@@ -50,7 +50,9 @@
|
||||
#' @param feval custimized evaluation function. Returns
|
||||
#' \code{list(metric='metric-name', value='metric-value')} with given
|
||||
#' prediction and dtrain.
|
||||
#' @param stratified \code{boolean}, whether sampling of folds should be stratified by the values of labels in \code{data}
|
||||
#' @param stratified \code{boolean} whether sampling of folds should be stratified by the values of labels in \code{data}
|
||||
#' @param folds \code{list} provides a possibility of using a list of pre-defined CV folds (each element must be a vector of fold's indices).
|
||||
#' If folds are supplied, the nfold and stratified parameters would be ignored.
|
||||
#' @param verbose \code{boolean}, print the statistics during the process
|
||||
#' @param ... other parameters to pass to \code{params}.
|
||||
#'
|
||||
@@ -84,10 +86,16 @@
|
||||
#'
|
||||
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
|
||||
prediction = FALSE, showsd = TRUE, metrics=list(),
|
||||
obj = NULL, feval = NULL, stratified = TRUE, verbose = T,...) {
|
||||
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T,...) {
|
||||
if (typeof(params) != "list") {
|
||||
stop("xgb.cv: first argument params must be list")
|
||||
}
|
||||
if(!is.null(folds)) {
|
||||
if(class(folds)!="list" | length(folds) < 2) {
|
||||
stop("folds must be a list with 2 or more elements that are vectors of indices for each CV-fold")
|
||||
}
|
||||
nfold <- length(folds)
|
||||
}
|
||||
if (nfold <= 1) {
|
||||
stop("nfold must be bigger than 1")
|
||||
}
|
||||
@@ -102,7 +110,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
params <- append(params, list("eval_metric"=mc))
|
||||
}
|
||||
|
||||
folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified)
|
||||
xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds)
|
||||
obj_type = params[['objective']]
|
||||
mat_pred = FALSE
|
||||
if (!is.null(obj_type) && obj_type=='multi:softprob')
|
||||
@@ -119,7 +127,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
for (i in 1:nrounds) {
|
||||
msg <- list()
|
||||
for (k in 1:nfold) {
|
||||
fd <- folds[[k]]
|
||||
fd <- xgb_folds[[k]]
|
||||
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
|
||||
if (i<nrounds) {
|
||||
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
|
||||
|
||||
Reference in New Issue
Block a user