From 0ecd6c08f370a046bf31bf1a763a48b1332d2f40 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 5 Sep 2014 22:34:32 -0700 Subject: [PATCH 1/3] add cross validation --- R-package/NAMESPACE | 1 + R-package/R/utils.R | 44 ++++++++++++++++++++++++++++++++----- R-package/R/xgb.cv.R | 22 +++++++++++++++---- R-package/src/xgboost_R.cpp | 4 ++++ R-package/src/xgboost_R.h | 5 +++++ 5 files changed, 66 insertions(+), 10 deletions(-) diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 4a7cb9465..491231a11 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -8,6 +8,7 @@ export(xgb.dump) export(xgb.load) export(xgb.save) export(xgb.train) +export(xgb.cv) export(xgboost) exportMethods(predict) import(methods) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index d979660ca..5aea42373 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -103,6 +103,10 @@ xgb.get.DMatrix <- function(data, label = NULL) { } return (dtrain) } +xgb.numrow <- function(dmat) { + nrow <- .Call("XGDMatrixNumRow_R", dmat, PACKAGE="xgboost") + return(nrow) +} # iteratively update booster with customized statistics xgb.iter.boost <- function(booster, dtrain, gpair) { if (class(booster) != "xgb.Booster") { @@ -174,23 +178,51 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) { } } else { msg <- "" - } + } return(msg) } #------------------------------------------ # helper functions for cross validation # -xgb.cv.mknfold <- function(dall, nfold, param, metrics=list(), fpreproc = NULL) { +xgb.cv.mknfold <- function(dall, nfold, param) { randidx <- sample(1 : xgb.numrow(dall)) kstep <- length(randidx) / nfold idset <- list() for (i in 1:nfold) { - idset = append(idset, randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ]) + idset[[i]] <- randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ] } ret <- list() for (k in 1:nfold) { - + dtest <- slice(dall, idset[[k]]) + didx = c() + for (i in 1:nfold) { + if (i != k) { + didx <- append(didx, idset[[i]]) + } + } + dtrain <- slice(dall, didx) + bst <- xgb.Booster(param, list(dtrain, dtest)) + watchlist = list(train=dtrain, test=dtest) + ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist) } - + return (ret) +} +xgb.cv.aggcv <- function(res, showsd = TRUE) { + header = res[[1]] + ret <- header[1] + for (i in 2:length(header)) { + kv <- strsplit(header[i], ":")[[1]] + ret <- paste(ret, "\t", kv[1], ":", sep="") + stats <- c() + stats[1] <- as.numeric(kv[2]) + for (j in 2:length(res)) { + tkv <- strsplit(res[[j]][i], ":")[[1]] + stats[j] <- as.numeric(tkv[2]) + } + ret <- paste(ret, sprintf("%f", mean(stats)), sep="") + if (showsd) { + ret <- paste(ret, sprintf("+%f", sd(stats)), sep="") + } + } + return (ret) } - diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 089acb838..dd0e2c891 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -46,12 +46,26 @@ #' #' @export #' -xgb.cv <- function(params=list(), data, nrounds, metrics=list(), label = NULL, - obj = NULL, feval = NULL, ...) { +xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, + showsd = TRUE, obj = NULL, feval = NULL, ...) { if (typeof(params) != "list") { stop("xgb.cv: first argument params must be list") } dtrain <- xgb.get.DMatrix(data, label) - params = append(params, list(...)) - + params <- append(params, list(...)) + params <- append(params, list(silent=1)) + folds <- xgb.cv.mknfold(dtrain, nfold, params) + history <- list() + for (i in 1:nrounds) { + msg <- list() + for (k in 1:nfold) { + fd <- folds[[k]] + succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj) + msg[[k]] <- strsplit(xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval), "\t")[[1]] + } + ret <- xgb.cv.aggcv(msg, showsd) + history <- append(history, ret) + cat(paste(ret, "\n", sep="")) + } + return (history) } diff --git a/R-package/src/xgboost_R.cpp b/R-package/src/xgboost_R.cpp index bbb99615a..5f3b2dde4 100644 --- a/R-package/src/xgboost_R.cpp +++ b/R-package/src/xgboost_R.cpp @@ -174,6 +174,10 @@ extern "C" { _WrapperEnd(); return ret; } + SEXP XGDMatrixNumRow_R(SEXP handle) { + bst_ulong nrow = XGDMatrixNumRow(R_ExternalPtrAddr(handle)); + return ScalarInteger(static_cast(nrow)); + } // functions related to booster void _BoosterFinalizer(SEXP ext) { if (R_ExternalPtrAddr(ext) == NULL) return; diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index c988ff1e5..9453bf061 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -65,6 +65,11 @@ extern "C" { * \return info vector */ SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field); + /*! + * \brief return number of rows + * \param handle a instance of data matrix + */ + SEXP XGDMatrixNumRow_R(SEXP handle); /*! * \brief create xgboost learner * \param dmats a list of dmatrix handles that will be cached From 831a102d48a33986c0619e7f7a265d974a408a9e Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 5 Sep 2014 22:36:59 -0700 Subject: [PATCH 2/3] add cv --- R-package/inst/examples/cross_validation.R | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 R-package/inst/examples/cross_validation.R diff --git a/R-package/inst/examples/cross_validation.R b/R-package/inst/examples/cross_validation.R new file mode 100644 index 000000000..abe45354d --- /dev/null +++ b/R-package/inst/examples/cross_validation.R @@ -0,0 +1,9 @@ +require(xgboost) +require(methods) +# Directly read in local file +dtrain <- xgb.DMatrix("agaricus.txt.train") + +history <- xgb.cv(list("max_depth"=3, "eta"=1, + "objective"="binary:logistic"), + dtrain, nround=3, nfold = 5, "eval_metric"="error") + From ab238ff8313b053a089abfa2d0d6893909d095d0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 5 Sep 2014 22:46:09 -0700 Subject: [PATCH 3/3] chg cv --- R-package/R/utils.R | 2 +- R-package/R/xgb.cv.R | 13 +++++++++++-- R-package/inst/examples/cross_validation.R | 7 ++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index 5aea42373..2dddcc980 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -208,7 +208,7 @@ xgb.cv.mknfold <- function(dall, nfold, param) { return (ret) } xgb.cv.aggcv <- function(res, showsd = TRUE) { - header = res[[1]] + header <- res[[1]] ret <- header[1] for (i in 2:length(header)) { kv <- strsplit(header[i], ":")[[1]] diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index dd0e2c891..9bd0f0468 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -18,6 +18,9 @@ #' further details. See also inst/examples/demo.R for walkthrough example in R. #' @param data takes an \code{xgb.DMatrix} as the input. #' @param nrounds the max number of iterations +#' @param nfold number of folds used +#' @param label option field, when data is Matrix +#' @param showd boolean, whether show standard deviation of cross validation #' @param metrics, list of evaluation metrics to be used in corss validation, #' when it is not specified, the evaluation metric is chosen according to objective function. #' Possible options are: @@ -28,7 +31,6 @@ #' \item \code{auc} Area under curve #' \item \code{merror} Exact matching error, used to evaluate multi-class classification #' } -#' #' @param obj customized objective function. Returns gradient and second order #' gradient with given prediction and dtrain, #' @param feval custimized evaluation function. Returns @@ -47,13 +49,20 @@ #' @export #' xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, - showsd = TRUE, obj = NULL, feval = NULL, ...) { + showsd = TRUE, metrics=list(), obj = NULL, feval = NULL, ...) { if (typeof(params) != "list") { stop("xgb.cv: first argument params must be list") } + if (nfold <= 1) { + stop("nfold must be bigger than 1") + } dtrain <- xgb.get.DMatrix(data, label) params <- append(params, list(...)) params <- append(params, list(silent=1)) + for (mc in metrics) { + params <- append(params, list("eval_metric"=mc)) + } + folds <- xgb.cv.mknfold(dtrain, nfold, params) history <- list() for (i in 1:nrounds) { diff --git a/R-package/inst/examples/cross_validation.R b/R-package/inst/examples/cross_validation.R index abe45354d..b46daa19f 100644 --- a/R-package/inst/examples/cross_validation.R +++ b/R-package/inst/examples/cross_validation.R @@ -3,7 +3,8 @@ require(methods) # Directly read in local file dtrain <- xgb.DMatrix("agaricus.txt.train") -history <- xgb.cv(list("max_depth"=3, "eta"=1, - "objective"="binary:logistic"), - dtrain, nround=3, nfold = 5, "eval_metric"="error") +history <- xgb.cv( data = dtrain, nround=3, nfold = 5, metrics=list("rmse","auc"), + "max_depth"=3, "eta"=1, + "objective"="binary:logistic") +