[R] Add optional check on column names matching in predict (#10020)
This commit is contained in:
@@ -111,6 +111,21 @@ xgb.get.handle <- function(object) {
|
||||
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
|
||||
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
|
||||
#' type and shape of predictions are invariant to the model type.
|
||||
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
|
||||
#' match (only applicable when both `object` and `newdata` have feature names).
|
||||
#'
|
||||
#' If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder
|
||||
#' the columns in `newdata` to match with the booster's.
|
||||
#'
|
||||
#' If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`,
|
||||
#' will additionally verify that categorical columns are of the correct type in `newdata`,
|
||||
#' throwing an error if they do not match.
|
||||
#'
|
||||
#' If passing `FALSE`, it is assumed that the feature names and types are the same,
|
||||
#' and come in the same order as in the training data.
|
||||
#'
|
||||
#' Note that this check might add some sizable latency to the predictions, so it's
|
||||
#' recommended to disable it for performance-sensitive applications.
|
||||
#' @param ... Not used.
|
||||
#'
|
||||
#' @details
|
||||
@@ -271,7 +286,11 @@ xgb.get.handle <- function(object) {
|
||||
#' @export
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
|
||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
||||
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, ...) {
|
||||
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
|
||||
validate_features = FALSE, ...) {
|
||||
if (validate_features) {
|
||||
newdata <- validate.features(object, newdata)
|
||||
}
|
||||
if (!inherits(newdata, "xgb.DMatrix")) {
|
||||
nthread <- xgb.nthread(object)
|
||||
newdata <- xgb.DMatrix(
|
||||
@@ -418,6 +437,85 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
return(arr)
|
||||
}
|
||||
|
||||
validate.features <- function(bst, newdata) {
|
||||
if (is.character(newdata)) {
|
||||
# this will be encountered when passing file paths
|
||||
return(newdata)
|
||||
}
|
||||
if (inherits(newdata, "sparseVector")) {
|
||||
# in this case, newdata won't have metadata
|
||||
return(newdata)
|
||||
}
|
||||
if (is.vector(newdata)) {
|
||||
newdata <- as.matrix(newdata)
|
||||
}
|
||||
|
||||
booster_names <- getinfo(bst, "feature_name")
|
||||
checked_names <- FALSE
|
||||
if (NROW(booster_names)) {
|
||||
|
||||
try_reorder <- FALSE
|
||||
if (inherits(newdata, "xgb.DMatrix")) {
|
||||
curr_names <- getinfo(newdata, "feature_name")
|
||||
} else {
|
||||
curr_names <- colnames(newdata)
|
||||
try_reorder <- TRUE
|
||||
}
|
||||
|
||||
if (NROW(curr_names)) {
|
||||
checked_names <- TRUE
|
||||
|
||||
if (length(curr_names) != length(booster_names) || any(curr_names != booster_names)) {
|
||||
|
||||
if (!try_reorder) {
|
||||
stop("Feature names in 'newdata' do not match with booster's.")
|
||||
} else {
|
||||
if (inherits(newdata, "data.table")) {
|
||||
newdata <- newdata[, booster_names, with = FALSE]
|
||||
} else {
|
||||
newdata <- newdata[, booster_names, drop = FALSE]
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} # if (NROW(curr_names)) {
|
||||
|
||||
} # if (NROW(booster_names)) {
|
||||
|
||||
if (inherits(newdata, c("data.frame", "xgb.DMatrix"))) {
|
||||
|
||||
booster_types <- getinfo(bst, "feature_type")
|
||||
if (!NROW(booster_types)) {
|
||||
# Note: types in the booster are optional. Other interfaces
|
||||
# might not even save it as booster attributes for example,
|
||||
# even if the model uses categorical features.
|
||||
return(newdata)
|
||||
}
|
||||
if (inherits(newdata, "xgb.DMatrix")) {
|
||||
curr_types <- getinfo(newdata, "feature_type")
|
||||
if (length(curr_types) != length(booster_types) || any(curr_types != booster_types)) {
|
||||
stop("Feature types in 'newdata' do not match with booster's.")
|
||||
}
|
||||
}
|
||||
if (inherits(newdata, "data.frame")) {
|
||||
is_factor <- sapply(newdata, is.factor)
|
||||
if (any(is_factor != (booster_types == "c"))) {
|
||||
stop(
|
||||
paste0(
|
||||
"Feature types in 'newdata' do not match with booster's for same columns (by ",
|
||||
ifelse(checked_names, "name", "position"),
|
||||
")."
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return(newdata)
|
||||
}
|
||||
|
||||
|
||||
#' @title Accessors for serializable attributes of a model
|
||||
#'
|
||||
|
||||
Reference in New Issue
Block a user