From 984102e586a3c2af1cc9fc6f15e8526cfbebdec4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 5 Sep 2014 20:34:41 -0700 Subject: [PATCH] style cleanup, incomplete CV --- R-package/R/utils.R | 101 +++++++++++++++++++++++++++++++--------- R-package/R/xgb.cv.R | 57 +++++++++++++++++++++++ R-package/R/xgb.train.R | 38 +++------------ R-package/R/xgboost.R | 21 +++------ 4 files changed, 148 insertions(+), 69 deletions(-) create mode 100644 R-package/R/xgb.cv.R diff --git a/R-package/R/utils.R b/R-package/R/utils.R index da602478a..d979660ca 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -81,20 +81,28 @@ xgb.predict <- function(booster, dmat, outputmargin = FALSE) { ## ----the following are low level iteratively function, not needed if ## you do not want to use them --------------------------------------- - -# iteratively update booster with dtrain -xgb.iter.update <- function(booster, dtrain, iter) { - if (class(booster) != "xgb.Booster") { - stop("xgb.iter.update: first argument must be type xgb.Booster") +# get dmatrix from data, label +xgb.get.DMatrix <- function(data, label = NULL) { + inClass <- class(data) + if (inClass == "dgCMatrix" || inClass == "matrix") { + if (is.null(label)) { + stop("xgboost: need label when data is a matrix") + } + dtrain <- xgb.DMatrix(data, label = label) + } else { + if (!is.null(label)) { + warning("xgboost: label will be ignored.") + } + if (inClass == "character") { + dtrain <- xgb.DMatrix(data) + } else if (inClass == "xgb.DMatrix") { + dtrain <- data + } else { + stop("xgboost: Invalid input of data") + } } - if (class(dtrain) != "xgb.DMatrix") { - stop("xgb.iter.update: second argument must be type xgb.DMatrix") - } - .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain, - PACKAGE = "xgboost") - return(TRUE) + return (dtrain) } - # iteratively update booster with customized statistics xgb.iter.boost <- function(booster, dtrain, gpair) { if (class(booster) != "xgb.Booster") { @@ -108,8 +116,28 @@ xgb.iter.boost <- function(booster, dtrain, gpair) { return(TRUE) } +# iteratively update booster with dtrain +xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) { + if (class(booster) != "xgb.Booster") { + stop("xgb.iter.update: first argument must be type xgb.Booster") + } + if (class(dtrain) != "xgb.DMatrix") { + stop("xgb.iter.update: second argument must be type xgb.DMatrix") + } + + if (is.null(obj)) { + .Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain, + PACKAGE = "xgboost") + } else { + pred <- xgb.predict(bst, dtrain) + gpair <- obj(pred, dtrain) + succ <- xgb.iter.boost(bst, dtrain, gpair) + } + return(TRUE) +} + # iteratively evaluate one iteration -xgb.iter.eval <- function(booster, watchlist, iter) { +xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) { if (class(booster) != "xgb.Booster") { stop("xgb.eval: first argument must be type xgb.Booster") } @@ -122,18 +150,47 @@ xgb.iter.eval <- function(booster, watchlist, iter) { } } if (length(watchlist) != 0) { - evnames <- list() - for (i in 1:length(watchlist)) { - w <- watchlist[i] - if (length(names(w)) == 0) { - stop("xgb.eval: name tag must be presented for every elements in watchlist") + if (is.null(feval)) { + evnames <- list() + for (i in 1:length(watchlist)) { + w <- watchlist[i] + if (length(names(w)) == 0) { + stop("xgb.eval: name tag must be presented for every elements in watchlist") + } + evnames <- append(evnames, names(w)) + } + msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, + evnames, PACKAGE = "xgboost") + } else { + msg <- paste("[", iter, "]", sep="") + for (j in 1:length(watchlist)) { + w <- watchlist[j] + if (length(names(w)) == 0) { + stop("xgb.eval: name tag must be presented for every elements in watchlist") + } + ret <- feval(xgb.predict(bst, w[[1]]), w[[1]]) + msg <- paste(msg, "\t", names(w), "-", ret$metric, ":", ret$value, sep="") } - evnames <- append(evnames, names(w)) } - msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, - evnames, PACKAGE = "xgboost") } else { msg <- "" - } + } return(msg) } +#------------------------------------------ +# helper functions for cross validation +# +xgb.cv.mknfold <- function(dall, nfold, param, metrics=list(), fpreproc = NULL) { + randidx <- sample(1 : xgb.numrow(dall)) + kstep <- length(randidx) / nfold + idset <- list() + for (i in 1:nfold) { + idset = append(idset, randidx[ ((i-1) * kstep + 1) : min(i * kstep, length(randidx)) ]) + } + ret <- list() + for (k in 1:nfold) { + + } + +} + diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R new file mode 100644 index 000000000..089acb838 --- /dev/null +++ b/R-package/R/xgb.cv.R @@ -0,0 +1,57 @@ +#' eXtreme Gradient Boosting Training +#' +#' The training function of xgboost +#' +#' @param params the list of parameters. Commonly used ones are: +#' \itemize{ +#' \item \code{objective} objective function, common ones are +#' \itemize{ +#' \item \code{reg:linear} linear regression +#' \item \code{binary:logistic} logistic regression for classification +#' } +#' \item \code{eta} step size of each boosting step +#' \item \code{max_depth} maximum depth of the tree +#' \item \code{nthread} number of thread used in training, if not set, all threads are used +#' } +#' +#' See \url{https://github.com/tqchen/xgboost/wiki/Parameters} for +#' further details. See also inst/examples/demo.R for walkthrough example in R. +#' @param data takes an \code{xgb.DMatrix} as the input. +#' @param nrounds the max number of iterations +#' @param metrics, list of evaluation metrics to be used in corss validation, +#' when it is not specified, the evaluation metric is chosen according to objective function. +#' Possible options are: +#' \itemize{ +#' \item \code{error} binary classification error rate +#' \item \code{rmse} Rooted mean square error +#' \item \code{logloss} negative log-likelihood function +#' \item \code{auc} Area under curve +#' \item \code{merror} Exact matching error, used to evaluate multi-class classification +#' } +#' +#' @param obj customized objective function. Returns gradient and second order +#' gradient with given prediction and dtrain, +#' @param feval custimized evaluation function. Returns +#' \code{list(metric='metric-name', value='metric-value')} with given +#' prediction and dtrain, +#' @param ... other parameters to pass to \code{params}. +#' +#' @details +#' This is the cross validation function for xgboost +#' +#' Parallelization is automatically enabled if OpenMP is present. +#' Number of threads can also be manually specified via "nthread" parameter. +#' +#' This function only accepts an \code{xgb.DMatrix} object as the input. +#' +#' @export +#' +xgb.cv <- function(params=list(), data, nrounds, metrics=list(), label = NULL, + obj = NULL, feval = NULL, ...) { + if (typeof(params) != "list") { + stop("xgb.cv: first argument params must be list") + } + dtrain <- xgb.get.DMatrix(data, label) + params = append(params, list(...)) + +} diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 58a575d03..d29bad569 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -16,7 +16,7 @@ #' #' See \url{https://github.com/tqchen/xgboost/wiki/Parameters} for #' further details. See also inst/examples/demo.R for walkthrough example in R. -#' @param dtrain takes an \code{xgb.DMatrix} as the input. +#' @param data takes an \code{xgb.DMatrix} as the input. #' @param nrounds the max number of iterations #' @param watchlist what information should be printed when \code{verbose=1} or #' \code{verbose=2}. Watchlist is used to specify validation set monitoring @@ -64,8 +64,9 @@ #' bst <- xgb.train(param, dtrain, nround = 2, watchlist, logregobj, evalerror) #' @export #' -xgb.train <- function(params=list(), dtrain, nrounds, watchlist = list(), +xgb.train <- function(params=list(), data, nrounds, watchlist = list(), obj = NULL, feval = NULL, ...) { + dtrain <- data if (typeof(params) != "list") { stop("xgb.train: first argument params must be list") } @@ -75,37 +76,10 @@ xgb.train <- function(params=list(), dtrain, nrounds, watchlist = list(), params = append(params, list(...)) bst <- xgb.Booster(params, append(watchlist, dtrain)) for (i in 1:nrounds) { - if (is.null(obj)) { - succ <- xgb.iter.update(bst, dtrain, i - 1) - } else { - pred <- xgb.predict(bst, dtrain) - gpair <- obj(pred, dtrain) - succ <- xgb.iter.boost(bst, dtrain, gpair) - } + succ <- xgb.iter.update(bst, dtrain, i - 1, obj) if (length(watchlist) != 0) { - if (is.null(feval)) { - msg <- xgb.iter.eval(bst, watchlist, i - 1) - cat(msg) - cat("\n") - } else { - cat("[") - cat(i) - cat("]") - for (j in 1:length(watchlist)) { - w <- watchlist[j] - if (length(names(w)) == 0) { - stop("xgb.eval: name tag must be presented for every elements in watchlist") - } - ret <- feval(xgb.predict(bst, w[[1]]), w[[1]]) - cat("\t") - cat(names(w)) - cat("-") - cat(ret$metric) - cat(":") - cat(ret$value) - } - cat("\n") - } + msg <- xgb.iter.eval(bst, watchlist, i - 1, feval) + cat(paste(msg, "\n", sep="")) } } return(bst) diff --git a/R-package/R/xgboost.R b/R-package/R/xgboost.R index 6f4633fb8..f3b5c66ec 100644 --- a/R-package/R/xgboost.R +++ b/R-package/R/xgboost.R @@ -40,19 +40,7 @@ #' xgboost <- function(data = NULL, label = NULL, params = list(), nrounds, verbose = 1, ...) { - inClass <- class(data) - if (inClass == "dgCMatrix" || inClass == "matrix") { - if (is.null(label)) - stop("xgboost: need label when data is a matrix") - dtrain <- xgb.DMatrix(data, label = label) - } else { - if (!is.null(label)) - warning("xgboost: label will be ignored.") - if (inClass == "character") - dtrain <- xgb.DMatrix(data) else if (inClass == "xgb.DMatrix") - dtrain <- data else stop("xgboost: Invalid input of data") - } - + dtrain <- xgb.get.DMatrix(data, label) if (verbose > 1) { silent <- 0 } else { @@ -62,8 +50,11 @@ xgboost <- function(data = NULL, label = NULL, params = list(), nrounds, params <- append(params, list(silent = silent)) params <- append(params, list(...)) - if (verbose > 0) - watchlist <- list(train = dtrain) else watchlist <- list() + if (verbose > 0) { + watchlist <- list(train = dtrain) + } else { + watchlist <- list() + } bst <- xgb.train(params, dtrain, nrounds, watchlist)