[R] Use inplace predict (#9829)
--------- Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -77,26 +77,45 @@ xgb.get.handle <- function(object) {
|
||||
|
||||
#' Predict method for XGBoost model
|
||||
#'
|
||||
#' Predicted values based on either xgboost model or model handle object.
|
||||
#' Predict values on data based on xgboost model.
|
||||
#'
|
||||
#' @param object Object of class `xgb.Booster`.
|
||||
#' @param newdata Takes `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
|
||||
#' @param newdata Takes `data.frame`, `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
|
||||
#' local data file, or `xgb.DMatrix`.
|
||||
#' For single-row predictions on sparse data, it is recommended to use the CSR format.
|
||||
#' If passing a sparse vector, it will take it as a row vector.
|
||||
#' @param missing Only used when input is a dense matrix. Pick a float value that represents
|
||||
#' missing values in data (e.g., 0 or some other extreme value).
|
||||
#'
|
||||
#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
|
||||
#' a sparse vector, it will take it as a row vector.
|
||||
#'
|
||||
#' Note that, for repeated predictions on the same data, one might want to create a DMatrix to
|
||||
#' pass here instead of passing R types like matrices or data frames, as predictions will be
|
||||
#' faster on DMatrix.
|
||||
#'
|
||||
#' If `newdata` is a `data.frame`, be aware that:\itemize{
|
||||
#' \item Columns will be converted to numeric if they aren't already, which could potentially make
|
||||
#' the operation slower than in an equivalent `matrix` object.
|
||||
#' \item The order of the columns must match with that of the data from which the model was fitted
|
||||
#' (i.e. columns will not be referenced by their names, just by their order in the data).
|
||||
#' \item If the model was fitted to data with categorical columns, these columns must be of
|
||||
#' `factor` type here, and must use the same encoding (i.e. have the same levels).
|
||||
#' \item If `newdata` contains any `factor` columns, they will be converted to base-0
|
||||
#' encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
|
||||
#' under a column which during training had a different type.
|
||||
#' }
|
||||
#' @param missing Float value that represents missing values in data (e.g., 0 or some other extreme value).
|
||||
#'
|
||||
#' This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
|
||||
#' this as an argument to the DMatrix constructor instead.
|
||||
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed
|
||||
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
|
||||
#' logistic regression would return log-odds instead of probabilities.
|
||||
#' @param predleaf Whether to predict pre-tree leaf indices.
|
||||
#' @param predleaf Whether to predict per-tree leaf indices.
|
||||
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
|
||||
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
|
||||
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
|
||||
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
|
||||
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
|
||||
#' or `predinteraction` is `TRUE`.
|
||||
#' @param training Whether the predictions are used for training. For dart booster,
|
||||
#' @param training Whether the prediction result is used for training. For dart booster,
|
||||
#' training predicting will perform dropout.
|
||||
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
|
||||
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
|
||||
@@ -111,6 +130,12 @@ xgb.get.handle <- function(object) {
|
||||
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
|
||||
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
|
||||
#' type and shape of predictions are invariant to the model type.
|
||||
#' @param base_margin Base margin used for boosting from existing model.
|
||||
#'
|
||||
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
|
||||
#' be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
|
||||
#' an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
|
||||
#'
|
||||
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
|
||||
#' match (only applicable when both `object` and `newdata` have feature names).
|
||||
#'
|
||||
@@ -287,16 +312,80 @@ xgb.get.handle <- function(object) {
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
|
||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
||||
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
|
||||
validate_features = FALSE, ...) {
|
||||
validate_features = FALSE, base_margin = NULL, ...) {
|
||||
if (validate_features) {
|
||||
newdata <- validate.features(object, newdata)
|
||||
}
|
||||
if (!inherits(newdata, "xgb.DMatrix")) {
|
||||
is_dmatrix <- inherits(newdata, "xgb.DMatrix")
|
||||
if (is_dmatrix && !is.null(base_margin)) {
|
||||
stop(
|
||||
"'base_margin' is not supported when passing 'xgb.DMatrix' as input.",
|
||||
" Should be passed as argument to 'xgb.DMatrix' constructor."
|
||||
)
|
||||
}
|
||||
|
||||
use_as_df <- FALSE
|
||||
use_as_dense_matrix <- FALSE
|
||||
use_as_csr_matrix <- FALSE
|
||||
n_row <- NULL
|
||||
if (!is_dmatrix) {
|
||||
|
||||
inplace_predict_supported <- !predcontrib && !predinteraction && !predleaf
|
||||
if (inplace_predict_supported) {
|
||||
booster_type <- xgb.booster_type(object)
|
||||
if (booster_type == "gblinear" || (booster_type == "dart" && training)) {
|
||||
inplace_predict_supported <- FALSE
|
||||
}
|
||||
}
|
||||
if (inplace_predict_supported) {
|
||||
|
||||
if (is.matrix(newdata)) {
|
||||
use_as_dense_matrix <- TRUE
|
||||
} else if (is.data.frame(newdata)) {
|
||||
# note: since here it turns it into a non-data-frame list,
|
||||
# needs to keep track of the number of rows it had for later
|
||||
n_row <- nrow(newdata)
|
||||
newdata <- lapply(
|
||||
newdata,
|
||||
function(x) {
|
||||
if (is.factor(x)) {
|
||||
return(as.numeric(x) - 1)
|
||||
} else {
|
||||
return(as.numeric(x))
|
||||
}
|
||||
}
|
||||
)
|
||||
use_as_df <- TRUE
|
||||
} else if (inherits(newdata, "dgRMatrix")) {
|
||||
use_as_csr_matrix <- TRUE
|
||||
csr_data <- list(newdata@p, newdata@j, newdata@x, ncol(newdata))
|
||||
} else if (inherits(newdata, "dsparseVector")) {
|
||||
use_as_csr_matrix <- TRUE
|
||||
n_row <- 1L
|
||||
i <- newdata@i - 1L
|
||||
if (storage.mode(i) != "integer") {
|
||||
storage.mode(i) <- "integer"
|
||||
}
|
||||
csr_data <- list(c(0L, length(i)), i, newdata@x, length(newdata))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} # if (!is_dmatrix)
|
||||
|
||||
if (!is_dmatrix && !use_as_dense_matrix && !use_as_csr_matrix && !use_as_df) {
|
||||
nthread <- xgb.nthread(object)
|
||||
newdata <- xgb.DMatrix(
|
||||
newdata,
|
||||
missing = missing, nthread = NVL(nthread, -1)
|
||||
missing = missing,
|
||||
base_margin = base_margin,
|
||||
nthread = NVL(nthread, -1)
|
||||
)
|
||||
is_dmatrix <- TRUE
|
||||
}
|
||||
|
||||
if (is.null(n_row)) {
|
||||
n_row <- nrow(newdata)
|
||||
}
|
||||
|
||||
|
||||
@@ -354,18 +443,30 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
args$type <- set_type(6)
|
||||
}
|
||||
|
||||
predts <- .Call(
|
||||
XGBoosterPredictFromDMatrix_R,
|
||||
xgb.get.handle(object),
|
||||
newdata,
|
||||
jsonlite::toJSON(args, auto_unbox = TRUE)
|
||||
)
|
||||
json_conf <- jsonlite::toJSON(args, auto_unbox = TRUE)
|
||||
if (is_dmatrix) {
|
||||
predts <- .Call(
|
||||
XGBoosterPredictFromDMatrix_R, xgb.get.handle(object), newdata, json_conf
|
||||
)
|
||||
} else if (use_as_dense_matrix) {
|
||||
predts <- .Call(
|
||||
XGBoosterPredictFromDense_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
|
||||
)
|
||||
} else if (use_as_csr_matrix) {
|
||||
predts <- .Call(
|
||||
XGBoosterPredictFromCSR_R, xgb.get.handle(object), csr_data, missing, json_conf, base_margin
|
||||
)
|
||||
} else if (use_as_df) {
|
||||
predts <- .Call(
|
||||
XGBoosterPredictFromColumnar_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
|
||||
)
|
||||
}
|
||||
|
||||
names(predts) <- c("shape", "results")
|
||||
shape <- predts$shape
|
||||
arr <- predts$results
|
||||
|
||||
n_ret <- length(arr)
|
||||
n_row <- nrow(newdata)
|
||||
if (n_row != shape[1]) {
|
||||
stop("Incorrect predict shape.")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user