enable returning prediction in cv

This commit is contained in:
hetong007
2015-01-20 14:12:45 -08:00
parent 89d5e67b78
commit 947f0a926d
3 changed files with 37 additions and 13 deletions

View File

@@ -31,6 +31,9 @@
#' @param nrounds the max number of iterations
#' @param nfold number of folds used
#' @param label option field, when data is Matrix
#' @param missing Missing is only used when input is dense matrix, pick a float
# value that represents missing value. Sometime a data use 0 or other extreme value to represents missing values.
#' @param prediction A logical value indicating whether to return the prediction vector.
#' @param showsd \code{boolean}, whether show standard deviation of cross validation
#' @param metrics, list of evaluation metrics to be used in corss validation,
#' when it is not specified, the evaluation metric is chosen according to objective function.
@@ -71,7 +74,8 @@
#' @export
#'
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, verbose = T,...) {
prediction = FALSE, showsd = TRUE, metrics=list(),
obj = NULL, feval = NULL, verbose = T,...) {
if (typeof(params) != "list") {
stop("xgb.cv: first argument params must be list")
}
@@ -90,13 +94,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
}
folds <- xgb.cv.mknfold(dtrain, nfold, params)
predictValues <- rep(0,xgb.numrow(dtrain))
history <- c()
for (i in 1:nrounds) {
msg <- list()
for (k in 1:nfold) {
fd <- folds[[k]]
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
if (!prediction){
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
} else {
res <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction)
predictValues[fd$index] <- res[[2]]
msg[[k]] <- res[[1]] %>% str_split("\t") %>% .[[1]]
}
}
ret <- xgb.cv.aggcv(msg, showsd)
history <- c(history, ret)
@@ -115,5 +126,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
split <- str_split(string = history, pattern = "\t")
for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.list %>% {vec <- .; rbindlist(list(dt, vec), use.names = F, fill = F)}
dt
if (prediction) {
return(list(dt,predictValues))
}
return(dt)
}