[R] Use inplace predict (#9829)

---------

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
david-cortes
2024-02-23 19:03:54 +01:00
committed by GitHub
parent 729fd97196
commit f7005d32c1
7 changed files with 450 additions and 46 deletions

View File

@@ -77,26 +77,45 @@ xgb.get.handle <- function(object) {
#' Predict method for XGBoost model
#'
#' Predicted values based on either xgboost model or model handle object.
#' Predict values on data based on xgboost model.
#'
#' @param object Object of class `xgb.Booster`.
#' @param newdata Takes `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
#' @param newdata Takes `data.frame`, `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
#' local data file, or `xgb.DMatrix`.
#' For single-row predictions on sparse data, it is recommended to use the CSR format.
#' If passing a sparse vector, it will take it as a row vector.
#' @param missing Only used when input is a dense matrix. Pick a float value that represents
#' missing values in data (e.g., 0 or some other extreme value).
#'
#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
#' a sparse vector, it will take it as a row vector.
#'
#' Note that, for repeated predictions on the same data, one might want to create a DMatrix to
#' pass here instead of passing R types like matrices or data frames, as predictions will be
#' faster on DMatrix.
#'
#' If `newdata` is a `data.frame`, be aware that:\itemize{
#' \item Columns will be converted to numeric if they aren't already, which could potentially make
#' the operation slower than in an equivalent `matrix` object.
#' \item The order of the columns must match with that of the data from which the model was fitted
#' (i.e. columns will not be referenced by their names, just by their order in the data).
#' \item If the model was fitted to data with categorical columns, these columns must be of
#' `factor` type here, and must use the same encoding (i.e. have the same levels).
#' \item If `newdata` contains any `factor` columns, they will be converted to base-0
#' encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
#' under a column which during training had a different type.
#' }
#' @param missing Float value that represents missing values in data (e.g., 0 or some other extreme value).
#'
#' This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
#' this as an argument to the DMatrix constructor instead.
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
#' logistic regression would return log-odds instead of probabilities.
#' @param predleaf Whether to predict pre-tree leaf indices.
#' @param predleaf Whether to predict per-tree leaf indices.
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
#' or `predinteraction` is `TRUE`.
#' @param training Whether the predictions are used for training. For dart booster,
#' @param training Whether the prediction result is used for training. For dart booster,
#' training predicting will perform dropout.
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
@@ -111,6 +130,12 @@ xgb.get.handle <- function(object) {
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
#' type and shape of predictions are invariant to the model type.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
#' be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
#' an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
#'
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
#' match (only applicable when both `object` and `newdata` have feature names).
#'
@@ -287,16 +312,80 @@ xgb.get.handle <- function(object) {
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
validate_features = FALSE, ...) {
validate_features = FALSE, base_margin = NULL, ...) {
if (validate_features) {
newdata <- validate.features(object, newdata)
}
if (!inherits(newdata, "xgb.DMatrix")) {
is_dmatrix <- inherits(newdata, "xgb.DMatrix")
if (is_dmatrix && !is.null(base_margin)) {
stop(
"'base_margin' is not supported when passing 'xgb.DMatrix' as input.",
" Should be passed as argument to 'xgb.DMatrix' constructor."
)
}
use_as_df <- FALSE
use_as_dense_matrix <- FALSE
use_as_csr_matrix <- FALSE
n_row <- NULL
if (!is_dmatrix) {
inplace_predict_supported <- !predcontrib && !predinteraction && !predleaf
if (inplace_predict_supported) {
booster_type <- xgb.booster_type(object)
if (booster_type == "gblinear" || (booster_type == "dart" && training)) {
inplace_predict_supported <- FALSE
}
}
if (inplace_predict_supported) {
if (is.matrix(newdata)) {
use_as_dense_matrix <- TRUE
} else if (is.data.frame(newdata)) {
# note: since here it turns it into a non-data-frame list,
# needs to keep track of the number of rows it had for later
n_row <- nrow(newdata)
newdata <- lapply(
newdata,
function(x) {
if (is.factor(x)) {
return(as.numeric(x) - 1)
} else {
return(as.numeric(x))
}
}
)
use_as_df <- TRUE
} else if (inherits(newdata, "dgRMatrix")) {
use_as_csr_matrix <- TRUE
csr_data <- list(newdata@p, newdata@j, newdata@x, ncol(newdata))
} else if (inherits(newdata, "dsparseVector")) {
use_as_csr_matrix <- TRUE
n_row <- 1L
i <- newdata@i - 1L
if (storage.mode(i) != "integer") {
storage.mode(i) <- "integer"
}
csr_data <- list(c(0L, length(i)), i, newdata@x, length(newdata))
}
}
} # if (!is_dmatrix)
if (!is_dmatrix && !use_as_dense_matrix && !use_as_csr_matrix && !use_as_df) {
nthread <- xgb.nthread(object)
newdata <- xgb.DMatrix(
newdata,
missing = missing, nthread = NVL(nthread, -1)
missing = missing,
base_margin = base_margin,
nthread = NVL(nthread, -1)
)
is_dmatrix <- TRUE
}
if (is.null(n_row)) {
n_row <- nrow(newdata)
}
@@ -354,18 +443,30 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
args$type <- set_type(6)
}
predts <- .Call(
XGBoosterPredictFromDMatrix_R,
xgb.get.handle(object),
newdata,
jsonlite::toJSON(args, auto_unbox = TRUE)
)
json_conf <- jsonlite::toJSON(args, auto_unbox = TRUE)
if (is_dmatrix) {
predts <- .Call(
XGBoosterPredictFromDMatrix_R, xgb.get.handle(object), newdata, json_conf
)
} else if (use_as_dense_matrix) {
predts <- .Call(
XGBoosterPredictFromDense_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
} else if (use_as_csr_matrix) {
predts <- .Call(
XGBoosterPredictFromCSR_R, xgb.get.handle(object), csr_data, missing, json_conf, base_margin
)
} else if (use_as_df) {
predts <- .Call(
XGBoosterPredictFromColumnar_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
}
names(predts) <- c("shape", "results")
shape <- predts$shape
arr <- predts$results
n_ret <- length(arr)
n_row <- nrow(newdata)
if (n_row != shape[1]) {
stop("Incorrect predict shape.")
}