From a524a51a06f979e5ea4d6a14e5c6368c88549dd0 Mon Sep 17 00:00:00 2001 From: El Potaeto Date: Thu, 1 Jan 2015 16:05:43 +0100 Subject: [PATCH] return history as data.table for cross validation + documentation --- R-package/NAMESPACE | 1 + R-package/R/xgb.cv.R | 23 ++++++++++++++++++----- R-package/man/xgb.cv.Rd | 3 +++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 1714d2044..7e0bfa8ac 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -18,5 +18,6 @@ importClassesFrom(Matrix,dgCMatrix) importClassesFrom(Matrix,dgeMatrix) importFrom(data.table,":=") importFrom(data.table,data.table) +importFrom(data.table,rbindlist) importFrom(magrittr,"%>%") importFrom(stringr,str_extract) diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 02870b772..3a9fd9b86 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -1,7 +1,12 @@ #' Cross Validation #' #' The cross valudation function of xgboost -#' +#' +#' @importFrom data.table data.table +#' @importFrom magrittr %>% +#' @importFrom data.table := +#' @importFrom data.table rbindlist +#' @importFrom stringr str_extract #' @param params the list of parameters. Commonly used ones are: #' \itemize{ #' \item \code{objective} objective function, common ones are @@ -40,6 +45,8 @@ # value that represents missing value. Sometime a data use 0 or other extreme value to represents missing values. #' @param ... other parameters to pass to \code{params}. #' +#' @return a \code{data.table} with each mean and standard deviation stat for training set and test set. +#' #' @details #' This is the cross validation function for xgboost #' @@ -88,9 +95,15 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = history <- c(history, ret) cat(paste(ret, "\n", sep="")) } - return (history) + + dt <- data.table(train_rmse_mean=numeric(), train_rmse_std=numeric(), train_auc_mean=numeric(), train_auc_std=numeric(), test_rmse_mean=numeric(), test_rmse_std=numeric(), test_auc_mean=numeric(), test_auc_std=numeric()) + + split = str_split(string = history, pattern = "\t") + for(line in split){ + dt <- line[2:length(line)] %>% str_extract_all(pattern = "\\d.\\d*") %>% unlist %>% as.list %>% {vec <- .;rbindlist(list(dt, vec), use.names = F, fill = F)} + } + dt } -xgb.cv.strip.numeric <- function(x) { - as.numeric(strsplit(regmatches(x, regexec("test-(.*):(.*)$", x))[[1]][3], "\\+")[[1]]) -} + + diff --git a/R-package/man/xgb.cv.Rd b/R-package/man/xgb.cv.Rd index 271182625..19f04ee79 100644 --- a/R-package/man/xgb.cv.Rd +++ b/R-package/man/xgb.cv.Rd @@ -56,6 +56,9 @@ prediction and dtrain,} \item{...}{other parameters to pass to \code{params}.} } +\value{ +a \code{data.table} with each mean and standard deviation stat for training set and test set. +} \description{ The cross valudation function of xgboost }