Merge pull request #278 from khotilov/custom_loss_cv_fix

Improved logic in stratified CV
This commit is contained in:
Tong He 2015-05-01 14:46:05 -07:00
commit 4ff6697d83

View File

@ -211,6 +211,7 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = F
}
return(msg)
}
#------------------------------------------
# helper functions for cross validation
#
@ -219,15 +220,30 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) {
stop("nfold must be bigger than 1")
}
if(is.null(folds)) {
if (exists('objective', where=param) && strtrim(param[['objective']], 5) == 'rank:') {
stop("\tAutomatic creation of CV-folds is not implemented for ranking!\n",
"\tConsider providing pre-computed CV-folds through the folds parameter.")
}
y <- getinfo(dall, 'label')
randidx <- sample(1 : xgb.numrow(dall))
if (stratified & length(y) == length(randidx)) {
y <- y[randidx]
# By default assume that y is a classification label,
# and only leave it numeric for the reg:linear objective.
# WARNING: if there would be any other objectives with truly
# numerical labels, they currently would not be treated correctly!
if (param[['objective']] != 'reg:linear') y <- factor(y)
#
# WARNING: some heuristic logic is employed to identify classification setting!
#
# For classification, need to convert y labels to factor before making the folds,
# and then do stratification by factor levels.
# For regression, leave y numeric and do stratification by quantiles.
if (exists('objective', where=param)) {
# If 'objective' provided in params, assume that y is a classification label
# unless objective is reg:linear
if (param[['objective']] != 'reg:linear') y <- factor(y)
} else {
# If no 'objective' given in params, it means that user either wants to use
# the default 'reg:linear' objective or has provided a custom obj function.
# Here, assume classification setting when y has 5 or less unique values:
if (length(unique(y)) <= 5) y <- factor(y)
}
folds <- xgb.createFolds(y, nfold)
} else {
# make simple non-stratified folds