[R] Add optional check on column names matching in predict (#10020)

This commit is contained in:
david-cortes 2024-01-31 08:43:22 +01:00 committed by GitHub
parent c53d59f8db
commit 1e72dc1276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 196 additions and 1 deletions

View File

@ -111,6 +111,21 @@ xgb.get.handle <- function(object) {
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
#' type and shape of predictions are invariant to the model type.
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
#' match (only applicable when both `object` and `newdata` have feature names).
#'
#' If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder
#' the columns in `newdata` to match with the booster's.
#'
#' If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`,
#' will additionally verify that categorical columns are of the correct type in `newdata`,
#' throwing an error if they do not match.
#'
#' If passing `FALSE`, it is assumed that the feature names and types are the same,
#' and come in the same order as in the training data.
#'
#' Note that this check might add some sizable latency to the predictions, so it's
#' recommended to disable it for performance-sensitive applications.
#' @param ... Not used.
#'
#' @details
@ -271,7 +286,11 @@ xgb.get.handle <- function(object) {
#' @export
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, ...) {
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
validate_features = FALSE, ...) {
if (validate_features) {
newdata <- validate.features(object, newdata)
}
if (!inherits(newdata, "xgb.DMatrix")) {
nthread <- xgb.nthread(object)
newdata <- xgb.DMatrix(
@ -418,6 +437,85 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
return(arr)
}
validate.features <- function(bst, newdata) {
if (is.character(newdata)) {
# this will be encountered when passing file paths
return(newdata)
}
if (inherits(newdata, "sparseVector")) {
# in this case, newdata won't have metadata
return(newdata)
}
if (is.vector(newdata)) {
newdata <- as.matrix(newdata)
}
booster_names <- getinfo(bst, "feature_name")
checked_names <- FALSE
if (NROW(booster_names)) {
try_reorder <- FALSE
if (inherits(newdata, "xgb.DMatrix")) {
curr_names <- getinfo(newdata, "feature_name")
} else {
curr_names <- colnames(newdata)
try_reorder <- TRUE
}
if (NROW(curr_names)) {
checked_names <- TRUE
if (length(curr_names) != length(booster_names) || any(curr_names != booster_names)) {
if (!try_reorder) {
stop("Feature names in 'newdata' do not match with booster's.")
} else {
if (inherits(newdata, "data.table")) {
newdata <- newdata[, booster_names, with = FALSE]
} else {
newdata <- newdata[, booster_names, drop = FALSE]
}
}
}
} # if (NROW(curr_names)) {
} # if (NROW(booster_names)) {
if (inherits(newdata, c("data.frame", "xgb.DMatrix"))) {
booster_types <- getinfo(bst, "feature_type")
if (!NROW(booster_types)) {
# Note: types in the booster are optional. Other interfaces
# might not even save it as booster attributes for example,
# even if the model uses categorical features.
return(newdata)
}
if (inherits(newdata, "xgb.DMatrix")) {
curr_types <- getinfo(newdata, "feature_type")
if (length(curr_types) != length(booster_types) || any(curr_types != booster_types)) {
stop("Feature types in 'newdata' do not match with booster's.")
}
}
if (inherits(newdata, "data.frame")) {
is_factor <- sapply(newdata, is.factor)
if (any(is_factor != (booster_types == "c"))) {
stop(
paste0(
"Feature types in 'newdata' do not match with booster's for same columns (by ",
ifelse(checked_names, "name", "position"),
")."
)
)
}
}
}
return(newdata)
}
#' @title Accessors for serializable attributes of a model
#'

View File

@ -17,6 +17,7 @@
training = FALSE,
iterationrange = NULL,
strict_shape = FALSE,
validate_features = FALSE,
...
)
}
@ -66,6 +67,23 @@ base-1 indexing, and inclusive of both ends).
\item{strict_shape}{Default is \code{FALSE}. When set to \code{TRUE}, the output
type and shape of predictions are invariant to the model type.}
\item{validate_features}{When \code{TRUE}, validate that the Booster's and newdata's feature_names
match (only applicable when both \code{object} and \code{newdata} have feature names).
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder
the columns in `newdata` to match with the booster's.
If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`,
will additionally verify that categorical columns are of the correct type in `newdata`,
throwing an error if they do not match.
If passing `FALSE`, it is assumed that the feature names and types are the same,
and come in the same order as in the training data.
Note that this check might add some sizable latency to the predictions, so it's
recommended to disable it for performance-sensitive applications.
}\if{html}{\out{</div>}}}
\item{...}{Not used.}
}
\value{

View File

@ -511,3 +511,82 @@ test_that('convert.labels works', {
expect_equal(class(res), 'numeric')
}
})
test_that("validate.features works as expected", {
data(mtcars)
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
dm <- xgb.DMatrix(x, label = y, nthread = 1)
model <- xgb.train(
params = list(nthread = 1),
data = dm,
nrounds = 3
)
# result is output as-is when needed
res <- validate.features(model, x)
expect_equal(res, x)
res <- validate.features(model, dm)
expect_identical(res, dm)
res <- validate.features(model, as(x[1, ], "dsparseVector"))
expect_equal(as.numeric(res), unname(x[1, ]))
res <- validate.features(model, "file.txt")
expect_equal(res, "file.txt")
# columns are reordered
res <- validate.features(model, mtcars[, rev(names(mtcars))])
expect_equal(names(res), colnames(x))
expect_equal(as.matrix(res), x)
res <- validate.features(model, as.matrix(mtcars[, rev(names(mtcars))]))
expect_equal(colnames(res), colnames(x))
expect_equal(res, x)
res <- validate.features(model, mtcars[1, rev(names(mtcars)), drop = FALSE])
expect_equal(names(res), colnames(x))
expect_equal(unname(as.matrix(res)), unname(x[1, , drop = FALSE]))
res <- validate.features(model, as.data.table(mtcars[, rev(names(mtcars))]))
expect_equal(names(res), colnames(x))
expect_equal(unname(as.matrix(res)), unname(x))
# error when columns are missing
expect_error({
validate.features(model, mtcars[, 1:3])
})
expect_error({
validate.features(model, as.matrix(mtcars[, 1:ncol(x)])) # nolint
})
expect_error({
validate.features(model, xgb.DMatrix(mtcars[, 1:3]))
})
expect_error({
validate.features(model, as(x[, 1:3], "CsparseMatrix"))
})
# error when it cannot reorder or subset
expect_error({
validate.features(model, xgb.DMatrix(mtcars))
}, "Feature names")
expect_error({
validate.features(model, xgb.DMatrix(x[, rev(colnames(x))]))
}, "Feature names")
# no error about types if the booster doesn't have types
expect_error({
validate.features(model, xgb.DMatrix(x, feature_types = c(rep("q", 5), rep("c", 5))))
}, NA)
tmp <- mtcars
tmp[["vs"]] <- factor(tmp[["vs"]])
expect_error({
validate.features(model, tmp)
}, NA)
# error when types do not match
setinfo(model, "feature_type", rep("q", 10))
expect_error({
validate.features(model, xgb.DMatrix(x, feature_types = c(rep("q", 5), rep("c", 5))))
}, "Feature types")
tmp <- mtcars
tmp[["vs"]] <- factor(tmp[["vs"]])
expect_error({
validate.features(model, tmp)
}, "Feature types")
})