diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 7613c9152..febefb757 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -111,6 +111,21 @@ xgb.get.handle <- function(object) { #' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not. #' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output #' type and shape of predictions are invariant to the model type. +#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names +#' match (only applicable when both `object` and `newdata` have feature names). +#' +#' If the column names differ and `newdata` is not an `xgb.DMatrix`, will try to reorder +#' the columns in `newdata` to match with the booster's. +#' +#' If the booster has feature types and `newdata` is either an `xgb.DMatrix` or `data.frame`, +#' will additionally verify that categorical columns are of the correct type in `newdata`, +#' throwing an error if they do not match. +#' +#' If passing `FALSE`, it is assumed that the feature names and types are the same, +#' and come in the same order as in the training data. +#' +#' Note that this check might add some sizable latency to the predictions, so it's +#' recommended to disable it for performance-sensitive applications. #' @param ... Not used. #' #' @details @@ -271,7 +286,11 @@ xgb.get.handle <- function(object) { #' @export predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE, - reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, ...) { + reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, + validate_features = FALSE, ...) { + if (validate_features) { + newdata <- validate.features(object, newdata) + } if (!inherits(newdata, "xgb.DMatrix")) { nthread <- xgb.nthread(object) newdata <- xgb.DMatrix( @@ -418,6 +437,85 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA return(arr) } +validate.features <- function(bst, newdata) { + if (is.character(newdata)) { + # this will be encountered when passing file paths + return(newdata) + } + if (inherits(newdata, "sparseVector")) { + # in this case, newdata won't have metadata + return(newdata) + } + if (is.vector(newdata)) { + newdata <- as.matrix(newdata) + } + + booster_names <- getinfo(bst, "feature_name") + checked_names <- FALSE + if (NROW(booster_names)) { + + try_reorder <- FALSE + if (inherits(newdata, "xgb.DMatrix")) { + curr_names <- getinfo(newdata, "feature_name") + } else { + curr_names <- colnames(newdata) + try_reorder <- TRUE + } + + if (NROW(curr_names)) { + checked_names <- TRUE + + if (length(curr_names) != length(booster_names) || any(curr_names != booster_names)) { + + if (!try_reorder) { + stop("Feature names in 'newdata' do not match with booster's.") + } else { + if (inherits(newdata, "data.table")) { + newdata <- newdata[, booster_names, with = FALSE] + } else { + newdata <- newdata[, booster_names, drop = FALSE] + } + } + + } + + } # if (NROW(curr_names)) { + + } # if (NROW(booster_names)) { + + if (inherits(newdata, c("data.frame", "xgb.DMatrix"))) { + + booster_types <- getinfo(bst, "feature_type") + if (!NROW(booster_types)) { + # Note: types in the booster are optional. Other interfaces + # might not even save it as booster attributes for example, + # even if the model uses categorical features. + return(newdata) + } + if (inherits(newdata, "xgb.DMatrix")) { + curr_types <- getinfo(newdata, "feature_type") + if (length(curr_types) != length(booster_types) || any(curr_types != booster_types)) { + stop("Feature types in 'newdata' do not match with booster's.") + } + } + if (inherits(newdata, "data.frame")) { + is_factor <- sapply(newdata, is.factor) + if (any(is_factor != (booster_types == "c"))) { + stop( + paste0( + "Feature types in 'newdata' do not match with booster's for same columns (by ", + ifelse(checked_names, "name", "position"), + ")." + ) + ) + } + } + + } + + return(newdata) +} + #' @title Accessors for serializable attributes of a model #' diff --git a/R-package/man/predict.xgb.Booster.Rd b/R-package/man/predict.xgb.Booster.Rd index 7a6dd6c13..95e7a51fd 100644 --- a/R-package/man/predict.xgb.Booster.Rd +++ b/R-package/man/predict.xgb.Booster.Rd @@ -17,6 +17,7 @@ training = FALSE, iterationrange = NULL, strict_shape = FALSE, + validate_features = FALSE, ... ) } @@ -66,6 +67,23 @@ base-1 indexing, and inclusive of both ends). \item{strict_shape}{Default is \code{FALSE}. When set to \code{TRUE}, the output type and shape of predictions are invariant to the model type.} +\item{validate_features}{When \code{TRUE}, validate that the Booster's and newdata's feature_names +match (only applicable when both \code{object} and \code{newdata} have feature names). + +\if{html}{\out{