[R] Add optional check on column names matching in predict (#10020)
This commit is contained in:
parent
c53d59f8db
commit
1e72dc1276
@ -111,6 +111,21 @@ xgb.get.handle <- function(object) {
|
||||
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
|
||||
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
|
||||
#' type and shape of predictions are invariant to the model type.
|
||||
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
|
||||
#' match (only applicable when both `object` and `newdata` have feature names).
|
||||
#'
|
||||
#' If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder
|
||||
#' the columns in `newdata` to match with the booster's.
|
||||
#'
|
||||
#' If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`,
|
||||
#' will additionally verify that categorical columns are of the correct type in `newdata`,
|
||||
#' throwing an error if they do not match.
|
||||
#'
|
||||
#' If passing `FALSE`, it is assumed that the feature names and types are the same,
|
||||
#' and come in the same order as in the training data.
|
||||
#'
|
||||
#' Note that this check might add some sizable latency to the predictions, so it's
|
||||
#' recommended to disable it for performance-sensitive applications.
|
||||
#' @param ... Not used.
|
||||
#'
|
||||
#' @details
|
||||
@ -271,7 +286,11 @@ xgb.get.handle <- function(object) {
|
||||
#' @export
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
|
||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
||||
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, ...) {
|
||||
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
|
||||
validate_features = FALSE, ...) {
|
||||
if (validate_features) {
|
||||
newdata <- validate.features(object, newdata)
|
||||
}
|
||||
if (!inherits(newdata, "xgb.DMatrix")) {
|
||||
nthread <- xgb.nthread(object)
|
||||
newdata <- xgb.DMatrix(
|
||||
@ -418,6 +437,85 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
return(arr)
|
||||
}
|
||||
|
||||
validate.features <- function(bst, newdata) {
|
||||
if (is.character(newdata)) {
|
||||
# this will be encountered when passing file paths
|
||||
return(newdata)
|
||||
}
|
||||
if (inherits(newdata, "sparseVector")) {
|
||||
# in this case, newdata won't have metadata
|
||||
return(newdata)
|
||||
}
|
||||
if (is.vector(newdata)) {
|
||||
newdata <- as.matrix(newdata)
|
||||
}
|
||||
|
||||
booster_names <- getinfo(bst, "feature_name")
|
||||
checked_names <- FALSE
|
||||
if (NROW(booster_names)) {
|
||||
|
||||
try_reorder <- FALSE
|
||||
if (inherits(newdata, "xgb.DMatrix")) {
|
||||
curr_names <- getinfo(newdata, "feature_name")
|
||||
} else {
|
||||
curr_names <- colnames(newdata)
|
||||
try_reorder <- TRUE
|
||||
}
|
||||
|
||||
if (NROW(curr_names)) {
|
||||
checked_names <- TRUE
|
||||
|
||||
if (length(curr_names) != length(booster_names) || any(curr_names != booster_names)) {
|
||||
|
||||
if (!try_reorder) {
|
||||
stop("Feature names in 'newdata' do not match with booster's.")
|
||||
} else {
|
||||
if (inherits(newdata, "data.table")) {
|
||||
newdata <- newdata[, booster_names, with = FALSE]
|
||||
} else {
|
||||
newdata <- newdata[, booster_names, drop = FALSE]
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} # if (NROW(curr_names)) {
|
||||
|
||||
} # if (NROW(booster_names)) {
|
||||
|
||||
if (inherits(newdata, c("data.frame", "xgb.DMatrix"))) {
|
||||
|
||||
booster_types <- getinfo(bst, "feature_type")
|
||||
if (!NROW(booster_types)) {
|
||||
# Note: types in the booster are optional. Other interfaces
|
||||
# might not even save it as booster attributes for example,
|
||||
# even if the model uses categorical features.
|
||||
return(newdata)
|
||||
}
|
||||
if (inherits(newdata, "xgb.DMatrix")) {
|
||||
curr_types <- getinfo(newdata, "feature_type")
|
||||
if (length(curr_types) != length(booster_types) || any(curr_types != booster_types)) {
|
||||
stop("Feature types in 'newdata' do not match with booster's.")
|
||||
}
|
||||
}
|
||||
if (inherits(newdata, "data.frame")) {
|
||||
is_factor <- sapply(newdata, is.factor)
|
||||
if (any(is_factor != (booster_types == "c"))) {
|
||||
stop(
|
||||
paste0(
|
||||
"Feature types in 'newdata' do not match with booster's for same columns (by ",
|
||||
ifelse(checked_names, "name", "position"),
|
||||
")."
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return(newdata)
|
||||
}
|
||||
|
||||
|
||||
#' @title Accessors for serializable attributes of a model
|
||||
#'
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
training = FALSE,
|
||||
iterationrange = NULL,
|
||||
strict_shape = FALSE,
|
||||
validate_features = FALSE,
|
||||
...
|
||||
)
|
||||
}
|
||||
@ -66,6 +67,23 @@ base-1 indexing, and inclusive of both ends).
|
||||
\item{strict_shape}{Default is \code{FALSE}. When set to \code{TRUE}, the output
|
||||
type and shape of predictions are invariant to the model type.}
|
||||
|
||||
\item{validate_features}{When \code{TRUE}, validate that the Booster's and newdata's feature_names
|
||||
match (only applicable when both \code{object} and \code{newdata} have feature names).
|
||||
|
||||
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder
|
||||
the columns in `newdata` to match with the booster's.
|
||||
|
||||
If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`,
|
||||
will additionally verify that categorical columns are of the correct type in `newdata`,
|
||||
throwing an error if they do not match.
|
||||
|
||||
If passing `FALSE`, it is assumed that the feature names and types are the same,
|
||||
and come in the same order as in the training data.
|
||||
|
||||
Note that this check might add some sizable latency to the predictions, so it's
|
||||
recommended to disable it for performance-sensitive applications.
|
||||
}\if{html}{\out{</div>}}}
|
||||
|
||||
\item{...}{Not used.}
|
||||
}
|
||||
\value{
|
||||
|
||||
@ -511,3 +511,82 @@ test_that('convert.labels works', {
|
||||
expect_equal(class(res), 'numeric')
|
||||
}
|
||||
})
|
||||
|
||||
test_that("validate.features works as expected", {
|
||||
data(mtcars)
|
||||
y <- mtcars$mpg
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(x, label = y, nthread = 1)
|
||||
model <- xgb.train(
|
||||
params = list(nthread = 1),
|
||||
data = dm,
|
||||
nrounds = 3
|
||||
)
|
||||
|
||||
# result is output as-is when needed
|
||||
res <- validate.features(model, x)
|
||||
expect_equal(res, x)
|
||||
res <- validate.features(model, dm)
|
||||
expect_identical(res, dm)
|
||||
res <- validate.features(model, as(x[1, ], "dsparseVector"))
|
||||
expect_equal(as.numeric(res), unname(x[1, ]))
|
||||
res <- validate.features(model, "file.txt")
|
||||
expect_equal(res, "file.txt")
|
||||
|
||||
# columns are reordered
|
||||
res <- validate.features(model, mtcars[, rev(names(mtcars))])
|
||||
expect_equal(names(res), colnames(x))
|
||||
expect_equal(as.matrix(res), x)
|
||||
res <- validate.features(model, as.matrix(mtcars[, rev(names(mtcars))]))
|
||||
expect_equal(colnames(res), colnames(x))
|
||||
expect_equal(res, x)
|
||||
res <- validate.features(model, mtcars[1, rev(names(mtcars)), drop = FALSE])
|
||||
expect_equal(names(res), colnames(x))
|
||||
expect_equal(unname(as.matrix(res)), unname(x[1, , drop = FALSE]))
|
||||
res <- validate.features(model, as.data.table(mtcars[, rev(names(mtcars))]))
|
||||
expect_equal(names(res), colnames(x))
|
||||
expect_equal(unname(as.matrix(res)), unname(x))
|
||||
|
||||
# error when columns are missing
|
||||
expect_error({
|
||||
validate.features(model, mtcars[, 1:3])
|
||||
})
|
||||
expect_error({
|
||||
validate.features(model, as.matrix(mtcars[, 1:ncol(x)])) # nolint
|
||||
})
|
||||
expect_error({
|
||||
validate.features(model, xgb.DMatrix(mtcars[, 1:3]))
|
||||
})
|
||||
expect_error({
|
||||
validate.features(model, as(x[, 1:3], "CsparseMatrix"))
|
||||
})
|
||||
|
||||
# error when it cannot reorder or subset
|
||||
expect_error({
|
||||
validate.features(model, xgb.DMatrix(mtcars))
|
||||
}, "Feature names")
|
||||
expect_error({
|
||||
validate.features(model, xgb.DMatrix(x[, rev(colnames(x))]))
|
||||
}, "Feature names")
|
||||
|
||||
# no error about types if the booster doesn't have types
|
||||
expect_error({
|
||||
validate.features(model, xgb.DMatrix(x, feature_types = c(rep("q", 5), rep("c", 5))))
|
||||
}, NA)
|
||||
tmp <- mtcars
|
||||
tmp[["vs"]] <- factor(tmp[["vs"]])
|
||||
expect_error({
|
||||
validate.features(model, tmp)
|
||||
}, NA)
|
||||
|
||||
# error when types do not match
|
||||
setinfo(model, "feature_type", rep("q", 10))
|
||||
expect_error({
|
||||
validate.features(model, xgb.DMatrix(x, feature_types = c(rep("q", 5), rep("c", 5))))
|
||||
}, "Feature types")
|
||||
tmp <- mtcars
|
||||
tmp[["vs"]] <- factor(tmp[["vs"]])
|
||||
expect_error({
|
||||
validate.features(model, tmp)
|
||||
}, "Feature types")
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user