1089 lines
35 KiB
R
1089 lines
35 KiB
R
prescreen.parameters <- function(params) {
|
|
if (!NROW(params)) {
|
|
return(list())
|
|
}
|
|
if (!is.list(params)) {
|
|
stop("'params' must be a list or NULL.")
|
|
}
|
|
|
|
params <- params[!is.null(params)]
|
|
|
|
if ("num_class" %in% names(params)) {
|
|
stop("'num_class' cannot be manually specified for 'xgboost()'. Pass a factor 'y' instead.")
|
|
}
|
|
if ("process_type" %in% names(params)) {
|
|
if (params$process_type != "default") {
|
|
stop("Non-default 'process_type' is not supported for 'xgboost()'. Try 'xgb.train()'.")
|
|
}
|
|
}
|
|
|
|
return(params)
|
|
}
|
|
|
|
prescreen.objective <- function(objective) {
|
|
if (!is.null(objective)) {
|
|
if (objective %in% .OBJECTIVES_NON_DEFAULT_MODE()) {
|
|
stop(
|
|
"Objectives with non-default prediction mode (",
|
|
paste(.OBJECTIVES_NON_DEFAULT_MODE(), collapse = ", "),
|
|
") are not supported in 'xgboost()'. Try 'xgb.train()'."
|
|
)
|
|
}
|
|
|
|
if (!is.character(objective) || length(objective) != 1L || is.na(objective)) {
|
|
stop("'objective' must be a single character/string variable.")
|
|
}
|
|
}
|
|
}
|
|
|
|
process.base.margin <- function(base_margin, nrows, ncols) {
|
|
if (!NROW(base_margin)) {
|
|
return(NULL)
|
|
}
|
|
if (is.array(base_margin) && length(dim(base_margin)) > 2) {
|
|
stop(
|
|
"'base_margin' should not have more than 2 dimensions for any objective (got: ",
|
|
length(dim(base_margin)),
|
|
" dimensions)."
|
|
)
|
|
}
|
|
if (inherits(base_margin, c("sparseMatrix", "sparseVector"))) {
|
|
warning(
|
|
"Got a sparse matrix type (class: ",
|
|
paste(class(base_margin), collapse = ", "),
|
|
") for 'base_margin'. Will convert to dense matrix."
|
|
)
|
|
base_margin <- as.matrix(base_margin)
|
|
}
|
|
if (NROW(base_margin) != nrows) {
|
|
stop(
|
|
"'base_margin' has incorrect number of rows. Expected: ",
|
|
nrows,
|
|
". Got: ",
|
|
NROW(base_margin)
|
|
)
|
|
}
|
|
|
|
if (ncols == 1L) {
|
|
if (inherits(base_margin, c("matrix", "data.frame"))) {
|
|
if (ncol(base_margin) != 1L) {
|
|
stop("'base_margin' should be a 1-d vector for the given objective and data.")
|
|
}
|
|
if (is.data.frame(base_margin)) {
|
|
base_margin <- base_margin[[1L]]
|
|
} else {
|
|
base_margin <- base_margin[, 1L]
|
|
}
|
|
}
|
|
if (!is.numeric(base_margin)) {
|
|
base_margin <- as.numeric(base_margin)
|
|
}
|
|
} else {
|
|
supported_multicol <- c("matrix", "data.frame")
|
|
if (!inherits(base_margin, supported_multicol)) {
|
|
stop(
|
|
"'base_margin' should be a matrix with ",
|
|
ncols,
|
|
" columns for the given objective and data. Got class: ",
|
|
paste(class(base_margin), collapse = ", ")
|
|
)
|
|
}
|
|
if (ncol(base_margin) != ncols) {
|
|
stop(
|
|
"'base_margin' has incorrect number of columns. Expected: ",
|
|
ncols,
|
|
". Got: ",
|
|
ncol(base_margin)
|
|
)
|
|
}
|
|
if (!is.matrix(base_margin)) {
|
|
base_margin <- as.matrix(base_margin)
|
|
}
|
|
}
|
|
|
|
return(base_margin)
|
|
}
|
|
|
|
process.y.margin.and.objective <- function(
|
|
y,
|
|
base_margin,
|
|
objective,
|
|
params
|
|
) {
|
|
|
|
if (!NROW(y)) {
|
|
stop("Passed empty 'y'.")
|
|
}
|
|
|
|
if (is.array(y) && length(dim(y)) > 2) {
|
|
stop(
|
|
"'y' should not have more than 2 dimensions for any objective (got: ",
|
|
length(dim(y)),
|
|
")."
|
|
)
|
|
}
|
|
|
|
if (inherits(y, c("sparseMatrix", "sparseVector"))) {
|
|
warning(
|
|
"Got a sparse matrix type (class: ",
|
|
paste(class(y), collapse = ", "),
|
|
") for 'y'. Will convert to dense matrix."
|
|
)
|
|
y <- as.matrix(y)
|
|
}
|
|
|
|
if (is.character(y)) {
|
|
if (!is.vector(y)) {
|
|
if (NCOL(y) > 1L) {
|
|
stop("Multi-column categorical 'y' is not supported.")
|
|
}
|
|
y <- as.vector(y)
|
|
}
|
|
y <- factor(y)
|
|
}
|
|
|
|
if (is.logical(y)) {
|
|
if (!is.vector(y)) {
|
|
if (NCOL(y) > 1L) {
|
|
stop("Multi-column logical/boolean 'y' is not supported.")
|
|
}
|
|
y <- as.vector(y)
|
|
}
|
|
y <- factor(y, c(FALSE, TRUE))
|
|
}
|
|
|
|
if (is.factor(y)) {
|
|
|
|
y_levels <- levels(y)
|
|
if (length(y_levels) < 2) {
|
|
stop("Factor 'y' has less than 2 levels.")
|
|
}
|
|
if (length(y_levels) == 2) {
|
|
if (is.null(objective)) {
|
|
objective <- "binary:logistic"
|
|
} else {
|
|
if (!(objective %in% .BINARY_CLASSIF_OBJECTIVES())) {
|
|
stop(
|
|
"Got binary 'y' - supported objectives for this data are: ",
|
|
paste(.BINARY_CLASSIF_OBJECTIVES(), collapse = ", "),
|
|
". Was passed: ",
|
|
objective
|
|
)
|
|
}
|
|
}
|
|
|
|
if (!is.null(base_margin)) {
|
|
base_margin <- process.base.margin(base_margin, length(y), 1)
|
|
}
|
|
|
|
out <- list(
|
|
params = list(
|
|
objective = objective
|
|
),
|
|
metadata = list(
|
|
y_levels = y_levels,
|
|
n_targets = 1
|
|
)
|
|
)
|
|
} else { # length(levels) > 2
|
|
if (is.null(objective)) {
|
|
objective <- "multi:softprob"
|
|
} else {
|
|
if (!(objective %in% .MULTICLASS_CLASSIF_OBJECTIVES())) {
|
|
stop(
|
|
"Got non-binary factor 'y' - supported objectives for this data are: ",
|
|
paste(.MULTICLASS_CLASSIF_OBJECTIVES(), collapse = ", "),
|
|
". Was passed: ",
|
|
objective
|
|
)
|
|
}
|
|
}
|
|
|
|
if (!is.null(base_margin)) {
|
|
base_margin <- process.base.margin(base_margin, length(y), length(y_levels))
|
|
}
|
|
|
|
out <- list(
|
|
params = list(
|
|
objective = objective,
|
|
num_class = length(y_levels)
|
|
),
|
|
metadata = list(
|
|
y_levels = y_levels,
|
|
n_targets = length(y_levels)
|
|
)
|
|
)
|
|
}
|
|
|
|
out$dmatrix_args <- list(
|
|
label = as.numeric(y) - 1,
|
|
base_margin = base_margin
|
|
)
|
|
|
|
} else if (inherits(y, "Surv")) {
|
|
|
|
y_attr <- attributes(y)
|
|
supported_surv_types <- c("left", "right", "interval")
|
|
if (!(y_attr$type %in% supported_surv_types)) {
|
|
stop(
|
|
"Survival objectives are only supported for types: ",
|
|
paste(supported_surv_types, collapse = ", "),
|
|
". Was passed: ",
|
|
y_attr$type
|
|
)
|
|
}
|
|
|
|
if (is.null(objective)) {
|
|
objective <- "survival:aft"
|
|
} else {
|
|
if (y_attr$type == "right") {
|
|
if (!(objective %in% .SURVIVAL_RIGHT_CENSORING_OBJECTIVES())) {
|
|
stop(
|
|
"Got right-censored 'y' variable - supported objectives for this data are: ",
|
|
paste(.SURVIVAL_RIGHT_CENSORING_OBJECTIVES(), collapse = ", "),
|
|
". Was passed: ",
|
|
objective
|
|
)
|
|
}
|
|
} else {
|
|
if (!(objective %in% .SURVIVAL_ALL_CENSORING_OBJECTIVES())) {
|
|
stop(
|
|
"Got ", y_attr$type, "-censored 'y' variable - supported objectives for this data are:",
|
|
paste(.SURVIVAL_ALL_CENSORING_OBJECTIVES(), collapse = ", "),
|
|
". Was passed: ",
|
|
objective
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!is.null(base_margin)) {
|
|
base_margin <- process.base.margin(base_margin, nrow(y), 1)
|
|
}
|
|
|
|
out <- list(
|
|
params = list(
|
|
objective = objective
|
|
),
|
|
metadata = list(
|
|
n_targets = 1
|
|
)
|
|
)
|
|
|
|
# Note: the 'Surv' object class that is passed as 'y' might have either 2 or 3 columns
|
|
# depending on the type of censoring, and the last column in both cases is the one that
|
|
# indicates the observation type (e.g. censored / uncensored).
|
|
# In the case of interval censoring, the second column will not always have values with
|
|
# infinites filled in. For more information, see the code behind the 'print.Surv' method.
|
|
|
|
if (objective == "survival:cox") {
|
|
# Can only get here when using right censoring
|
|
if (y_attr$type != "right") {
|
|
stop("Internal error.")
|
|
}
|
|
|
|
out$dmatrix_args <- list(
|
|
label = y[, 1L] * (2 * (y[, 2L] - 0.5))
|
|
)
|
|
|
|
} else {
|
|
if (y_attr$type == "left") {
|
|
lb <- ifelse(
|
|
y[, 2L] == 0,
|
|
0,
|
|
y[, 1L]
|
|
)
|
|
ub <- y[, 1L]
|
|
out$dmatrix_args <- list(
|
|
label_lower_bound = lb,
|
|
label_upper_bound = ub
|
|
)
|
|
} else if (y_attr$type == "right") {
|
|
lb <- y[, 1L]
|
|
ub <- ifelse(
|
|
y[, 2L] == 0,
|
|
Inf,
|
|
y[, 1L]
|
|
)
|
|
out$dmatrix_args <- list(
|
|
label_lower_bound = lb,
|
|
label_upper_bound = ub
|
|
)
|
|
} else if (y_attr$type == "interval") {
|
|
out$dmatrix_args <- list(
|
|
label_lower_bound = ifelse(y[, 3L] == 2, 0, y[, 1L]),
|
|
label_upper_bound = ifelse(
|
|
y[, 3L] == 0, Inf,
|
|
ifelse(y[, 3L] == 3, y[, 2L], y[, 1L])
|
|
)
|
|
)
|
|
}
|
|
|
|
if (min(out$dmatrix_args$label_lower_bound) < 0) {
|
|
stop("Survival objectives are only defined for non-negative 'y'.")
|
|
}
|
|
}
|
|
|
|
out$dmatrix_args$base_margin <- base_margin
|
|
|
|
} else if (is.vector(y)) {
|
|
|
|
if (is.null(objective)) {
|
|
objective <- "reg:squarederror"
|
|
} else if (!(objective %in% .REGRESSION_OBJECTIVES())) {
|
|
stop(
|
|
"Got numeric 'y' - supported objectives for this data are: ",
|
|
paste(.REGRESSION_OBJECTIVES(), collapse = ", "),
|
|
". Was passed: ",
|
|
objective
|
|
)
|
|
}
|
|
|
|
n_targets <- 1L
|
|
if (objective == "reg:quantileerror" && NROW(params$quantile_alpha) > 1) {
|
|
n_targets <- NROW(params$quantile_alpha)
|
|
}
|
|
|
|
if (!is.null(base_margin)) {
|
|
base_margin <- process.base.margin(base_margin, length(y), n_targets)
|
|
}
|
|
|
|
out <- list(
|
|
params = list(
|
|
objective = objective
|
|
),
|
|
metadata = list(
|
|
n_targets = n_targets
|
|
),
|
|
dmatrix_args = list(
|
|
label = as.numeric(y),
|
|
base_margin = base_margin
|
|
)
|
|
)
|
|
|
|
} else if (is.data.frame(y)) {
|
|
if (ncol(y) == 1L) {
|
|
return(process.y.margin.and.objective(y[[1L]], base_margin, objective, params))
|
|
}
|
|
|
|
if (is.null(objective)) {
|
|
objective <- "reg:squarederror"
|
|
} else if (!(objective %in% .MULTI_TARGET_OBJECTIVES())) {
|
|
stop(
|
|
"Got multi-column 'y' - supported objectives for this data are: ",
|
|
paste(.MULTI_TARGET_OBJECTIVES(), collapse = ", "),
|
|
". Was passed: ",
|
|
objective
|
|
)
|
|
}
|
|
|
|
y_names <- names(y)
|
|
y <- lapply(y, function(x) {
|
|
if (!inherits(x, c("numeric", "integer"))) {
|
|
stop(
|
|
"Multi-target 'y' only supports 'numeric' and 'integer' types. Got: ",
|
|
paste(class(x), collapse = ", ")
|
|
)
|
|
}
|
|
return(as.numeric(x))
|
|
})
|
|
y <- as.data.frame(y) |> as.matrix()
|
|
|
|
if (!is.null(base_margin)) {
|
|
base_margin <- process.base.margin(base_margin, length(y), ncol(y))
|
|
}
|
|
|
|
out <- list(
|
|
params = list(
|
|
objective = objective
|
|
),
|
|
dmatrix_args = list(
|
|
label = y,
|
|
base_margin = base_margin
|
|
),
|
|
metadata = list(
|
|
y_names = y_names,
|
|
n_targets = ncol(y)
|
|
)
|
|
)
|
|
|
|
} else if (is.matrix(y)) {
|
|
if (ncol(y) == 1L) {
|
|
return(process.y.margin.and.objective(as.vector(y), base_margin, objective, params))
|
|
}
|
|
|
|
if (!is.null(objective) && !(objective %in% .MULTI_TARGET_OBJECTIVES())) {
|
|
stop(
|
|
"Got multi-column 'y' - supported objectives for this data are: ",
|
|
paste(.MULTI_TARGET_OBJECTIVES(), collapse = ", "),
|
|
". Was passed: ",
|
|
objective
|
|
)
|
|
}
|
|
if (is.null(objective)) {
|
|
objective <- "reg:squarederror"
|
|
}
|
|
|
|
y_names <- colnames(y)
|
|
if (storage.mode(y) != "double") {
|
|
storage.mode(y) <- "double"
|
|
}
|
|
|
|
if (!is.null(base_margin)) {
|
|
base_margin <- process.base.margin(base_margin, nrow(y), ncol(y))
|
|
}
|
|
|
|
out <- list(
|
|
params = list(
|
|
objective = objective
|
|
),
|
|
dmatrix_args = list(
|
|
label = y,
|
|
base_margin = base_margin
|
|
),
|
|
metadata = list(
|
|
n_targets = ncol(y)
|
|
)
|
|
)
|
|
|
|
if (NROW(y_names) == ncol(y)) {
|
|
out$metadata$y_names <- y_names
|
|
}
|
|
|
|
} else {
|
|
stop("Passed 'y' object with unsupported class: ", paste(class(y), collapse = ", "))
|
|
}
|
|
|
|
return(out)
|
|
}
|
|
|
|
process.row.weights <- function(w, lst_args) {
|
|
if (!is.null(w)) {
|
|
if ("label" %in% names(lst_args$dmatrix_args)) {
|
|
nrow_y <- NROW(lst_args$dmatrix_args$label)
|
|
} else if ("label_lower_bound" %in% names(lst_args$dmatrix_args)) {
|
|
nrow_y <- length(lst_args$dmatrix_args$label_lower_bound)
|
|
} else {
|
|
stop("Internal error.")
|
|
}
|
|
if (!is.numeric(w)) {
|
|
w <- as.numeric(w)
|
|
}
|
|
if (length(w) != nrow_y) {
|
|
stop(
|
|
"'weights' must be a 1-d vector with the same length as 'y' (",
|
|
length(w), " vs. ", nrow_y, ")."
|
|
)
|
|
}
|
|
lst_args$dmatrix_args$weight <- w
|
|
}
|
|
return(lst_args)
|
|
}
|
|
|
|
check.nthreads <- function(nthreads) {
|
|
if (is.null(nthreads)) {
|
|
return(1L)
|
|
}
|
|
if (!inherits(nthreads, c("numeric", "integer")) || !NROW(nthreads)) {
|
|
stop("'nthreads' must be a positive scalar value.")
|
|
}
|
|
if (length(nthreads) > 1L) {
|
|
nthreads <- utils::head(nthreads, 1L)
|
|
}
|
|
if (is.na(nthreads) || nthreads < 0) {
|
|
stop("Passed invalid 'nthreads': ", nthreads)
|
|
}
|
|
if (is.numeric(nthreads)) {
|
|
if (floor(nthreads) != nthreads) {
|
|
stop("'nthreads' must be an integer.")
|
|
}
|
|
}
|
|
return(as.integer(nthreads))
|
|
}
|
|
|
|
check.can.use.qdm <- function(x, params) {
|
|
if ("booster" %in% names(params)) {
|
|
if (params$booster == "gblinear") {
|
|
return(FALSE)
|
|
}
|
|
}
|
|
if ("tree_method" %in% names(params)) {
|
|
if (params$tree_method %in% c("exact", "approx")) {
|
|
return(FALSE)
|
|
}
|
|
}
|
|
return(TRUE)
|
|
}
|
|
|
|
process.x.and.col.args <- function(
|
|
x,
|
|
monotone_constraints,
|
|
interaction_constraints,
|
|
feature_weights,
|
|
lst_args,
|
|
use_qdm
|
|
) {
|
|
if (is.null(x)) {
|
|
stop("'x' cannot be NULL.")
|
|
}
|
|
if (inherits(x, "xgb.DMatrix")) {
|
|
stop("Cannot pass 'xgb.DMatrix' as 'x' to 'xgboost()'. Try 'xgb.train()' instead.")
|
|
}
|
|
supported_x_types <- c("data.frame", "matrix", "dgTMatrix", "dgCMatrix", "dgRMatrix")
|
|
if (!inherits(x, supported_x_types)) {
|
|
stop(
|
|
"'x' must be one of the following classes: ",
|
|
paste(supported_x_types, collapse = ", "),
|
|
". Got: ",
|
|
paste(class(x), collapse = ", ")
|
|
)
|
|
}
|
|
if (use_qdm && inherits(x, "sparseMatrix") && !inherits(x, "dgRMatrix")) {
|
|
x <- methods::as(x, "RsparseMatrix")
|
|
if (!inherits(x, "RsparseMatrix")) {
|
|
stop("Internal error: casting sparse matrix did not yield 'dgRMatrix'.")
|
|
}
|
|
}
|
|
|
|
if (NROW(feature_weights)) {
|
|
if (is.list(feature_weights)) {
|
|
feature_weights <- unlist(feature_weights)
|
|
}
|
|
if (!inherits(feature_weights, c("numeric", "integer"))) {
|
|
stop("'feature_weights' must be a numeric vector or named list matching to columns of 'x'.")
|
|
}
|
|
if (NROW(names(feature_weights)) && NROW(colnames(x))) {
|
|
matched <- match(colnames(x), names(feature_weights))
|
|
matched <- matched[!is.na(matched)]
|
|
matched <- matched[!duplicated(matched)]
|
|
if (length(matched) > 0 && length(matched) < length(feature_weights)) {
|
|
stop(
|
|
"'feature_weights' names do not contain all columns of 'x'. Missing: ",
|
|
utils::head(setdiff(colnames(x), names(feature_weights)))
|
|
)
|
|
}
|
|
if (length(matched)) {
|
|
feature_weights <- feature_weights[matched]
|
|
} else {
|
|
warning("Names of 'feature_weights' do not match with 'x'. Names will be ignored.")
|
|
}
|
|
}
|
|
|
|
lst_args$dmatrix_args$feature_weights <- unname(feature_weights)
|
|
}
|
|
|
|
if (NROW(monotone_constraints)) {
|
|
|
|
if (NROW(monotone_constraints) > ncol(x)) {
|
|
stop(
|
|
"'monotone_constraints' contains more entries than there are columns in 'x' (",
|
|
NROW(monotone_constraints), " vs. ", ncol(x), ")."
|
|
)
|
|
}
|
|
|
|
if (is.list(monotone_constraints)) {
|
|
|
|
if (!NROW(names(monotone_constraints))) {
|
|
stop(
|
|
"If passing 'monotone_constraints' as a named list,",
|
|
" must have names matching to columns of 'x'."
|
|
)
|
|
}
|
|
if (!NROW(colnames(x))) {
|
|
stop("If passing 'monotone_constraints' as a named list, 'x' must have column names.")
|
|
}
|
|
if (anyDuplicated(names(monotone_constraints))) {
|
|
stop(
|
|
"'monotone_constraints' contains duplicated names: ",
|
|
paste(
|
|
names(monotone_constraints)[duplicated(names(monotone_constraints))] |> utils::head(),
|
|
collapse = ", "
|
|
)
|
|
)
|
|
}
|
|
if (NROW(setdiff(names(monotone_constraints), colnames(x)))) {
|
|
stop(
|
|
"'monotone_constraints' contains column names not present in 'x': ",
|
|
paste(utils::head(names(monotone_constraints)), collapse = ", ")
|
|
)
|
|
}
|
|
|
|
vec_monotone_constr <- rep(0, ncol(x))
|
|
matched <- match(names(monotone_constraints), colnames(x))
|
|
vec_monotone_constr[matched] <- unlist(monotone_constraints)
|
|
lst_args$params$monotone_constraints <- unname(vec_monotone_constr)
|
|
|
|
} else if (inherits(monotone_constraints, c("numeric", "integer"))) {
|
|
|
|
if (NROW(names(monotone_constraints)) && NROW(colnames(x))) {
|
|
if (length(monotone_constraints) < ncol(x)) {
|
|
return(
|
|
process.x.and.col.args(
|
|
x,
|
|
as.list(monotone_constraints),
|
|
interaction_constraints,
|
|
feature_weights,
|
|
lst_args,
|
|
use_qdm
|
|
)
|
|
)
|
|
} else {
|
|
matched <- match(names(monotone_constraints), colnames(x))
|
|
matched <- matched[!is.na(matched)]
|
|
matched <- matched[!duplicated(matched)]
|
|
if (length(matched)) {
|
|
monotone_constraints <- monotone_constraints[matched]
|
|
} else {
|
|
warning("Names of 'monotone_constraints' do not match with 'x'. Names will be ignored.")
|
|
}
|
|
}
|
|
} else {
|
|
if (length(monotone_constraints) != ncol(x)) {
|
|
stop(
|
|
"If passing 'monotone_constraints' as unnamed vector or not using column names,",
|
|
" must have length matching to number of columns in 'x'. Got: ",
|
|
length(monotone_constraints), " (vs. ", ncol(x), ")"
|
|
)
|
|
}
|
|
}
|
|
|
|
lst_args$params$monotone_constraints <- unname(monotone_constraints)
|
|
|
|
} else if (is.character(monotone_constraints)) {
|
|
lst_args$params$monotone_constraints <- monotone_constraints
|
|
} else {
|
|
stop(
|
|
"Passed unsupported type for 'monotone_constraints': ",
|
|
paste(class(monotone_constraints), collapse = ", ")
|
|
)
|
|
}
|
|
}
|
|
|
|
if (NROW(interaction_constraints)) {
|
|
if (!is.list(interaction_constraints)) {
|
|
stop("'interaction_constraints' must be a list of vectors.")
|
|
}
|
|
cnames <- colnames(x)
|
|
lst_args$params$interaction_constraints <- lapply(interaction_constraints, function(idx) {
|
|
if (!NROW(idx)) {
|
|
stop("Elements in 'interaction_constraints' cannot be empty.")
|
|
}
|
|
|
|
if (is.character(idx)) {
|
|
if (!NROW(cnames)) {
|
|
stop(
|
|
"Passed a character vector for 'interaction_constraints', but 'x' ",
|
|
"has no column names to match them against."
|
|
)
|
|
}
|
|
out <- match(idx, cnames) - 1L
|
|
if (anyNA(out)) {
|
|
stop(
|
|
"'interaction_constraints' contains column names not present in 'x': ",
|
|
paste(utils::head(idx[which(is.na(out))]), collapse = ", ")
|
|
)
|
|
}
|
|
return(out)
|
|
} else if (inherits(idx, c("numeric", "integer"))) {
|
|
if (anyNA(idx)) {
|
|
stop("'interaction_constraints' cannot contain NA values.")
|
|
}
|
|
if (min(idx) < 1) {
|
|
stop("Column indices for 'interaction_constraints' must follow base-1 indexing.")
|
|
}
|
|
if (max(idx) > ncol(x)) {
|
|
stop("'interaction_constraints' contains invalid column indices.")
|
|
}
|
|
if (is.numeric(idx)) {
|
|
if (any(idx != floor(idx))) {
|
|
stop(
|
|
"'interaction_constraints' must contain only integer indices. Got non-integer: ",
|
|
paste(utils::head(idx[which(idx != floor(idx))]), collapse = ", ")
|
|
)
|
|
}
|
|
}
|
|
return(idx - 1L)
|
|
} else {
|
|
stop(
|
|
"Elements in 'interaction_constraints' must be vectors of types ",
|
|
"'integer', 'numeric', or 'character'. Got: ",
|
|
paste(class(idx), collapse = ", ")
|
|
)
|
|
}
|
|
})
|
|
}
|
|
|
|
lst_args$dmatrix_args$data <- x
|
|
return(lst_args)
|
|
}
|
|
|
|
#' @noMd
|
|
#' @export
|
|
#' @title Fit XGBoost Model
|
|
#' @description Fits an XGBoost model (boosted decision tree ensemble) to given x/y data.
|
|
#'
|
|
#' See the tutorial \href{https://xgboost.readthedocs.io/en/stable/tutorials/model.html}{
|
|
#' Introduction to Boosted Trees} for a longer explanation of what XGBoost does.
|
|
#'
|
|
#' This function is intended to provide a more user-friendly interface for XGBoost that follows
|
|
#' R's conventions for model fitting and predictions, but which doesn't expose all of the
|
|
#' possible functionalities of the core XGBoost library.
|
|
#'
|
|
#' See \link{xgb.train} for a more flexible low-level alternative which is similar across different
|
|
#' language bindings of XGBoost and which exposes the full library's functionalities.
|
|
#' @details For package authors using `xgboost` as a dependency, it is highly recommended to use
|
|
#' \link{xgb.train} in package code instead of `xgboost()`, since it has a more stable interface
|
|
#' and performs fewer data conversions and copies along the way.
|
|
#' @references \itemize{
|
|
#' \item Chen, Tianqi, and Carlos Guestrin. "Xgboost: A scalable tree boosting system."
|
|
#' Proceedings of the 22nd acm sigkdd international conference on knowledge discovery and
|
|
#' data mining. 2016.
|
|
#' \item \url{https://xgboost.readthedocs.io/en/stable/}
|
|
#' }
|
|
#' @param x The features / covariates. Can be passed as:\itemize{
|
|
#' \item A numeric or integer `matrix`.
|
|
#' \item A `data.frame`, in which all columns are one of the following types:\itemize{
|
|
#' \item `numeric`
|
|
#' \item `integer`
|
|
#' \item `logical`
|
|
#' \item `factor`
|
|
#' }
|
|
#'
|
|
#' Columns of `factor` type will be assumed to be categorical, while other column types will
|
|
#' be assumed to be numeric.
|
|
#' \item A sparse matrix from the `Matrix` package, either as `dgCMatrix` or `dgRMatrix` class.
|
|
#' }
|
|
#'
|
|
#' Note that categorical features are only supported for `data.frame` inputs, and are automatically
|
|
#' determined based on their types. See \link{xgb.train} with \link{xgb.DMatrix} for more flexible
|
|
#' variants that would allow something like categorical features on sparse matrices.
|
|
#' @param y The response variable. Allowed values are:\itemize{
|
|
#' \item A numeric or integer vector (for regression tasks).
|
|
#' \item A factor or character vector (for binary and multi-class classification tasks).
|
|
#' \item A logical (boolean) vector (for binary classification tasks).
|
|
#' \item A numeric or integer matrix or `data.frame` with numeric/integer columns
|
|
#' (for multi-task regression tasks).
|
|
#' \item A `Surv` object from the `survival` package (for survival tasks).
|
|
#' }
|
|
#'
|
|
#' If `objective` is `NULL`, the right task will be determined automatically based on
|
|
#' the class of `y`.
|
|
#'
|
|
#' If `objective` is not `NULL`, it must match with the type of `y` - e.g. `factor` types of `y`
|
|
#' can only be used with classification objectives and vice-versa.
|
|
#'
|
|
#' For binary classification, the last factor level of `y` will be used as the "positive"
|
|
#' class - that is, the numbers from `predict` will reflect the probabilities of belonging to this
|
|
#' class instead of to the first factor level. If `y` is a `logical` vector, then `TRUE` will be
|
|
#' set as the last level.
|
|
#' @param objective Optimization objective to minimize based on the supplied data, to be passed
|
|
#' by name as a string / character (e.g. `reg:absoluteerror`). See the
|
|
#' \href{https://xgboost.readthedocs.io/en/stable/parameter.html#learning-task-parameters}{
|
|
#' Learning Task Parameters} page for more detailed information on allowed values.
|
|
#'
|
|
#' If `NULL` (the default), will be automatically determined from `y` according to the following
|
|
#' logic:\itemize{
|
|
#' \item If `y` is a factor with 2 levels, will use `binary:logistic`.
|
|
#' \item If `y` is a factor with more than 2 levels, will use `multi:softprob` (number of classes
|
|
#' will be determined automatically, should not be passed under `params`).
|
|
#' \item If `y` is a `Surv` object from the `survival` package, will use `survival:aft` (note that
|
|
#' the only types supported are left / right / interval censored).
|
|
#' \item Otherwise, will use `reg:squarederror`.
|
|
#' }
|
|
#'
|
|
#' If `objective` is not `NULL`, it must match with the type of `y` - e.g. `factor` types of `y`
|
|
#' can only be used with classification objectives and vice-versa.
|
|
#'
|
|
#' Note that not all possible `objective` values supported by the core XGBoost library are allowed
|
|
#' here - for example, objectives which are a variation of another but with a different default
|
|
#' prediction type (e.g. `multi:softmax` vs. `multi:softprob`) are not allowed, and neither are
|
|
#' ranking objectives, nor custom objectives at the moment.
|
|
#' @param nrounds Number of boosting iterations / rounds.
|
|
#'
|
|
#' Note that the number of default boosting rounds here is not automatically tuned, and different
|
|
#' problems will have vastly different optimal numbers of boosting rounds.
|
|
#' @param weights Sample weights for each row in `x` and `y`. If `NULL` (the default), each row
|
|
#' will have the same weight.
|
|
#'
|
|
#' If not `NULL`, should be passed as a numeric vector with length matching to the number of
|
|
#' rows in `x`.
|
|
#' @param verbosity Verbosity of printing messages. Valid values of 0 (silent), 1 (warning),
|
|
#' 2 (info), and 3 (debug).
|
|
#' @param nthreads Number of parallel threads to use. If passing zero, will use all CPU threads.
|
|
#' @param seed Seed to use for random number generation. If passing `NULL`, will draw a random
|
|
#' number using R's PRNG system to use as seed.
|
|
#' @param monotone_constraints Optional monotonicity constraints for features.
|
|
#'
|
|
#' Can be passed either as a named list (when `x` has column names), or as a vector. If passed
|
|
#' as a vector and `x` has column names, will try to match the elements by name.
|
|
#'
|
|
#' A value of `+1` for a given feature makes the model predictions / scores constrained to be
|
|
#' a monotonically increasing function of that feature (that is, as the value of the feature
|
|
#' increases, the model prediction cannot decrease), while a value of `-1` makes it a monotonically
|
|
#' decreasing function. A value of zero imposes no constraint.
|
|
#'
|
|
#' The input for `monotone_constraints` can be a subset of the columns of `x` if named, in which
|
|
#' case the columns that are not referred to in `monotone_constraints` will be assumed to have
|
|
#' a value of zero (no constraint imposed on the model for those features).
|
|
#'
|
|
#' See the tutorial \href{https://xgboost.readthedocs.io/en/stable/tutorials/monotonic.html}{
|
|
#' Monotonic Constraints} for a more detailed explanation.
|
|
#' @param interaction_constraints Constraints for interaction representing permitted interactions.
|
|
#' The constraints must be specified in the form of a list of vectors referencing columns in the
|
|
#' data, e.g. `list(c(1, 2), c(3, 4, 5))` (with these numbers being column indices, numeration
|
|
#' starting at 1 - i.e. the first sublist references the first and second columns) or
|
|
#' `list(c("Sepal.Length", "Sepal.Width"), c("Petal.Length", "Petal.Width"))` (references
|
|
#' columns by names), where each vector is a group of indices of features that are allowed to
|
|
#' interact with each other.
|
|
#'
|
|
#' See the tutorial
|
|
#' \href{https://xgboost.readthedocs.io/en/stable/tutorials/feature_interaction_constraint.html}{
|
|
#' Feature Interaction Constraints} for more information.
|
|
#' @param feature_weights Feature weights for column sampling.
|
|
#'
|
|
#' Can be passed either as a vector with length matching to columns of `x`, or as a named
|
|
#' list (only if `x` has column names) with names matching to columns of 'x'. If it is a
|
|
#' named vector, will try to match the entries to column names of `x` by name.
|
|
#'
|
|
#' If `NULL` (the default), all columns will have the same weight.
|
|
#' @param base_margin Base margin used for boosting from existing model.
|
|
#'
|
|
#' If passing it, will start the gradient boosting procedure from the scores that are provided
|
|
#' here - for example, one can pass the raw scores from a previous model, or some per-observation
|
|
#' offset, or similar.
|
|
#'
|
|
#' Should be either a numeric vector or numeric matrix (for multi-class and multi-target objectives)
|
|
#' with the same number of rows as `x` and number of columns corresponding to number of optimization
|
|
#' targets, and should be in the untransformed scale (for example, for objective `binary:logistic`,
|
|
#' it should have log-odds, not probabilities; and for objective `multi:softprob`, should have
|
|
#' number of columns matching to number of classes in the data).
|
|
#'
|
|
#' Note that, if it contains more than one column, then columns will not be matched by name to
|
|
#' the corresponding `y` - `base_margin` should have the same column order that the model will use
|
|
#' (for example, for objective `multi:softprob`, columns of `base_margin` will be matched against
|
|
#' `levels(y)` by their position, regardless of what `colnames(base_margin)` returns).
|
|
#'
|
|
#' If `NULL`, will start from zero, but note that for most objectives, an intercept is usually
|
|
#' added (controllable through parameter `base_score` instead) when `base_margin` is not passed.
|
|
#' @param ... Other training parameters. See the online documentation
|
|
#' \href{https://xgboost.readthedocs.io/en/stable/parameter.html}{XGBoost Parameters} for
|
|
#' details about possible values and what they do.
|
|
#'
|
|
#' Note that not all possible values from the core XGBoost library are allowed as `params` for
|
|
#' 'xgboost()' - in particular, values which require an already-fitted booster object (such as
|
|
#' `process_type`) are not accepted here.
|
|
#' @return A model object, inheriting from both `xgboost` and `xgb.Booster`. Compared to the regular
|
|
#' `xgb.Booster` model class produced by \link{xgb.train}, this `xgboost` class will have an
|
|
#' additional attribute `metadata` containing information which is used for formatting prediction
|
|
#' outputs, such as class names for classification problems.
|
|
#' @examples
|
|
#' library(xgboost)
|
|
#' data(mtcars)
|
|
#'
|
|
#' # Fit a small regression model on the mtcars data
|
|
#' model_regression <- xgboost(mtcars[, -1], mtcars$mpg, nthreads = 1, nrounds = 3)
|
|
#' predict(model_regression, mtcars, validate_features = TRUE)
|
|
#'
|
|
#' # Task objective is determined automatically according to the type of 'y'
|
|
#' data(iris)
|
|
#' model_classif <- xgboost(iris[, -5], iris$Species, nthreads = 1, nrounds = 5)
|
|
#' predict(model_classif, iris, validate_features = TRUE)
|
|
xgboost <- function(
|
|
x,
|
|
y,
|
|
objective = NULL,
|
|
nrounds = 100L,
|
|
weights = NULL,
|
|
verbosity = 0L,
|
|
nthreads = parallel::detectCores(),
|
|
seed = 0L,
|
|
monotone_constraints = NULL,
|
|
interaction_constraints = NULL,
|
|
feature_weights = NULL,
|
|
base_margin = NULL,
|
|
...
|
|
) {
|
|
# Note: '...' is a workaround, to be removed later by making all parameters be arguments
|
|
params <- list(...)
|
|
params <- prescreen.parameters(params)
|
|
prescreen.objective(objective)
|
|
use_qdm <- check.can.use.qdm(x, params)
|
|
lst_args <- process.y.margin.and.objective(y, base_margin, objective, params)
|
|
lst_args <- process.row.weights(weights, lst_args)
|
|
lst_args <- process.x.and.col.args(
|
|
x,
|
|
monotone_constraints,
|
|
interaction_constraints,
|
|
feature_weights,
|
|
lst_args,
|
|
use_qdm
|
|
)
|
|
|
|
if (use_qdm && "max_bin" %in% names(params)) {
|
|
lst_args$dmatrix_args$max_bin <- params$max_bin
|
|
}
|
|
|
|
nthreads <- check.nthreads(nthreads)
|
|
lst_args$dmatrix_args$nthread <- nthreads
|
|
lst_args$params$nthread <- nthreads
|
|
lst_args$params$seed <- seed
|
|
|
|
params <- c(lst_args$params, params)
|
|
|
|
fn_dm <- if (use_qdm) xgb.QuantileDMatrix else xgb.DMatrix
|
|
dm <- do.call(fn_dm, lst_args$dmatrix_args)
|
|
model <- xgb.train(
|
|
params = params,
|
|
data = dm,
|
|
nrounds = nrounds,
|
|
verbose = verbosity
|
|
)
|
|
attributes(model)$metadata <- lst_args$metadata
|
|
attributes(model)$call <- match.call()
|
|
class(model) <- c("xgboost", class(model))
|
|
return(model)
|
|
}
|
|
|
|
#' @export
|
|
print.xgboost <- function(x, ...) {
|
|
cat("XGBoost model object\n")
|
|
cat("Call:\n ")
|
|
print(attributes(x)$call)
|
|
cat("Objective: ", attributes(x)$params$objective, "\n", sep = "")
|
|
cat("Number of iterations: ", xgb.get.num.boosted.rounds(x), "\n", sep = "")
|
|
cat("Number of features: ", xgb.num_feature(x), "\n", sep = "")
|
|
|
|
printable_head <- function(v) {
|
|
v_sub <- utils::head(v, 5L)
|
|
return(
|
|
sprintf(
|
|
"%s%s",
|
|
paste(v_sub, collapse = ", "),
|
|
ifelse(length(v_sub) < length(v), ", ...", "")
|
|
)
|
|
)
|
|
}
|
|
|
|
if (NROW(attributes(x)$metadata$y_levels)) {
|
|
cat(
|
|
"Classes: ",
|
|
printable_head(attributes(x)$metadata$y_levels),
|
|
"\n",
|
|
sep = ""
|
|
)
|
|
} else if (NROW(attributes(x)$params$quantile_alpha)) {
|
|
cat(
|
|
"Prediction quantile",
|
|
ifelse(length(attributes(x)$params$quantile_alpha) > 1L, "s", ""),
|
|
": ",
|
|
printable_head(attributes(x)$params$quantile_alpha),
|
|
"\n",
|
|
sep = ""
|
|
)
|
|
} else if (NROW(attributes(x)$metadata$y_names)) {
|
|
cat(
|
|
"Prediction targets: ",
|
|
printable_head(attributes(x)$metadata$y_names),
|
|
"\n",
|
|
sep = ""
|
|
)
|
|
} else if (attributes(x)$metadata$n_targets > 1L) {
|
|
cat(
|
|
"Number of predition targets: ",
|
|
attributes(x)$metadata$n_targets,
|
|
"\n",
|
|
sep = ""
|
|
)
|
|
}
|
|
|
|
return(x)
|
|
}
|
|
|
|
|
|
#' Training part from Mushroom Data Set
|
|
#'
|
|
#' This data set is originally from the Mushroom data set,
|
|
#' UCI Machine Learning Repository.
|
|
#'
|
|
#' This data set includes the following fields:
|
|
#'
|
|
#' \itemize{
|
|
#' \item \code{label} the label for each record
|
|
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
|
|
#' }
|
|
#'
|
|
#' @references
|
|
#' <https://archive.ics.uci.edu/ml/datasets/Mushroom>
|
|
#'
|
|
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
|
#' <http://archive.ics.uci.edu/ml>. Irvine, CA: University of California,
|
|
#' School of Information and Computer Science.
|
|
#'
|
|
#' @docType data
|
|
#' @keywords datasets
|
|
#' @name agaricus.train
|
|
#' @usage data(agaricus.train)
|
|
#' @format A list containing a label vector, and a dgCMatrix object with 6513
|
|
#' rows and 127 variables
|
|
NULL
|
|
|
|
#' Test part from Mushroom Data Set
|
|
#'
|
|
#' This data set is originally from the Mushroom data set,
|
|
#' UCI Machine Learning Repository.
|
|
#'
|
|
#' This data set includes the following fields:
|
|
#'
|
|
#' \itemize{
|
|
#' \item \code{label} the label for each record
|
|
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
|
|
#' }
|
|
#'
|
|
#' @references
|
|
#' <https://archive.ics.uci.edu/ml/datasets/Mushroom>
|
|
#'
|
|
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
|
#' <http://archive.ics.uci.edu/ml>. Irvine, CA: University of California,
|
|
#' School of Information and Computer Science.
|
|
#'
|
|
#' @docType data
|
|
#' @keywords datasets
|
|
#' @name agaricus.test
|
|
#' @usage data(agaricus.test)
|
|
#' @format A list containing a label vector, and a dgCMatrix object with 1611
|
|
#' rows and 126 variables
|
|
NULL
|
|
|
|
# Various imports
|
|
#' @importClassesFrom Matrix dgCMatrix dgRMatrix CsparseMatrix
|
|
#' @importFrom Matrix sparse.model.matrix
|
|
#' @importFrom data.table data.table
|
|
#' @importFrom data.table is.data.table
|
|
#' @importFrom data.table as.data.table
|
|
#' @importFrom data.table :=
|
|
#' @importFrom data.table rbindlist
|
|
#' @importFrom data.table setkey
|
|
#' @importFrom data.table setkeyv
|
|
#' @importFrom data.table setnames
|
|
#' @importFrom jsonlite fromJSON
|
|
#' @importFrom jsonlite toJSON
|
|
#' @importFrom methods new
|
|
#' @importFrom utils object.size str tail
|
|
#' @importFrom stats coef
|
|
#' @importFrom stats predict
|
|
#' @importFrom stats median
|
|
#' @importFrom stats sd
|
|
#' @importFrom stats variable.names
|
|
#' @importFrom utils head
|
|
#' @importFrom graphics barplot
|
|
#' @importFrom graphics lines
|
|
#' @importFrom graphics points
|
|
#' @importFrom graphics grid
|
|
#' @importFrom graphics par
|
|
#' @importFrom graphics title
|
|
#' @importFrom grDevices rgb
|
|
#'
|
|
#' @import methods
|
|
#' @useDynLib xgboost, .registration = TRUE
|
|
NULL
|