diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 5aea42373..2dddcc980 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -208,7 +208,7 @@ xgb.cv.mknfold <- function(dall, nfold, param) { return (ret) } xgb.cv.aggcv <- function(res, showsd = TRUE) { - header = res[[1]] + header <- res[[1]] ret <- header[1] for (i in 2:length(header)) { kv <- strsplit(header[i], ":")[[1]] diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index dd0e2c891..9bd0f0468 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -18,6 +18,9 @@ #' further details. See also inst/examples/demo.R for walkthrough example in R. #' @param data takes an \code{xgb.DMatrix} as the input. #' @param nrounds the max number of iterations +#' @param nfold number of folds used +#' @param label option field, when data is Matrix +#' @param showd boolean, whether show standard deviation of cross validation #' @param metrics, list of evaluation metrics to be used in corss validation, #' when it is not specified, the evaluation metric is chosen according to objective function. #' Possible options are: @@ -28,7 +31,6 @@ #' \item \code{auc} Area under curve #' \item \code{merror} Exact matching error, used to evaluate multi-class classification #' } -#' #' @param obj customized objective function. Returns gradient and second order #' gradient with given prediction and dtrain, #' @param feval custimized evaluation function. Returns @@ -47,13 +49,20 @@ #' @export #' xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, - showsd = TRUE, obj = NULL, feval = NULL, ...) { + showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, ...) { if (typeof(params) != "list") { stop("xgb.cv: first argument params must be list") } + if (nfold <= 1) { + stop("nfold must be bigger than 1") + } dtrain <- xgb.get.DMatrix(data, label) params <- append(params, list(...)) params <- append(params, list(silent=1)) + for (mc in metrics) { + params <- append(params, list("eval_metric"=mc)) + } + folds <- xgb.cv.mknfold(dtrain, nfold, params) history <- list() for (i in 1:nrounds) { diff --git a/R-package/inst/examples/cross_validation.R b/R-package/inst/examples/cross_validation.R index abe45354d..b46daa19f 100644 --- a/R-package/inst/examples/cross_validation.R +++ b/R-package/inst/examples/cross_validation.R @@ -3,7 +3,8 @@ require(methods) # Directly read in local file dtrain <- xgb.DMatrix("agaricus.txt.train") -history <- xgb.cv(list("max_depth"=3, "eta"=1, - "objective"="binary:logistic"), - dtrain, nround=3, nfold = 5, "eval_metric"="error") +history <- xgb.cv( data = dtrain, nround=3, nfold = 5, metrics=list("rmse","auc"), + "max_depth"=3, "eta"=1, + "objective"="binary:logistic") +