enable returning prediction in cv

This commit is contained in:
hetong007 2015-01-20 14:12:45 -08:00
parent 89d5e67b78
commit 947f0a926d
3 changed files with 37 additions and 13 deletions

View File

@ -39,7 +39,7 @@ setMethod("getinfo", signature = "xgb.DMatrix",
if (name != "nrow"){
ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost")
} else {
ret <- .Call("XGDMatrixNumRow_R", object, PACKAGE = "xgboost")
ret <- xgb.numrow(object)
}
return(ret)
})

View File

@ -131,7 +131,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
}
# iteratively evaluate one iteration
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) {
if (class(booster) != "xgb.Booster") {
stop("xgb.eval: first argument must be type xgb.Booster")
}
@ -169,18 +169,27 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
} else {
msg <- ""
}
if (prediction){
preds <- predict(booster,watchlist[[2]])
return(list(msg,preds))
}
return(msg)
}
#------------------------------------------
# helper functions for cross validation
#
xgb.cv.mknfold <- function(dall, nfold, param) {
randidx <- sample(1 : xgb.numrow(dall))
kstep <- length(randidx) / nfold
idset <- list()
for (i in 1:nfold) {
idset[[i]] <- randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ]
if (nfold <= 1) {
stop("nfold must be bigger than 1")
}
randidx <- sample(1 : xgb.numrow(dall))
kstep <- length(randidx) %/% nfold
idset <- list()
for (i in 1:(nfold-1)) {
idset[[i]] = randidx[1:kstep]
randidx = setdiff(randidx,idset[[i]])
}
idset[[nfold]] = randidx
ret <- list()
for (k in 1:nfold) {
dtest <- slice(dall, idset[[k]])
@ -193,7 +202,7 @@ xgb.cv.mknfold <- function(dall, nfold, param) {
dtrain <- slice(dall, didx)
bst <- xgb.Booster(param, list(dtrain, dtest))
watchlist = list(train=dtrain, test=dtest)
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist)
ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=idset[[k]])
}
return (ret)
}

View File

@ -31,6 +31,9 @@
#' @param nrounds the max number of iterations
#' @param nfold number of folds used
#' @param label option field, when data is Matrix
#' @param missing Missing is only used when input is dense matrix, pick a float
# value that represents missing value. Sometime a data use 0 or other extreme value to represents missing values.
#' @param prediction A logical value indicating whether to return the prediction vector.
#' @param showsd \code{boolean}, whether show standard deviation of cross validation
#' @param metrics, list of evaluation metrics to be used in corss validation,
#' when it is not specified, the evaluation metric is chosen according to objective function.
@ -71,7 +74,8 @@
#' @export
#'
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, verbose = T,...) {
prediction = FALSE, showsd = TRUE, metrics=list(),
obj = NULL, feval = NULL, verbose = T,...) {
if (typeof(params) != "list") {
stop("xgb.cv: first argument params must be list")
}
@ -90,13 +94,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
}
folds <- xgb.cv.mknfold(dtrain, nfold, params)
predictValues <- rep(0,xgb.numrow(dtrain))
history <- c()
for (i in 1:nrounds) {
msg <- list()
for (k in 1:nfold) {
fd <- folds[[k]]
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
if (!prediction){
msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
} else {
res <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction)
predictValues[fd$index] <- res[[2]]
msg[[k]] <- res[[1]] %>% str_split("\t") %>% .[[1]]
}
}
ret <- xgb.cv.aggcv(msg, showsd)
history <- c(history, ret)
@ -115,5 +126,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
split <- str_split(string = history, pattern = "\t")
for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.list %>% {vec <- .; rbindlist(list(dt, vec), use.names = F, fill = F)}
dt
if (prediction) {
return(list(dt,predictValues))
}
return(dt)
}