116 lines
3.8 KiB
R
116 lines
3.8 KiB
R
#' eXtreme Gradient Boosting (Tree) library
|
|
#'
|
|
#' A simple interface for xgboost in R
|
|
#'
|
|
#' @param data takes \code{matrix}, \code{dgCMatrix}, local data file or
|
|
#' \code{xgb.DMatrix}.
|
|
#' @param label the response variable. User should not set this field,
|
|
# if data is local data file or \code{xgb.DMatrix}.
|
|
#' @param params the list of parameters. Commonly used ones are:
|
|
#' \itemize{
|
|
#' \item \code{objective} objective function, common ones are
|
|
#' \itemize{
|
|
#' \item \code{reg:linear} linear regression
|
|
#' \item \code{binary:logistic} logistic regression for classification
|
|
#' }
|
|
#' \item \code{eta} step size of each boosting step
|
|
#' \item \code{max_depth} maximum depth of the tree
|
|
#' \item \code{nthread} number of thread used in training, if not set, all threads are used
|
|
#' }
|
|
#'
|
|
#' See \url{https://github.com/tqchen/xgboost/wiki/Parameters} for
|
|
#' further details. See also inst/examples/demo.R for walkthrough example in R.
|
|
#' @param nrounds the max number of iterations
|
|
#' @param verbose If 0, xgboost will stay silent. If 1, xgboost will print
|
|
#' information of performance. If 2, xgboost will print information of both
|
|
#' performance and construction progress information
|
|
#' @param ... other parameters to pass to \code{params}.
|
|
#'
|
|
#' @details
|
|
#' This is the modeling function for xgboost.
|
|
#'
|
|
#' Parallelization is automatically enabled if OpenMP is present.
|
|
#' Number of threads can also be manually specified via "nthread" parameter
|
|
#'
|
|
#' @examples
|
|
#' data(agaricus.train, package='xgboost')
|
|
#' data(agaricus.test, package='xgboost')
|
|
#' train <- agaricus.train
|
|
#' test <- agaricus.test
|
|
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
|
#' eta = 1, nround = 2,objective = "binary:logistic")
|
|
#' pred <- predict(bst, test$data)
|
|
#'
|
|
#' @export
|
|
#'
|
|
xgboost <- function(data = NULL, label = NULL, params = list(), nrounds,
|
|
verbose = 1, ...) {
|
|
dtrain <- xgb.get.DMatrix(data, label)
|
|
params <- append(params, list(...))
|
|
|
|
if (verbose > 0) {
|
|
watchlist <- list(train = dtrain)
|
|
} else {
|
|
watchlist <- list()
|
|
}
|
|
|
|
bst <- xgb.train(params, dtrain, nrounds, watchlist, verbose=verbose)
|
|
|
|
return(bst)
|
|
}
|
|
|
|
|
|
#' Training part from Mushroom Data Set
|
|
#'
|
|
#' This data set is originally from the Mushroom data set,
|
|
#' UCI Machine Learning Repository.
|
|
#'
|
|
#' This data set includes the following fields:
|
|
#'
|
|
#' \itemize{
|
|
#' \item \code{label} the label for each record
|
|
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 127 columns.
|
|
#' }
|
|
#'
|
|
#' @references
|
|
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
|
|
#'
|
|
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
|
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
|
#' School of Information and Computer Science.
|
|
#'
|
|
#' @docType data
|
|
#' @keywords datasets
|
|
#' @name agaricus.train
|
|
#' @usage data(agaricus.train)
|
|
#' @format A list containing a label vector, and a dgCMatrix object with 6513
|
|
#' rows and 127 variables
|
|
NULL
|
|
|
|
#' Test part from Mushroom Data Set
|
|
#'
|
|
#' This data set is originally from the Mushroom data set,
|
|
#' UCI Machine Learning Repository.
|
|
#'
|
|
#' This data set includes the following fields:
|
|
#'
|
|
#' \itemize{
|
|
#' \item \code{label} the label for each record
|
|
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 127 columns.
|
|
#' }
|
|
#'
|
|
#' @references
|
|
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
|
|
#'
|
|
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
|
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
|
#' School of Information and Computer Science.
|
|
#'
|
|
#' @docType data
|
|
#' @keywords datasets
|
|
#' @name agaricus.test
|
|
#' @usage data(agaricus.test)
|
|
#' @format A list containing a label vector, and a dgCMatrix object with 1611
|
|
#' rows and 127 variables
|
|
NULL
|