[R] Add optional check on column names matching in predict (#10020)

This commit is contained in:
david-cortes
2024-01-31 08:43:22 +01:00
committed by GitHub
parent c53d59f8db
commit 1e72dc1276
3 changed files with 196 additions and 1 deletions

View File

@@ -111,6 +111,21 @@ xgb.get.handle <- function(object) {
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
#' type and shape of predictions are invariant to the model type.
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
#' match (only applicable when both `object` and `newdata` have feature names).
#'
#' If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder
#' the columns in `newdata` to match with the booster's.
#'
#' If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`,
#' will additionally verify that categorical columns are of the correct type in `newdata`,
#' throwing an error if they do not match.
#'
#' If passing `FALSE`, it is assumed that the feature names and types are the same,
#' and come in the same order as in the training data.
#'
#' Note that this check might add some sizable latency to the predictions, so it's
#' recommended to disable it for performance-sensitive applications.
#' @param ... Not used.
#'
#' @details
@@ -271,7 +286,11 @@ xgb.get.handle <- function(object) {
#' @export
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, ...) {
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
validate_features = FALSE, ...) {
if (validate_features) {
newdata <- validate.features(object, newdata)
}
if (!inherits(newdata, "xgb.DMatrix")) {
nthread <- xgb.nthread(object)
newdata <- xgb.DMatrix(
@@ -418,6 +437,85 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
return(arr)
}
validate.features <- function(bst, newdata) {
if (is.character(newdata)) {
# this will be encountered when passing file paths
return(newdata)
}
if (inherits(newdata, "sparseVector")) {
# in this case, newdata won't have metadata
return(newdata)
}
if (is.vector(newdata)) {
newdata <- as.matrix(newdata)
}
booster_names <- getinfo(bst, "feature_name")
checked_names <- FALSE
if (NROW(booster_names)) {
try_reorder <- FALSE
if (inherits(newdata, "xgb.DMatrix")) {
curr_names <- getinfo(newdata, "feature_name")
} else {
curr_names <- colnames(newdata)
try_reorder <- TRUE
}
if (NROW(curr_names)) {
checked_names <- TRUE
if (length(curr_names) != length(booster_names) || any(curr_names != booster_names)) {
if (!try_reorder) {
stop("Feature names in 'newdata' do not match with booster's.")
} else {
if (inherits(newdata, "data.table")) {
newdata <- newdata[, booster_names, with = FALSE]
} else {
newdata <- newdata[, booster_names, drop = FALSE]
}
}
}
} # if (NROW(curr_names)) {
} # if (NROW(booster_names)) {
if (inherits(newdata, c("data.frame", "xgb.DMatrix"))) {
booster_types <- getinfo(bst, "feature_type")
if (!NROW(booster_types)) {
# Note: types in the booster are optional. Other interfaces
# might not even save it as booster attributes for example,
# even if the model uses categorical features.
return(newdata)
}
if (inherits(newdata, "xgb.DMatrix")) {
curr_types <- getinfo(newdata, "feature_type")
if (length(curr_types) != length(booster_types) || any(curr_types != booster_types)) {
stop("Feature types in 'newdata' do not match with booster's.")
}
}
if (inherits(newdata, "data.frame")) {
is_factor <- sapply(newdata, is.factor)
if (any(is_factor != (booster_types == "c"))) {
stop(
paste0(
"Feature types in 'newdata' do not match with booster's for same columns (by ",
ifelse(checked_names, "name", "position"),
")."
)
)
}
}
}
return(newdata)
}
#' @title Accessors for serializable attributes of a model
#'