diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 4a7cb9465..491231a11 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -8,6 +8,7 @@ export(xgb.dump) export(xgb.load) export(xgb.save) export(xgb.train) +export(xgb.cv) export(xgboost) exportMethods(predict) import(methods) diff --git a/R-package/R/utils.R b/R-package/R/utils.R index d979660ca..5aea42373 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -103,6 +103,10 @@ xgb.get.DMatrix <- function(data, label = NULL) { } return (dtrain) } +xgb.numrow <- function(dmat) { + nrow <- .Call("XGDMatrixNumRow_R", dmat, PACKAGE="xgboost") + return(nrow) +} # iteratively update booster with customized statistics xgb.iter.boost <- function(booster, dtrain, gpair) { if (class(booster) != "xgb.Booster") { @@ -174,23 +178,51 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) { } } else { msg <- "" - } + } return(msg) } #------------------------------------------ # helper functions for cross validation # -xgb.cv.mknfold <- function(dall, nfold, param, metrics=list(), fpreproc = NULL) { +xgb.cv.mknfold <- function(dall, nfold, param) { randidx <- sample(1 : xgb.numrow(dall)) kstep <- length(randidx) / nfold idset <- list() for (i in 1:nfold) { - idset = append(idset, randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ]) + idset[[i]] <- randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ] } ret <- list() for (k in 1:nfold) { - + dtest <- slice(dall, idset[[k]]) + didx = c() + for (i in 1:nfold) { + if (i != k) { + didx <- append(didx, idset[[i]]) + } + } + dtrain <- slice(dall, didx) + bst <- xgb.Booster(param, list(dtrain, dtest)) + watchlist = list(train=dtrain, test=dtest) + ret[[k]] <- list(dtrain=dtrain, booster=bst, watchlist=watchlist) } - + return (ret) +} +xgb.cv.aggcv <- function(res, showsd = TRUE) { + header = res[[1]] + ret <- header[1] + for (i in 2:length(header)) { + kv <- strsplit(header[i], ":")[[1]] + ret <- paste(ret, "\t", kv[1], ":", sep="") + stats <- c() + stats[1] <- as.numeric(kv[2]) + for (j in 2:length(res)) { + tkv <- strsplit(res[[j]][i], ":")[[1]] + stats[j] <- as.numeric(tkv[2]) + } + ret <- paste(ret, sprintf("%f", mean(stats)), sep="") + if (showsd) { + ret <- paste(ret, sprintf("+%f", sd(stats)), sep="") + } + } + return (ret) } - diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 089acb838..dd0e2c891 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -46,12 +46,26 @@ #' #' @export #' -xgb.cv <- function(params=list(), data, nrounds, metrics=list(), label = NULL, - obj = NULL, feval = NULL, ...) { +xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, + showsd = TRUE, obj = NULL, feval = NULL, ...) { if (typeof(params) != "list") { stop("xgb.cv: first argument params must be list") } dtrain <- xgb.get.DMatrix(data, label) - params = append(params, list(...)) - + params <- append(params, list(...)) + params <- append(params, list(silent=1)) + folds <- xgb.cv.mknfold(dtrain, nfold, params) + history <- list() + for (i in 1:nrounds) { + msg <- list() + for (k in 1:nfold) { + fd <- folds[[k]] + succ <- xgb.iter.update(fd$booster, fd$dtrain, i - 1, obj) + msg[[k]] <- strsplit(xgb.iter.eval(fd$booster, fd$watchlist, i - 1, feval), "\t")[[1]] + } + ret <- xgb.cv.aggcv(msg, showsd) + history <- append(history, ret) + cat(paste(ret, "\n", sep="")) + } + return (history) } diff --git a/R-package/src/xgboost_R.cpp b/R-package/src/xgboost_R.cpp index bbb99615a..5f3b2dde4 100644 --- a/R-package/src/xgboost_R.cpp +++ b/R-package/src/xgboost_R.cpp @@ -174,6 +174,10 @@ extern "C" { _WrapperEnd(); return ret; } + SEXP XGDMatrixNumRow_R(SEXP handle) { + bst_ulong nrow = XGDMatrixNumRow(R_ExternalPtrAddr(handle)); + return ScalarInteger(static_cast(nrow)); + } // functions related to booster void _BoosterFinalizer(SEXP ext) { if (R_ExternalPtrAddr(ext) == NULL) return; diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index c988ff1e5..9453bf061 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -65,6 +65,11 @@ extern "C" { * \return info vector */ SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field); + /*! + * \brief return number of rows + * \param handle a instance of data matrix + */ + SEXP XGDMatrixNumRow_R(SEXP handle); /*! * \brief create xgboost learner * \param dmats a list of dmatrix handles that will be cached