From 947f0a926ddf12c66a31b412e8c9261cd16888b8 Mon Sep 17 00:00:00 2001 From: hetong007 Date: Tue, 20 Jan 2015 14:12:45 -0800 Subject: [PATCH] enable returning prediction in cv --- R-package/R/getinfo.xgb.DMatrix.R | 2 +- R-package/R/utils.R | 25 +++++++++++++++++-------- R-package/R/xgb.cv.R | 23 +++++++++++++++++++---- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/R-package/R/getinfo.xgb.DMatrix.R b/R-package/R/getinfo.xgb.DMatrix.R index 53ca5748c..6e291fe62 100644 --- a/R-package/R/getinfo.xgb.DMatrix.R +++ b/R-package/R/getinfo.xgb.DMatrix.R @@ -39,7 +39,7 @@ setMethod("getinfo", signature = "xgb.DMatrix", if (name != "nrow"){ ret <- .Call("XGDMatrixGetInfo_R", object, name, PACKAGE = "xgboost") } else { - ret <- .Call("XGDMatrixNumRow_R", object, PACKAGE = "xgboost") + ret <- xgb.numrow(object) } return(ret) }) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 34ce003db..412132891 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -131,7 +131,7 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { } # iteratively evaluate one iteration -xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) { +xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = FALSE) { if (class(booster) != "xgb.Booster") { stop("xgb.eval: first argument must be type xgb.Booster") } @@ -169,18 +169,27 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) { } else { msg <- "" } + if (prediction){ + preds <- predict(booster,watchlist[[2]]) + return(list(msg,preds)) + } return(msg) -} +} #------------------------------------------ # helper functions for cross validation # xgb.cv.mknfold <- function(dall, nfold, param) { - randidx <- sample(1 : xgb.numrow(dall)) - kstep <- length(randidx) / nfold - idset <- list() - for (i in 1:nfold) { - idset[[i]] <- randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ] + if (nfold <= 1) { + stop("nfold must be bigger than 1") } + randidx <- sample(1 : xgb.numrow(dall)) + kstep <- length(randidx) %/% nfold + idset <- list() + for (i in 1:(nfold-1)) { + idset[[i]] = randidx[1:kstep] + randidx = setdiff(randidx,idset[[i]]) + } + idset[[nfold]] = randidx ret <- list() for (k in 1:nfold) { dtest <- slice(dall, idset[[k]]) @@ -193,7 +202,7 @@ xgb.cv.mknfold <- function(dall, nfold, param) { dtrain <- slice(dall, didx) bst <- xgb.Booster(param, list(dtrain, dtest)) watchlist = list(train=dtrain, test=dtest) - ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist) + ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist, index=idset[[k]]) } return (ret) } diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index b071f08a7..e562610f1 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -31,6 +31,9 @@ #' @param nrounds the max number of iterations #' @param nfold number of folds used #' @param label option field, when data is Matrix +#' @param missing Missing is only used when input is dense matrix, pick a float +# value that represents missing value. Sometime a data use 0 or other extreme value to represents missing values. +#' @param prediction A logical value indicating whether to return the prediction vector. #' @param showsd \code{boolean}, whether show standard deviation of cross validation #' @param metrics, list of evaluation metrics to be used in corss validation, #' when it is not specified, the evaluation metric is chosen according to objective function. @@ -71,7 +74,8 @@ #' @export #' xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL, - showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, verbose = T,...) { + prediction = FALSE, showsd = TRUE, metrics=list(), + obj = NULL, feval = NULL, verbose = T,...) { if (typeof(params) != "list") { stop("xgb.cv: first argument params must be list") } @@ -90,13 +94,20 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = } folds <- xgb.cv.mknfold(dtrain, nfold, params) + predictValues <- rep(0,xgb.numrow(dtrain)) history <- c() for (i in 1:nrounds) { msg <- list() for (k in 1:nfold) { fd <- folds[[k]] - succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj) - msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]] + succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj) + if (!prediction){ + msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]] + } else { + res <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval, prediction) + predictValues[fd$index] <- res[[2]] + msg[[k]] <- res[[1]] %>% str_split("\t") %>% .[[1]] + } } ret <- xgb.cv.aggcv(msg, showsd) history <- c(history, ret) @@ -115,5 +126,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = split <- str_split(string = history, pattern = "\t") for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d*\\.+\\d*") %>% unlist %>% as.list %>% {vec <- .; rbindlist(list(dt, vec), use.names = F, fill = F)} - dt + + if (prediction) { + return(list(dt,predictValues)) + } + return(dt) } \ No newline at end of file