make it possible to use a list of pre-defined CV folds in xgb.cv

This commit is contained in:
Vadim Khotilovich
2015-04-03 13:24:04 -05:00
parent c03b42054f
commit 31b0e53cd4
3 changed files with 42 additions and 28 deletions

View File

@@ -214,43 +214,45 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = F
#------------------------------------------
# helper functions for cross validation
#
xgb.cv.mknfold <- function(dall, nfold, param, stratified) {
xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) {
if (nfold <= 1) {
stop("nfold must be bigger than 1")
}
randidx <- sample(1 : xgb.numrow(dall))
y <- getinfo(dall, 'label')
if (stratified & length(y) == length(randidx)) {
y <- y[randidx]
# By default assume that y is a classification label,
# and only leave it numeric for the reg:linear objective.
# WARNING: if there would be any other objectives with truly
# numerical labels, they currently would not be treated correctly!
if (param[['objective']] != 'reg:linear') y <- factor(y)
idset <- xgb.createFolds(y, nfold)
} else {
# make simple non-stratified folds
kstep <- length(randidx) %/% nfold
idset <- list()
for (i in 1:(nfold-1)) {
idset[[i]] = randidx[1:kstep]
randidx = setdiff(randidx,idset[[i]])
if(is.null(folds)) {
y <- getinfo(dall, 'label')
randidx <- sample(1 : xgb.numrow(dall))
if (stratified & length(y) == length(randidx)) {
y <- y[randidx]
# By default assume that y is a classification label,
# and only leave it numeric for the reg:linear objective.
# WARNING: if there would be any other objectives with truly
# numerical labels, they currently would not be treated correctly!
if (param[['objective']] != 'reg:linear') y <- factor(y)
folds <- xgb.createFolds(y, nfold)
} else {
# make simple non-stratified folds
kstep <- length(randidx) %/% nfold
folds <- list()
for (i in 1:(nfold-1)) {
folds[[i]] = randidx[1:kstep]
randidx = setdiff(randidx, folds[[i]])
}
folds[[nfold]] = randidx
}
idset[[nfold]] = randidx
}
ret <- list()
for (k in 1:nfold) {
dtest <- slice(dall, idset[[k]])
dtest <- slice(dall, folds[[k]])
didx = c()
for (i in 1:nfold) {
if (i != k) {
didx <- append(didx, idset[[i]])
didx <- append(didx, folds[[i]])
}
}
dtrain <- slice(dall, didx)
bst <- xgb.Booster(param, list(dtrain, dtest))
watchlist = list(train=dtrain, test=dtest)
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=idset[[k]])
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=folds[[k]])
}
return (ret)
}