diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 7613c9152..febefb757 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -111,6 +111,21 @@ xgb.get.handle <- function(object) { #' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not. #' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output #' type and shape of predictions are invariant to the model type. +#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names +#' match (only applicable when both `object` and `newdata` have feature names). +#' +#' If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder +#' the columns in `newdata` to match with the booster's. +#' +#' If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`, +#' will additionally verify that categorical columns are of the correct type in `newdata`, +#' throwing an error if they do not match. +#' +#' If passing `FALSE`, it is assumed that the feature names and types are the same, +#' and come in the same order as in the training data. +#' +#' Note that this check might add some sizable latency to the predictions, so it's +#' recommended to disable it for performance-sensitive applications. #' @param ... Not used. #' #' @details @@ -271,7 +286,11 @@ xgb.get.handle <- function(object) { #' @export predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE, - reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, ...) { + reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, + validate_features = FALSE, ...) { + if (validate_features) { + newdata <- validate.features(object, newdata) + } if (!inherits(newdata, "xgb.DMatrix")) { nthread <- xgb.nthread(object) newdata <- xgb.DMatrix( @@ -418,6 +437,85 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA return(arr) } +validate.features <- function(bst, newdata) { + if (is.character(newdata)) { + # this will be encountered when passing file paths + return(newdata) + } + if (inherits(newdata, "sparseVector")) { + # in this case, newdata won't have metadata + return(newdata) + } + if (is.vector(newdata)) { + newdata <- as.matrix(newdata) + } + + booster_names <- getinfo(bst, "feature_name") + checked_names <- FALSE + if (NROW(booster_names)) { + + try_reorder <- FALSE + if (inherits(newdata, "xgb.DMatrix")) { + curr_names <- getinfo(newdata, "feature_name") + } else { + curr_names <- colnames(newdata) + try_reorder <- TRUE + } + + if (NROW(curr_names)) { + checked_names <- TRUE + + if (length(curr_names) != length(booster_names) || any(curr_names != booster_names)) { + + if (!try_reorder) { + stop("Feature names in 'newdata' do not match with booster's.") + } else { + if (inherits(newdata, "data.table")) { + newdata <- newdata[, booster_names, with = FALSE] + } else { + newdata <- newdata[, booster_names, drop = FALSE] + } + } + + } + + } # if (NROW(curr_names)) { + + } # if (NROW(booster_names)) { + + if (inherits(newdata, c("data.frame", "xgb.DMatrix"))) { + + booster_types <- getinfo(bst, "feature_type") + if (!NROW(booster_types)) { + # Note: types in the booster are optional. Other interfaces + # might not even save it as booster attributes for example, + # even if the model uses categorical features. + return(newdata) + } + if (inherits(newdata, "xgb.DMatrix")) { + curr_types <- getinfo(newdata, "feature_type") + if (length(curr_types) != length(booster_types) || any(curr_types != booster_types)) { + stop("Feature types in 'newdata' do not match with booster's.") + } + } + if (inherits(newdata, "data.frame")) { + is_factor <- sapply(newdata, is.factor) + if (any(is_factor != (booster_types == "c"))) { + stop( + paste0( + "Feature types in 'newdata' do not match with booster's for same columns (by ", + ifelse(checked_names, "name", "position"), + ")." + ) + ) + } + } + + } + + return(newdata) +} + #' @title Accessors for serializable attributes of a model #' diff --git a/R-package/man/predict.xgb.Booster.Rd b/R-package/man/predict.xgb.Booster.Rd index 7a6dd6c13..95e7a51fd 100644 --- a/R-package/man/predict.xgb.Booster.Rd +++ b/R-package/man/predict.xgb.Booster.Rd @@ -17,6 +17,7 @@ training = FALSE, iterationrange = NULL, strict_shape = FALSE, + validate_features = FALSE, ... ) } @@ -66,6 +67,23 @@ base-1 indexing, and inclusive of both ends). \item{strict_shape}{Default is \code{FALSE}. When set to \code{TRUE}, the output type and shape of predictions are invariant to the model type.} +\item{validate_features}{When \code{TRUE}, validate that the Booster's and newdata's feature_names +match (only applicable when both \code{object} and \code{newdata} have feature names). + +\if{html}{\out{
}}\preformatted{ If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder + the columns in `newdata` to match with the booster's. + + If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`, + will additionally verify that categorical columns are of the correct type in `newdata`, + throwing an error if they do not match. + + If passing `FALSE`, it is assumed that the feature names and types are the same, + and come in the same order as in the training data. + + Note that this check might add some sizable latency to the predictions, so it's + recommended to disable it for performance-sensitive applications. +}\if{html}{\out{
}}} + \item{...}{Not used.} } \value{ diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index badac0213..38b5ca066 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -511,3 +511,82 @@ test_that('convert.labels works', { expect_equal(class(res), 'numeric') } }) + +test_that("validate.features works as expected", { + data(mtcars) + y <- mtcars$mpg + x <- as.matrix(mtcars[, -1]) + dm <- xgb.DMatrix(x, label = y, nthread = 1) + model <- xgb.train( + params = list(nthread = 1), + data = dm, + nrounds = 3 + ) + + # result is output as-is when needed + res <- validate.features(model, x) + expect_equal(res, x) + res <- validate.features(model, dm) + expect_identical(res, dm) + res <- validate.features(model, as(x[1, ], "dsparseVector")) + expect_equal(as.numeric(res), unname(x[1, ])) + res <- validate.features(model, "file.txt") + expect_equal(res, "file.txt") + + # columns are reordered + res <- validate.features(model, mtcars[, rev(names(mtcars))]) + expect_equal(names(res), colnames(x)) + expect_equal(as.matrix(res), x) + res <- validate.features(model, as.matrix(mtcars[, rev(names(mtcars))])) + expect_equal(colnames(res), colnames(x)) + expect_equal(res, x) + res <- validate.features(model, mtcars[1, rev(names(mtcars)), drop = FALSE]) + expect_equal(names(res), colnames(x)) + expect_equal(unname(as.matrix(res)), unname(x[1, , drop = FALSE])) + res <- validate.features(model, as.data.table(mtcars[, rev(names(mtcars))])) + expect_equal(names(res), colnames(x)) + expect_equal(unname(as.matrix(res)), unname(x)) + + # error when columns are missing + expect_error({ + validate.features(model, mtcars[, 1:3]) + }) + expect_error({ + validate.features(model, as.matrix(mtcars[, 1:ncol(x)])) # nolint + }) + expect_error({ + validate.features(model, xgb.DMatrix(mtcars[, 1:3])) + }) + expect_error({ + validate.features(model, as(x[, 1:3], "CsparseMatrix")) + }) + + # error when it cannot reorder or subset + expect_error({ + validate.features(model, xgb.DMatrix(mtcars)) + }, "Feature names") + expect_error({ + validate.features(model, xgb.DMatrix(x[, rev(colnames(x))])) + }, "Feature names") + + # no error about types if the booster doesn't have types + expect_error({ + validate.features(model, xgb.DMatrix(x, feature_types = c(rep("q", 5), rep("c", 5)))) + }, NA) + tmp <- mtcars + tmp[["vs"]] <- factor(tmp[["vs"]]) + expect_error({ + validate.features(model, tmp) + }, NA) + + # error when types do not match + setinfo(model, "feature_type", rep("q", 10)) + expect_error({ + validate.features(model, xgb.DMatrix(x, feature_types = c(rep("q", 5), rep("c", 5)))) + }, "Feature types") + tmp <- mtcars + tmp[["vs"]] <- factor(tmp[["vs"]]) + expect_error({ + validate.features(model, tmp) + }, "Feature types") +})