xgboost/R-package/R/xgboost.R
2024-08-20 13:33:13 +08:00

1090 lines
35 KiB
R

prescreen.parameters <- function(params) {
if (!NROW(params)) {
return(list())
}
if (!is.list(params)) {
stop("'params' must be a list or NULL.")
}
params <- params[!is.null(params)]
if ("num_class" %in% names(params)) {
stop("'num_class' cannot be manually specified for 'xgboost()'. Pass a factor 'y' instead.")
}
if ("process_type" %in% names(params)) {
if (params$process_type != "default") {
stop("Non-default 'process_type' is not supported for 'xgboost()'. Try 'xgb.train()'.")
}
}
return(params)
}
prescreen.objective <- function(objective) {
if (!is.null(objective)) {
if (objective %in% .OBJECTIVES_NON_DEFAULT_MODE()) {
stop(
"Objectives with non-default prediction mode (",
paste(.OBJECTIVES_NON_DEFAULT_MODE(), collapse = ", "),
") are not supported in 'xgboost()'. Try 'xgb.train()'."
)
}
if (!is.character(objective) || length(objective) != 1L || is.na(objective)) {
stop("'objective' must be a single character/string variable.")
}
}
}
process.base.margin <- function(base_margin, nrows, ncols) {
if (!NROW(base_margin)) {
return(NULL)
}
if (is.array(base_margin) && length(dim(base_margin)) > 2) {
stop(
"'base_margin' should not have more than 2 dimensions for any objective (got: ",
length(dim(base_margin)),
" dimensions)."
)
}
if (inherits(base_margin, c("sparseMatrix", "sparseVector"))) {
warning(
"Got a sparse matrix type (class: ",
paste(class(base_margin), collapse = ", "),
") for 'base_margin'. Will convert to dense matrix."
)
base_margin <- as.matrix(base_margin)
}
if (NROW(base_margin) != nrows) {
stop(
"'base_margin' has incorrect number of rows. Expected: ",
nrows,
". Got: ",
NROW(base_margin)
)
}
if (ncols == 1L) {
if (inherits(base_margin, c("matrix", "data.frame"))) {
if (ncol(base_margin) != 1L) {
stop("'base_margin' should be a 1-d vector for the given objective and data.")
}
if (is.data.frame(base_margin)) {
base_margin <- base_margin[[1L]]
} else {
base_margin <- base_margin[, 1L]
}
}
if (!is.numeric(base_margin)) {
base_margin <- as.numeric(base_margin)
}
} else {
supported_multicol <- c("matrix", "data.frame")
if (!inherits(base_margin, supported_multicol)) {
stop(
"'base_margin' should be a matrix with ",
ncols,
" columns for the given objective and data. Got class: ",
paste(class(base_margin), collapse = ", ")
)
}
if (ncol(base_margin) != ncols) {
stop(
"'base_margin' has incorrect number of columns. Expected: ",
ncols,
". Got: ",
ncol(base_margin)
)
}
if (!is.matrix(base_margin)) {
base_margin <- as.matrix(base_margin)
}
}
return(base_margin)
}
process.y.margin.and.objective <- function(
y,
base_margin,
objective,
params
) {
if (!NROW(y)) {
stop("Passed empty 'y'.")
}
if (is.array(y) && length(dim(y)) > 2) {
stop(
"'y' should not have more than 2 dimensions for any objective (got: ",
length(dim(y)),
")."
)
}
if (inherits(y, c("sparseMatrix", "sparseVector"))) {
warning(
"Got a sparse matrix type (class: ",
paste(class(y), collapse = ", "),
") for 'y'. Will convert to dense matrix."
)
y <- as.matrix(y)
}
if (is.character(y)) {
if (!is.vector(y)) {
if (NCOL(y) > 1L) {
stop("Multi-column categorical 'y' is not supported.")
}
y <- as.vector(y)
}
y <- factor(y)
}
if (is.logical(y)) {
if (!is.vector(y)) {
if (NCOL(y) > 1L) {
stop("Multi-column logical/boolean 'y' is not supported.")
}
y <- as.vector(y)
}
y <- factor(y, c(FALSE, TRUE))
}
if (is.factor(y)) {
y_levels <- levels(y)
if (length(y_levels) < 2) {
stop("Factor 'y' has less than 2 levels.")
}
if (length(y_levels) == 2) {
if (is.null(objective)) {
objective <- "binary:logistic"
} else {
if (!(objective %in% .BINARY_CLASSIF_OBJECTIVES())) {
stop(
"Got binary 'y' - supported objectives for this data are: ",
paste(.BINARY_CLASSIF_OBJECTIVES(), collapse = ", "),
". Was passed: ",
objective
)
}
}
if (!is.null(base_margin)) {
base_margin <- process.base.margin(base_margin, length(y), 1)
}
out <- list(
params = list(
objective = objective
),
metadata = list(
y_levels = y_levels,
n_targets = 1
)
)
} else { # length(levels) > 2
if (is.null(objective)) {
objective <- "multi:softprob"
} else {
if (!(objective %in% .MULTICLASS_CLASSIF_OBJECTIVES())) {
stop(
"Got non-binary factor 'y' - supported objectives for this data are: ",
paste(.MULTICLASS_CLASSIF_OBJECTIVES(), collapse = ", "),
". Was passed: ",
objective
)
}
}
if (!is.null(base_margin)) {
base_margin <- process.base.margin(base_margin, length(y), length(y_levels))
}
out <- list(
params = list(
objective = objective,
num_class = length(y_levels)
),
metadata = list(
y_levels = y_levels,
n_targets = length(y_levels)
)
)
}
out$dmatrix_args <- list(
label = as.numeric(y) - 1,
base_margin = base_margin
)
} else if (inherits(y, "Surv")) {
y_attr <- attributes(y)
supported_surv_types <- c("left", "right", "interval")
if (!(y_attr$type %in% supported_surv_types)) {
stop(
"Survival objectives are only supported for types: ",
paste(supported_surv_types, collapse = ", "),
". Was passed: ",
y_attr$type
)
}
if (is.null(objective)) {
objective <- "survival:aft"
} else {
if (y_attr$type == "right") {
if (!(objective %in% .SURVIVAL_RIGHT_CENSORING_OBJECTIVES())) {
stop(
"Got right-censored 'y' variable - supported objectives for this data are: ",
paste(.SURVIVAL_RIGHT_CENSORING_OBJECTIVES(), collapse = ", "),
". Was passed: ",
objective
)
}
} else {
if (!(objective %in% .SURVIVAL_ALL_CENSORING_OBJECTIVES())) {
stop(
"Got ", y_attr$type, "-censored 'y' variable - supported objectives for this data are:",
paste(.SURVIVAL_ALL_CENSORING_OBJECTIVES(), collapse = ", "),
". Was passed: ",
objective
)
}
}
}
if (!is.null(base_margin)) {
base_margin <- process.base.margin(base_margin, nrow(y), 1)
}
out <- list(
params = list(
objective = objective
),
metadata = list(
n_targets = 1
)
)
# Note: the 'Surv' object class that is passed as 'y' might have either 2 or 3 columns
# depending on the type of censoring, and the last column in both cases is the one that
# indicates the observation type (e.g. censored / uncensored).
# In the case of interval censoring, the second column will not always have values with
# infinites filled in. For more information, see the code behind the 'print.Surv' method.
if (objective == "survival:cox") {
# Can only get here when using right censoring
if (y_attr$type != "right") {
stop("Internal error.")
}
out$dmatrix_args <- list(
label = y[, 1L] * (2 * (y[, 2L] - 0.5))
)
} else {
if (y_attr$type == "left") {
lb <- ifelse(
y[, 2L] == 0,
0,
y[, 1L]
)
ub <- y[, 1L]
out$dmatrix_args <- list(
label_lower_bound = lb,
label_upper_bound = ub
)
} else if (y_attr$type == "right") {
lb <- y[, 1L]
ub <- ifelse(
y[, 2L] == 0,
Inf,
y[, 1L]
)
out$dmatrix_args <- list(
label_lower_bound = lb,
label_upper_bound = ub
)
} else if (y_attr$type == "interval") {
out$dmatrix_args <- list(
label_lower_bound = ifelse(y[, 3L] == 2, 0, y[, 1L]),
label_upper_bound = ifelse(
y[, 3L] == 0, Inf,
ifelse(y[, 3L] == 3, y[, 2L], y[, 1L])
)
)
}
if (min(out$dmatrix_args$label_lower_bound) < 0) {
stop("Survival objectives are only defined for non-negative 'y'.")
}
}
out$dmatrix_args$base_margin <- base_margin
} else if (is.vector(y)) {
if (is.null(objective)) {
objective <- "reg:squarederror"
} else if (!(objective %in% .REGRESSION_OBJECTIVES())) {
stop(
"Got numeric 'y' - supported objectives for this data are: ",
paste(.REGRESSION_OBJECTIVES(), collapse = ", "),
". Was passed: ",
objective
)
}
n_targets <- 1L
if (objective == "reg:quantileerror" && NROW(params$quantile_alpha) > 1) {
n_targets <- NROW(params$quantile_alpha)
}
if (!is.null(base_margin)) {
base_margin <- process.base.margin(base_margin, length(y), n_targets)
}
out <- list(
params = list(
objective = objective
),
metadata = list(
n_targets = n_targets
),
dmatrix_args = list(
label = as.numeric(y),
base_margin = base_margin
)
)
} else if (is.data.frame(y)) {
if (ncol(y) == 1L) {
return(process.y.margin.and.objective(y[[1L]], base_margin, objective, params))
}
if (is.null(objective)) {
objective <- "reg:squarederror"
} else if (!(objective %in% .MULTI_TARGET_OBJECTIVES())) {
stop(
"Got multi-column 'y' - supported objectives for this data are: ",
paste(.MULTI_TARGET_OBJECTIVES(), collapse = ", "),
". Was passed: ",
objective
)
}
y_names <- names(y)
y <- lapply(y, function(x) {
if (!inherits(x, c("numeric", "integer"))) {
stop(
"Multi-target 'y' only supports 'numeric' and 'integer' types. Got: ",
paste(class(x), collapse = ", ")
)
}
return(as.numeric(x))
})
y <- as.data.frame(y) |> as.matrix()
if (!is.null(base_margin)) {
base_margin <- process.base.margin(base_margin, length(y), ncol(y))
}
out <- list(
params = list(
objective = objective
),
dmatrix_args = list(
label = y,
base_margin = base_margin
),
metadata = list(
y_names = y_names,
n_targets = ncol(y)
)
)
} else if (is.matrix(y)) {
if (ncol(y) == 1L) {
return(process.y.margin.and.objective(as.vector(y), base_margin, objective, params))
}
if (!is.null(objective) && !(objective %in% .MULTI_TARGET_OBJECTIVES())) {
stop(
"Got multi-column 'y' - supported objectives for this data are: ",
paste(.MULTI_TARGET_OBJECTIVES(), collapse = ", "),
". Was passed: ",
objective
)
}
if (is.null(objective)) {
objective <- "reg:squarederror"
}
y_names <- colnames(y)
if (storage.mode(y) != "double") {
storage.mode(y) <- "double"
}
if (!is.null(base_margin)) {
base_margin <- process.base.margin(base_margin, nrow(y), ncol(y))
}
out <- list(
params = list(
objective = objective
),
dmatrix_args = list(
label = y,
base_margin = base_margin
),
metadata = list(
n_targets = ncol(y)
)
)
if (NROW(y_names) == ncol(y)) {
out$metadata$y_names <- y_names
}
} else {
stop("Passed 'y' object with unsupported class: ", paste(class(y), collapse = ", "))
}
return(out)
}
process.row.weights <- function(w, lst_args) {
if (!is.null(w)) {
if ("label" %in% names(lst_args$dmatrix_args)) {
nrow_y <- NROW(lst_args$dmatrix_args$label)
} else if ("label_lower_bound" %in% names(lst_args$dmatrix_args)) {
nrow_y <- length(lst_args$dmatrix_args$label_lower_bound)
} else {
stop("Internal error.")
}
if (!is.numeric(w)) {
w <- as.numeric(w)
}
if (length(w) != nrow_y) {
stop(
"'weights' must be a 1-d vector with the same length as 'y' (",
length(w), " vs. ", nrow_y, ")."
)
}
lst_args$dmatrix_args$weight <- w
}
return(lst_args)
}
check.nthreads <- function(nthreads) {
if (is.null(nthreads)) {
return(1L)
}
if (!inherits(nthreads, c("numeric", "integer")) || !NROW(nthreads)) {
stop("'nthreads' must be a positive scalar value.")
}
if (length(nthreads) > 1L) {
nthreads <- utils::head(nthreads, 1L)
}
if (is.na(nthreads) || nthreads < 0) {
stop("Passed invalid 'nthreads': ", nthreads)
}
if (is.numeric(nthreads)) {
if (floor(nthreads) != nthreads) {
stop("'nthreads' must be an integer.")
}
}
return(as.integer(nthreads))
}
check.can.use.qdm <- function(x, params) {
if ("booster" %in% names(params)) {
if (params$booster == "gblinear") {
return(FALSE)
}
}
if ("tree_method" %in% names(params)) {
if (params$tree_method %in% c("exact", "approx")) {
return(FALSE)
}
}
return(TRUE)
}
process.x.and.col.args <- function(
x,
monotone_constraints,
interaction_constraints,
feature_weights,
lst_args,
use_qdm
) {
if (is.null(x)) {
stop("'x' cannot be NULL.")
}
if (inherits(x, "xgb.DMatrix")) {
stop("Cannot pass 'xgb.DMatrix' as 'x' to 'xgboost()'. Try 'xgb.train()' instead.")
}
supported_x_types <- c("data.frame", "matrix", "dgTMatrix", "dgCMatrix", "dgRMatrix")
if (!inherits(x, supported_x_types)) {
stop(
"'x' must be one of the following classes: ",
paste(supported_x_types, collapse = ", "),
". Got: ",
paste(class(x), collapse = ", ")
)
}
if (use_qdm && inherits(x, "sparseMatrix") && !inherits(x, "dgRMatrix")) {
x <- methods::as(x, "RsparseMatrix")
if (!inherits(x, "RsparseMatrix")) {
stop("Internal error: casting sparse matrix did not yield 'dgRMatrix'.")
}
}
if (NROW(feature_weights)) {
if (is.list(feature_weights)) {
feature_weights <- unlist(feature_weights)
}
if (!inherits(feature_weights, c("numeric", "integer"))) {
stop("'feature_weights' must be a numeric vector or named list matching to columns of 'x'.")
}
if (NROW(names(feature_weights)) && NROW(colnames(x))) {
matched <- match(colnames(x), names(feature_weights))
matched <- matched[!is.na(matched)]
matched <- matched[!duplicated(matched)]
if (length(matched) > 0 && length(matched) < length(feature_weights)) {
stop(
"'feature_weights' names do not contain all columns of 'x'. Missing: ",
utils::head(setdiff(colnames(x), names(feature_weights)))
)
}
if (length(matched)) {
feature_weights <- feature_weights[matched]
} else {
warning("Names of 'feature_weights' do not match with 'x'. Names will be ignored.")
}
}
lst_args$dmatrix_args$feature_weights <- unname(feature_weights)
}
if (NROW(monotone_constraints)) {
if (NROW(monotone_constraints) > ncol(x)) {
stop(
"'monotone_constraints' contains more entries than there are columns in 'x' (",
NROW(monotone_constraints), " vs. ", ncol(x), ")."
)
}
if (is.list(monotone_constraints)) {
if (!NROW(names(monotone_constraints))) {
stop(
"If passing 'monotone_constraints' as a named list,",
" must have names matching to columns of 'x'."
)
}
if (!NROW(colnames(x))) {
stop("If passing 'monotone_constraints' as a named list, 'x' must have column names.")
}
if (anyDuplicated(names(monotone_constraints))) {
stop(
"'monotone_constraints' contains duplicated names: ",
paste(
names(monotone_constraints)[duplicated(names(monotone_constraints))] |> utils::head(),
collapse = ", "
)
)
}
if (NROW(setdiff(names(monotone_constraints), colnames(x)))) {
stop(
"'monotone_constraints' contains column names not present in 'x': ",
paste(utils::head(names(monotone_constraints)), collapse = ", ")
)
}
vec_monotone_constr <- rep(0, ncol(x))
matched <- match(names(monotone_constraints), colnames(x))
vec_monotone_constr[matched] <- unlist(monotone_constraints)
lst_args$params$monotone_constraints <- unname(vec_monotone_constr)
} else if (inherits(monotone_constraints, c("numeric", "integer"))) {
if (NROW(names(monotone_constraints)) && NROW(colnames(x))) {
if (length(monotone_constraints) < ncol(x)) {
return(
process.x.and.col.args(
x,
as.list(monotone_constraints),
interaction_constraints,
feature_weights,
lst_args,
use_qdm
)
)
} else {
matched <- match(names(monotone_constraints), colnames(x))
matched <- matched[!is.na(matched)]
matched <- matched[!duplicated(matched)]
if (length(matched)) {
monotone_constraints <- monotone_constraints[matched]
} else {
warning("Names of 'monotone_constraints' do not match with 'x'. Names will be ignored.")
}
}
} else {
if (length(monotone_constraints) != ncol(x)) {
stop(
"If passing 'monotone_constraints' as unnamed vector or not using column names,",
" must have length matching to number of columns in 'x'. Got: ",
length(monotone_constraints), " (vs. ", ncol(x), ")"
)
}
}
lst_args$params$monotone_constraints <- unname(monotone_constraints)
} else if (is.character(monotone_constraints)) {
lst_args$params$monotone_constraints <- monotone_constraints
} else {
stop(
"Passed unsupported type for 'monotone_constraints': ",
paste(class(monotone_constraints), collapse = ", ")
)
}
}
if (NROW(interaction_constraints)) {
if (!is.list(interaction_constraints)) {
stop("'interaction_constraints' must be a list of vectors.")
}
cnames <- colnames(x)
lst_args$params$interaction_constraints <- lapply(interaction_constraints, function(idx) {
if (!NROW(idx)) {
stop("Elements in 'interaction_constraints' cannot be empty.")
}
if (is.character(idx)) {
if (!NROW(cnames)) {
stop(
"Passed a character vector for 'interaction_constraints', but 'x' ",
"has no column names to match them against."
)
}
out <- match(idx, cnames) - 1L
if (anyNA(out)) {
stop(
"'interaction_constraints' contains column names not present in 'x': ",
paste(utils::head(idx[which(is.na(out))]), collapse = ", ")
)
}
return(out)
} else if (inherits(idx, c("numeric", "integer"))) {
if (anyNA(idx)) {
stop("'interaction_constraints' cannot contain NA values.")
}
if (min(idx) < 1) {
stop("Column indices for 'interaction_constraints' must follow base-1 indexing.")
}
if (max(idx) > ncol(x)) {
stop("'interaction_constraints' contains invalid column indices.")
}
if (is.numeric(idx)) {
if (any(idx != floor(idx))) {
stop(
"'interaction_constraints' must contain only integer indices. Got non-integer: ",
paste(utils::head(idx[which(idx != floor(idx))]), collapse = ", ")
)
}
}
return(idx - 1L)
} else {
stop(
"Elements in 'interaction_constraints' must be vectors of types ",
"'integer', 'numeric', or 'character'. Got: ",
paste(class(idx), collapse = ", ")
)
}
})
}
lst_args$dmatrix_args$data <- x
return(lst_args)
}
#' @noMd
#' @export
#' @title Fit XGBoost Model
#' @description Fits an XGBoost model (boosted decision tree ensemble) to given x/y data.
#'
#' See the tutorial \href{https://xgboost.readthedocs.io/en/stable/tutorials/model.html}{
#' Introduction to Boosted Trees} for a longer explanation of what XGBoost does.
#'
#' This function is intended to provide a more user-friendly interface for XGBoost that follows
#' R's conventions for model fitting and predictions, but which doesn't expose all of the
#' possible functionalities of the core XGBoost library.
#'
#' See \link{xgb.train} for a more flexible low-level alternative which is similar across different
#' language bindings of XGBoost and which exposes the full library's functionalities.
#' @details For package authors using `xgboost` as a dependency, it is highly recommended to use
#' \link{xgb.train} in package code instead of `xgboost()`, since it has a more stable interface
#' and performs fewer data conversions and copies along the way.
#' @references \itemize{
#' \item Chen, Tianqi, and Carlos Guestrin. "Xgboost: A scalable tree boosting system."
#' Proceedings of the 22nd acm sigkdd international conference on knowledge discovery and
#' data mining. 2016.
#' \item \url{https://xgboost.readthedocs.io/en/stable/}
#' }
#' @param x The features / covariates. Can be passed as:\itemize{
#' \item A numeric or integer `matrix`.
#' \item A `data.frame`, in which all columns are one of the following types:\itemize{
#' \item `numeric`
#' \item `integer`
#' \item `logical`
#' \item `factor`
#' }
#'
#' Columns of `factor` type will be assumed to be categorical, while other column types will
#' be assumed to be numeric.
#' \item A sparse matrix from the `Matrix` package, either as `dgCMatrix` or `dgRMatrix` class.
#' }
#'
#' Note that categorical features are only supported for `data.frame` inputs, and are automatically
#' determined based on their types. See \link{xgb.train} with \link{xgb.DMatrix} for more flexible
#' variants that would allow something like categorical features on sparse matrices.
#' @param y The response variable. Allowed values are:\itemize{
#' \item A numeric or integer vector (for regression tasks).
#' \item A factor or character vector (for binary and multi-class classification tasks).
#' \item A logical (boolean) vector (for binary classification tasks).
#' \item A numeric or integer matrix or `data.frame` with numeric/integer columns
#' (for multi-task regression tasks).
#' \item A `Surv` object from the `survival` package (for survival tasks).
#' }
#'
#' If `objective` is `NULL`, the right task will be determined automatically based on
#' the class of `y`.
#'
#' If `objective` is not `NULL`, it must match with the type of `y` - e.g. `factor` types of `y`
#' can only be used with classification objectives and vice-versa.
#'
#' For binary classification, the last factor level of `y` will be used as the "positive"
#' class - that is, the numbers from `predict` will reflect the probabilities of belonging to this
#' class instead of to the first factor level. If `y` is a `logical` vector, then `TRUE` will be
#' set as the last level.
#' @param objective Optimization objective to minimize based on the supplied data, to be passed
#' by name as a string / character (e.g. `reg:absoluteerror`). See the
#' \href{https://xgboost.readthedocs.io/en/stable/parameter.html#learning-task-parameters}{
#' Learning Task Parameters} page for more detailed information on allowed values.
#'
#' If `NULL` (the default), will be automatically determined from `y` according to the following
#' logic:\itemize{
#' \item If `y` is a factor with 2 levels, will use `binary:logistic`.
#' \item If `y` is a factor with more than 2 levels, will use `multi:softprob` (number of classes
#' will be determined automatically, should not be passed under `params`).
#' \item If `y` is a `Surv` object from the `survival` package, will use `survival:aft` (note that
#' the only types supported are left / right / interval censored).
#' \item Otherwise, will use `reg:squarederror`.
#' }
#'
#' If `objective` is not `NULL`, it must match with the type of `y` - e.g. `factor` types of `y`
#' can only be used with classification objectives and vice-versa.
#'
#' Note that not all possible `objective` values supported by the core XGBoost library are allowed
#' here - for example, objectives which are a variation of another but with a different default
#' prediction type (e.g. `multi:softmax` vs. `multi:softprob`) are not allowed, and neither are
#' ranking objectives, nor custom objectives at the moment.
#' @param nrounds Number of boosting iterations / rounds.
#'
#' Note that the number of default boosting rounds here is not automatically tuned, and different
#' problems will have vastly different optimal numbers of boosting rounds.
#' @param weights Sample weights for each row in `x` and `y`. If `NULL` (the default), each row
#' will have the same weight.
#'
#' If not `NULL`, should be passed as a numeric vector with length matching to the number of
#' rows in `x`.
#' @param verbosity Verbosity of printing messages. Valid values of 0 (silent), 1 (warning),
#' 2 (info), and 3 (debug).
#' @param nthreads Number of parallel threads to use. If passing zero, will use all CPU threads.
#' @param seed Seed to use for random number generation. If passing `NULL`, will draw a random
#' number using R's PRNG system to use as seed.
#' @param monotone_constraints Optional monotonicity constraints for features.
#'
#' Can be passed either as a named list (when `x` has column names), or as a vector. If passed
#' as a vector and `x` has column names, will try to match the elements by name.
#'
#' A value of `+1` for a given feature makes the model predictions / scores constrained to be
#' a monotonically increasing function of that feature (that is, as the value of the feature
#' increases, the model prediction cannot decrease), while a value of `-1` makes it a monotonically
#' decreasing function. A value of zero imposes no constraint.
#'
#' The input for `monotone_constraints` can be a subset of the columns of `x` if named, in which
#' case the columns that are not referred to in `monotone_constraints` will be assumed to have
#' a value of zero (no constraint imposed on the model for those features).
#'
#' See the tutorial \href{https://xgboost.readthedocs.io/en/stable/tutorials/monotonic.html}{
#' Monotonic Constraints} for a more detailed explanation.
#' @param interaction_constraints Constraints for interaction representing permitted interactions.
#' The constraints must be specified in the form of a list of vectors referencing columns in the
#' data, e.g. `list(c(1, 2), c(3, 4, 5))` (with these numbers being column indices, numeration
#' starting at 1 - i.e. the first sublist references the first and second columns) or
#' `list(c("Sepal.Length", "Sepal.Width"), c("Petal.Length", "Petal.Width"))` (references
#' columns by names), where each vector is a group of indices of features that are allowed to
#' interact with each other.
#'
#' See the tutorial
#' \href{https://xgboost.readthedocs.io/en/stable/tutorials/feature_interaction_constraint.html}{
#' Feature Interaction Constraints} for more information.
#' @param feature_weights Feature weights for column sampling.
#'
#' Can be passed either as a vector with length matching to columns of `x`, or as a named
#' list (only if `x` has column names) with names matching to columns of 'x'. If it is a
#' named vector, will try to match the entries to column names of `x` by name.
#'
#' If `NULL` (the default), all columns will have the same weight.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' If passing it, will start the gradient boosting procedure from the scores that are provided
#' here - for example, one can pass the raw scores from a previous model, or some per-observation
#' offset, or similar.
#'
#' Should be either a numeric vector or numeric matrix (for multi-class and multi-target objectives)
#' with the same number of rows as `x` and number of columns corresponding to number of optimization
#' targets, and should be in the untransformed scale (for example, for objective `binary:logistic`,
#' it should have log-odds, not probabilities; and for objective `multi:softprob`, should have
#' number of columns matching to number of classes in the data).
#'
#' Note that, if it contains more than one column, then columns will not be matched by name to
#' the corresponding `y` - `base_margin` should have the same column order that the model will use
#' (for example, for objective `multi:softprob`, columns of `base_margin` will be matched against
#' `levels(y)` by their position, regardless of what `colnames(base_margin)` returns).
#'
#' If `NULL`, will start from zero, but note that for most objectives, an intercept is usually
#' added (controllable through parameter `base_score` instead) when `base_margin` is not passed.
#' @param ... Other training parameters. See the online documentation
#' \href{https://xgboost.readthedocs.io/en/stable/parameter.html}{XGBoost Parameters} for
#' details about possible values and what they do.
#'
#' Note that not all possible values from the core XGBoost library are allowed as `params` for
#' 'xgboost()' - in particular, values which require an already-fitted booster object (such as
#' `process_type`) are not accepted here.
#' @return A model object, inheriting from both `xgboost` and `xgb.Booster`. Compared to the regular
#' `xgb.Booster` model class produced by \link{xgb.train}, this `xgboost` class will have an
#' additional attribute `metadata` containing information which is used for formatting prediction
#' outputs, such as class names for classification problems.
#' @examples
#' library(xgboost)
#' data(mtcars)
#'
#' # Fit a small regression model on the mtcars data
#' model_regression <- xgboost(mtcars[, -1], mtcars$mpg, nthreads = 1, nrounds = 3)
#' predict(model_regression, mtcars, validate_features = TRUE)
#'
#' # Task objective is determined automatically according to the type of 'y'
#' data(iris)
#' model_classif <- xgboost(iris[, -5], iris$Species, nthreads = 1, nrounds = 5)
#' predict(model_classif, iris, validate_features = TRUE)
xgboost <- function(
x,
y,
objective = NULL,
nrounds = 100L,
weights = NULL,
verbosity = 0L,
nthreads = parallel::detectCores(),
seed = 0L,
monotone_constraints = NULL,
interaction_constraints = NULL,
feature_weights = NULL,
base_margin = NULL,
...
) {
# Note: '...' is a workaround, to be removed later by making all parameters be arguments
params <- list(...)
params <- prescreen.parameters(params)
prescreen.objective(objective)
use_qdm <- check.can.use.qdm(x, params)
lst_args <- process.y.margin.and.objective(y, base_margin, objective, params)
lst_args <- process.row.weights(weights, lst_args)
lst_args <- process.x.and.col.args(
x,
monotone_constraints,
interaction_constraints,
feature_weights,
lst_args,
use_qdm
)
if (use_qdm && "max_bin" %in% names(params)) {
lst_args$dmatrix_args$max_bin <- params$max_bin
}
nthreads <- check.nthreads(nthreads)
lst_args$dmatrix_args$nthread <- nthreads
lst_args$params$nthread <- nthreads
lst_args$params$seed <- seed
params <- c(lst_args$params, params)
fn_dm <- if (use_qdm) xgb.QuantileDMatrix else xgb.DMatrix
dm <- do.call(fn_dm, lst_args$dmatrix_args)
model <- xgb.train(
params = params,
data = dm,
nrounds = nrounds,
verbose = verbosity
)
attributes(model)$metadata <- lst_args$metadata
attributes(model)$call <- match.call()
class(model) <- c("xgboost", class(model))
return(model)
}
#' @method print xgboost
#' @export
print.xgboost <- function(x, ...) {
cat("XGBoost model object\n")
cat("Call:\n ")
print(attributes(x)$call)
cat("Objective: ", attributes(x)$params$objective, "\n", sep = "")
cat("Number of iterations: ", xgb.get.num.boosted.rounds(x), "\n", sep = "")
cat("Number of features: ", xgb.num_feature(x), "\n", sep = "")
printable_head <- function(v) {
v_sub <- utils::head(v, 5L)
return(
sprintf(
"%s%s",
paste(v_sub, collapse = ", "),
ifelse(length(v_sub) < length(v), ", ...", "")
)
)
}
if (NROW(attributes(x)$metadata$y_levels)) {
cat(
"Classes: ",
printable_head(attributes(x)$metadata$y_levels),
"\n",
sep = ""
)
} else if (NROW(attributes(x)$params$quantile_alpha)) {
cat(
"Prediction quantile",
ifelse(length(attributes(x)$params$quantile_alpha) > 1L, "s", ""),
": ",
printable_head(attributes(x)$params$quantile_alpha),
"\n",
sep = ""
)
} else if (NROW(attributes(x)$metadata$y_names)) {
cat(
"Prediction targets: ",
printable_head(attributes(x)$metadata$y_names),
"\n",
sep = ""
)
} else if (attributes(x)$metadata$n_targets > 1L) {
cat(
"Number of predition targets: ",
attributes(x)$metadata$n_targets,
"\n",
sep = ""
)
}
return(x)
}
#' Training part from Mushroom Data Set
#'
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
#'
#' This data set includes the following fields:
#'
#' \itemize{
#' \item \code{label} the label for each record
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' <https://archive.ics.uci.edu/ml/datasets/Mushroom>
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' <http://archive.ics.uci.edu/ml>. Irvine, CA: University of California,
#' School of Information and Computer Science.
#'
#' @docType data
#' @keywords datasets
#' @name agaricus.train
#' @usage data(agaricus.train)
#' @format A list containing a label vector, and a dgCMatrix object with 6513
#' rows and 127 variables
NULL
#' Test part from Mushroom Data Set
#'
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
#'
#' This data set includes the following fields:
#'
#' \itemize{
#' \item \code{label} the label for each record
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' <https://archive.ics.uci.edu/ml/datasets/Mushroom>
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' <http://archive.ics.uci.edu/ml>. Irvine, CA: University of California,
#' School of Information and Computer Science.
#'
#' @docType data
#' @keywords datasets
#' @name agaricus.test
#' @usage data(agaricus.test)
#' @format A list containing a label vector, and a dgCMatrix object with 1611
#' rows and 126 variables
NULL
# Various imports
#' @importClassesFrom Matrix dgCMatrix dgRMatrix CsparseMatrix
#' @importFrom Matrix sparse.model.matrix
#' @importFrom data.table data.table
#' @importFrom data.table is.data.table
#' @importFrom data.table as.data.table
#' @importFrom data.table :=
#' @importFrom data.table rbindlist
#' @importFrom data.table setkey
#' @importFrom data.table setkeyv
#' @importFrom data.table setnames
#' @importFrom jsonlite fromJSON
#' @importFrom jsonlite toJSON
#' @importFrom methods new
#' @importFrom utils object.size str tail
#' @importFrom stats coef
#' @importFrom stats predict
#' @importFrom stats median
#' @importFrom stats sd
#' @importFrom stats variable.names
#' @importFrom utils head
#' @importFrom graphics barplot
#' @importFrom graphics lines
#' @importFrom graphics points
#' @importFrom graphics grid
#' @importFrom graphics par
#' @importFrom graphics title
#' @importFrom grDevices rgb
#'
#' @import methods
#' @useDynLib xgboost, .registration = TRUE
NULL