Add a new verbose parameter to print progress during the process (set to true by default to not change behavior of existing code) + source code refactoring

This commit is contained in:
El Potaeto 2015-01-02 11:21:53 +01:00
parent 4d0d65837d
commit cdea1685e5
3 changed files with 18 additions and 18 deletions

View File

@ -25,5 +25,4 @@ importFrom(stringr,str_extract)
importFrom(stringr,str_extract_all) importFrom(stringr,str_extract_all)
importFrom(stringr,str_match) importFrom(stringr,str_match)
importFrom(stringr,str_replace) importFrom(stringr,str_replace)
importFrom(stringr,str_replace_all)
importFrom(stringr,str_split) importFrom(stringr,str_split)

View File

@ -8,8 +8,8 @@
#' @importFrom data.table := #' @importFrom data.table :=
#' @importFrom data.table rbindlist #' @importFrom data.table rbindlist
#' @importFrom stringr str_extract_all #' @importFrom stringr str_extract_all
#' @importFrom stringr str_extract
#' @importFrom stringr str_split #' @importFrom stringr str_split
#' @importFrom stringr str_replace_all
#' @importFrom stringr str_replace #' @importFrom stringr str_replace
#' @importFrom stringr str_match #' @importFrom stringr str_match
#' #'
@ -31,7 +31,7 @@
#' @param nrounds the max number of iterations #' @param nrounds the max number of iterations
#' @param nfold number of folds used #' @param nfold number of folds used
#' @param label option field, when data is Matrix #' @param label option field, when data is Matrix
#' @param showsd boolean, whether show standard deviation of cross validation #' @param showsd \code{boolean}, whether show standard deviation of cross validation
#' @param metrics, list of evaluation metrics to be used in corss validation, #' @param metrics, list of evaluation metrics to be used in corss validation,
#' when it is not specified, the evaluation metric is chosen according to objective function. #' when it is not specified, the evaluation metric is chosen according to objective function.
#' Possible options are: #' Possible options are:
@ -49,9 +49,10 @@
#' prediction and dtrain, #' prediction and dtrain,
#' @param missing Missing is only used when input is dense matrix, pick a float #' @param missing Missing is only used when input is dense matrix, pick a float
# value that represents missing value. Sometime a data use 0 or other extreme value to represents missing values. # value that represents missing value. Sometime a data use 0 or other extreme value to represents missing values.
#' @param verbose \code{boolean}, print the statistics during the process.
#' @param ... other parameters to pass to \code{params}. #' @param ... other parameters to pass to \code{params}.
#' #'
#' @return a \code{data.table} with each mean and standard deviation stat for training set and test set. #' @return A \code{data.table} with each mean and standard deviation stat for training set and test set.
#' #'
#' @details #' @details
#' This is the cross validation function for xgboost #' This is the cross validation function for xgboost
@ -66,10 +67,11 @@
#' dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) #' dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
#' history <- xgb.cv(data = dtrain, nround=3, nfold = 5, metrics=list("rmse","auc"), #' history <- xgb.cv(data = dtrain, nround=3, nfold = 5, metrics=list("rmse","auc"),
#' "max.depth"=3, "eta"=1, "objective"="binary:logistic") #' "max.depth"=3, "eta"=1, "objective"="binary:logistic")
#' print(history)
#' @export #' @export
#' #'
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL, xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, ...) { showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, verbose = T,...) {
if (typeof(params) != "list") { if (typeof(params) != "list") {
stop("xgb.cv: first argument params must be list") stop("xgb.cv: first argument params must be list")
} }
@ -94,28 +96,24 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
for (k in 1:nfold) { for (k in 1:nfold) {
fd <- folds[[k]] fd <- folds[[k]]
succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj) succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj)
msg[[k]] <- strsplit(xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval), msg[[k]] <- xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval) %>% str_split("\t") %>% .[[1]]
"\t")[[1]]
} }
ret <- xgb.cv.aggcv(msg, showsd) ret <- xgb.cv.aggcv(msg, showsd)
history <- c(history, ret) history <- c(history, ret)
cat(paste(ret, "\n", sep="")) if(verbose) paste(ret, "\n", sep="") %>% cat
} }
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace_all("-", ".") colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")
colnamesMean <- paste(colnames, "mean") colnamesMean <- paste(colnames, "mean")
colnamesStd <- paste(colnames, "std") colnamesStd <- paste(colnames, "std")
colnames <- c() colnames <- c()
for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i]) for(i in 1:length(colnamesMean)) colnames <- c(colnames, colnamesMean[i], colnamesStd[i])
type <- rep(x = "numeric", times = length(colnames)) type <- rep(x = "numeric", times = length(colnames))
dt <- read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table dt <- read.table(text = "", colClasses = type, col.names = colnames) %>% as.data.table
split = str_split(string = history, pattern = "\t") split = str_split(string = history, pattern = "\t")
for(line in split){
dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d.\\d*") %>% unlist %>% as.list %>% {vec <- .;rbindlist(list(dt, vec), use.names = F, fill = F)} for(line in split) dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d.\\d*") %>% unlist %>% as.list %>% {vec <- .; rbindlist(list(dt, vec), use.names = F, fill = F)}
}
dt dt
} }

View File

@ -6,7 +6,7 @@
\usage{ \usage{
xgb.cv(params = list(), data, nrounds, nfold, label = NULL, xgb.cv(params = list(), data, nrounds, nfold, label = NULL,
missing = NULL, showsd = TRUE, metrics = list(), obj = NULL, missing = NULL, showsd = TRUE, metrics = list(), obj = NULL,
feval = NULL, ...) feval = NULL, verbose = T, ...)
} }
\arguments{ \arguments{
\item{params}{the list of parameters. Commonly used ones are: \item{params}{the list of parameters. Commonly used ones are:
@ -34,7 +34,7 @@ xgb.cv(params = list(), data, nrounds, nfold, label = NULL,
\item{missing}{Missing is only used when input is dense matrix, pick a float} \item{missing}{Missing is only used when input is dense matrix, pick a float}
\item{showsd}{boolean, whether show standard deviation of cross validation} \item{showsd}{\code{boolean}, whether show standard deviation of cross validation}
\item{metrics,}{list of evaluation metrics to be used in corss validation, \item{metrics,}{list of evaluation metrics to be used in corss validation,
when it is not specified, the evaluation metric is chosen according to objective function. when it is not specified, the evaluation metric is chosen according to objective function.
@ -54,10 +54,12 @@ gradient with given prediction and dtrain,}
\code{list(metric='metric-name', value='metric-value')} with given \code{list(metric='metric-name', value='metric-value')} with given
prediction and dtrain,} prediction and dtrain,}
\item{verbose}{\code{boolean}, print the statistics during the process.}
\item{...}{other parameters to pass to \code{params}.} \item{...}{other parameters to pass to \code{params}.}
} }
\value{ \value{
a \code{data.table} with each mean and standard deviation stat for training set and test set. A \code{data.table} with each mean and standard deviation stat for training set and test set.
} }
\description{ \description{
The cross valudation function of xgboost The cross valudation function of xgboost
@ -75,5 +77,6 @@ data(agaricus.train, package='xgboost')
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
history <- xgb.cv(data = dtrain, nround=3, nfold = 5, metrics=list("rmse","auc"), history <- xgb.cv(data = dtrain, nround=3, nfold = 5, metrics=list("rmse","auc"),
"max.depth"=3, "eta"=1, "objective"="binary:logistic") "max.depth"=3, "eta"=1, "objective"="binary:logistic")
print(history)
} }