enable returning prediction in cv

This commit is contained in:
hetong007
2015-01-20 14:12:45 -08:00
parent 89d5e67b78
commit 947f0a926d
3 changed files with 37 additions and 13 deletions

View File

@@ -131,7 +131,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
}
# iteratively evaluate one iteration
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) {
if (class(booster) != "xgb.Booster") {
stop("xgb.eval: first argument must be type xgb.Booster")
}
@@ -169,18 +169,27 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
} else {
msg <- ""
}
if (prediction){
preds <- predict(booster,watchlist[[2]])
return(list(msg,preds))
}
return(msg)
}
}
#------------------------------------------
# helper functions for cross validation
#
xgb.cv.mknfold <- function(dall, nfold, param) {
randidx <- sample(1 : xgb.numrow(dall))
kstep <- length(randidx) / nfold
idset <- list()
for (i in 1:nfold) {
idset[[i]] <- randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ]
if (nfold <= 1) {
stop("nfold must be bigger than 1")
}
randidx <- sample(1 : xgb.numrow(dall))
kstep <- length(randidx) %/% nfold
idset <- list()
for (i in 1:(nfold-1)) {
idset[[i]] = randidx[1:kstep]
randidx = setdiff(randidx,idset[[i]])
}
idset[[nfold]] = randidx
ret <- list()
for (k in 1:nfold) {
dtest <- slice(dall, idset[[k]])
@@ -193,7 +202,7 @@ xgb.cv.mknfold <- function(dall, nfold, param) {
dtrain <- slice(dall, didx)
bst <- xgb.Booster(param, list(dtrain, dtest))
watchlist = list(train=dtrain, test=dtest)
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist)
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=idset[[k]])
}
return (ret)
}