[R] Use inplace predict (#9829)

---------

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
david-cortes 2024-02-23 19:03:54 +01:00 committed by GitHub
parent 729fd97196
commit f7005d32c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 450 additions and 46 deletions

View File

@ -77,26 +77,45 @@ xgb.get.handle <- function(object) {
#' Predict method for XGBoost model #' Predict method for XGBoost model
#' #'
#' Predicted values based on either xgboost model or model handle object. #' Predict values on data based on xgboost model.
#' #'
#' @param object Object of class `xgb.Booster`. #' @param object Object of class `xgb.Booster`.
#' @param newdata Takes `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`, #' @param newdata Takes `data.frame`, `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
#' local data file, or `xgb.DMatrix`. #' local data file, or `xgb.DMatrix`.
#' For single-row predictions on sparse data, it is recommended to use the CSR format. #'
#' If passing a sparse vector, it will take it as a row vector. #' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
#' @param missing Only used when input is a dense matrix. Pick a float value that represents #' a sparse vector, it will take it as a row vector.
#' missing values in data (e.g., 0 or some other extreme value). #'
#' Note that, for repeated predictions on the same data, one might want to create a DMatrix to
#' pass here instead of passing R types like matrices or data frames, as predictions will be
#' faster on DMatrix.
#'
#' If `newdata` is a `data.frame`, be aware that:\itemize{
#' \item Columns will be converted to numeric if they aren't already, which could potentially make
#' the operation slower than in an equivalent `matrix` object.
#' \item The order of the columns must match with that of the data from which the model was fitted
#' (i.e. columns will not be referenced by their names, just by their order in the data).
#' \item If the model was fitted to data with categorical columns, these columns must be of
#' `factor` type here, and must use the same encoding (i.e. have the same levels).
#' \item If `newdata` contains any `factor` columns, they will be converted to base-0
#' encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
#' under a column which during training had a different type.
#' }
#' @param missing Float value that represents missing values in data (e.g., 0 or some other extreme value).
#'
#' This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
#' this as an argument to the DMatrix constructor instead.
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed #' @param outputmargin Whether the prediction should be returned in the form of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for #' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
#' logistic regression would return log-odds instead of probabilities. #' logistic regression would return log-odds instead of probabilities.
#' @param predleaf Whether to predict pre-tree leaf indices. #' @param predleaf Whether to predict per-tree leaf indices.
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details). #' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details). #' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details). #' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several #' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`, #' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
#' or `predinteraction` is `TRUE`. #' or `predinteraction` is `TRUE`.
#' @param training Whether the predictions are used for training. For dart booster, #' @param training Whether the prediction result is used for training. For dart booster,
#' training predicting will perform dropout. #' training predicting will perform dropout.
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing #' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e. #' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
@ -111,6 +130,12 @@ xgb.get.handle <- function(object) {
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not. #' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output #' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
#' type and shape of predictions are invariant to the model type. #' type and shape of predictions are invariant to the model type.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
#' be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
#' an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
#'
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names #' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
#' match (only applicable when both `object` and `newdata` have feature names). #' match (only applicable when both `object` and `newdata` have feature names).
#' #'
@ -287,16 +312,80 @@ xgb.get.handle <- function(object) {
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE, predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
validate_features = FALSE, ...) { validate_features = FALSE, base_margin = NULL, ...) {
if (validate_features) { if (validate_features) {
newdata <- validate.features(object, newdata) newdata <- validate.features(object, newdata)
} }
if (!inherits(newdata, "xgb.DMatrix")) { is_dmatrix <- inherits(newdata, "xgb.DMatrix")
if (is_dmatrix && !is.null(base_margin)) {
stop(
"'base_margin' is not supported when passing 'xgb.DMatrix' as input.",
" Should be passed as argument to 'xgb.DMatrix' constructor."
)
}
use_as_df <- FALSE
use_as_dense_matrix <- FALSE
use_as_csr_matrix <- FALSE
n_row <- NULL
if (!is_dmatrix) {
inplace_predict_supported <- !predcontrib && !predinteraction && !predleaf
if (inplace_predict_supported) {
booster_type <- xgb.booster_type(object)
if (booster_type == "gblinear" || (booster_type == "dart" && training)) {
inplace_predict_supported <- FALSE
}
}
if (inplace_predict_supported) {
if (is.matrix(newdata)) {
use_as_dense_matrix <- TRUE
} else if (is.data.frame(newdata)) {
# note: since here it turns it into a non-data-frame list,
# needs to keep track of the number of rows it had for later
n_row <- nrow(newdata)
newdata <- lapply(
newdata,
function(x) {
if (is.factor(x)) {
return(as.numeric(x) - 1)
} else {
return(as.numeric(x))
}
}
)
use_as_df <- TRUE
} else if (inherits(newdata, "dgRMatrix")) {
use_as_csr_matrix <- TRUE
csr_data <- list(newdata@p, newdata@j, newdata@x, ncol(newdata))
} else if (inherits(newdata, "dsparseVector")) {
use_as_csr_matrix <- TRUE
n_row <- 1L
i <- newdata@i - 1L
if (storage.mode(i) != "integer") {
storage.mode(i) <- "integer"
}
csr_data <- list(c(0L, length(i)), i, newdata@x, length(newdata))
}
}
} # if (!is_dmatrix)
if (!is_dmatrix && !use_as_dense_matrix && !use_as_csr_matrix && !use_as_df) {
nthread <- xgb.nthread(object) nthread <- xgb.nthread(object)
newdata <- xgb.DMatrix( newdata <- xgb.DMatrix(
newdata, newdata,
missing = missing, nthread = NVL(nthread, -1) missing = missing,
base_margin = base_margin,
nthread = NVL(nthread, -1)
) )
is_dmatrix <- TRUE
}
if (is.null(n_row)) {
n_row <- nrow(newdata)
} }
@ -354,18 +443,30 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
args$type <- set_type(6) args$type <- set_type(6)
} }
predts <- .Call( json_conf <- jsonlite::toJSON(args, auto_unbox = TRUE)
XGBoosterPredictFromDMatrix_R, if (is_dmatrix) {
xgb.get.handle(object), predts <- .Call(
newdata, XGBoosterPredictFromDMatrix_R, xgb.get.handle(object), newdata, json_conf
jsonlite::toJSON(args, auto_unbox = TRUE) )
) } else if (use_as_dense_matrix) {
predts <- .Call(
XGBoosterPredictFromDense_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
} else if (use_as_csr_matrix) {
predts <- .Call(
XGBoosterPredictFromCSR_R, xgb.get.handle(object), csr_data, missing, json_conf, base_margin
)
} else if (use_as_df) {
predts <- .Call(
XGBoosterPredictFromColumnar_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
}
names(predts) <- c("shape", "results") names(predts) <- c("shape", "results")
shape <- predts$shape shape <- predts$shape
arr <- predts$results arr <- predts$results
n_ret <- length(arr) n_ret <- length(arr)
n_row <- nrow(newdata)
if (n_row != shape[1]) { if (n_row != shape[1]) {
stop("Incorrect predict shape.") stop("Incorrect predict shape.")
} }

View File

@ -18,25 +18,47 @@
iterationrange = NULL, iterationrange = NULL,
strict_shape = FALSE, strict_shape = FALSE,
validate_features = FALSE, validate_features = FALSE,
base_margin = NULL,
... ...
) )
} }
\arguments{ \arguments{
\item{object}{Object of class \code{xgb.Booster}.} \item{object}{Object of class \code{xgb.Booster}.}
\item{newdata}{Takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector}, \item{newdata}{Takes \code{data.frame}, \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector},
local data file, or \code{xgb.DMatrix}. local data file, or \code{xgb.DMatrix}.
For single-row predictions on sparse data, it is recommended to use the CSR format.
If passing a sparse vector, it will take it as a row vector.}
\item{missing}{Only used when input is a dense matrix. Pick a float value that represents \if{html}{\out{<div class="sourceCode">}}\preformatted{ For single-row predictions on sparse data, it's recommended to use CSR format. If passing
missing values in data (e.g., 0 or some other extreme value).} a sparse vector, it will take it as a row vector.
Note that, for repeated predictions on the same data, one might want to create a DMatrix to
pass here instead of passing R types like matrices or data frames, as predictions will be
faster on DMatrix.
If `newdata` is a `data.frame`, be aware that:\\itemize\{
\\item Columns will be converted to numeric if they aren't already, which could potentially make
the operation slower than in an equivalent `matrix` object.
\\item The order of the columns must match with that of the data from which the model was fitted
(i.e. columns will not be referenced by their names, just by their order in the data).
\\item If the model was fitted to data with categorical columns, these columns must be of
`factor` type here, and must use the same encoding (i.e. have the same levels).
\\item If `newdata` contains any `factor` columns, they will be converted to base-0
encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
under a column which during training had a different type.
\}
}\if{html}{\out{</div>}}}
\item{missing}{Float value that represents missing values in data (e.g., 0 or some other extreme value).
\if{html}{\out{<div class="sourceCode">}}\preformatted{ This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
this as an argument to the DMatrix constructor instead.
}\if{html}{\out{</div>}}}
\item{outputmargin}{Whether the prediction should be returned in the form of original untransformed \item{outputmargin}{Whether the prediction should be returned in the form of original untransformed
sum of predictions from boosting iterations' results. E.g., setting \code{outputmargin=TRUE} for sum of predictions from boosting iterations' results. E.g., setting \code{outputmargin=TRUE} for
logistic regression would return log-odds instead of probabilities.} logistic regression would return log-odds instead of probabilities.}
\item{predleaf}{Whether to predict pre-tree leaf indices.} \item{predleaf}{Whether to predict per-tree leaf indices.}
\item{predcontrib}{Whether to return feature contributions to individual predictions (see Details).} \item{predcontrib}{Whether to return feature contributions to individual predictions (see Details).}
@ -48,7 +70,7 @@ logistic regression would return log-odds instead of probabilities.}
prediction outputs per case. No effect if \code{predleaf}, \code{predcontrib}, prediction outputs per case. No effect if \code{predleaf}, \code{predcontrib},
or \code{predinteraction} is \code{TRUE}.} or \code{predinteraction} is \code{TRUE}.}
\item{training}{Whether the predictions are used for training. For dart booster, \item{training}{Whether the prediction result is used for training. For dart booster,
training predicting will perform dropout.} training predicting will perform dropout.}
\item{iterationrange}{Sequence of rounds/iterations from the model to use for prediction, specified by passing \item{iterationrange}{Sequence of rounds/iterations from the model to use for prediction, specified by passing
@ -84,6 +106,13 @@ match (only applicable when both \code{object} and \code{newdata} have feature n
recommended to disable it for performance-sensitive applications. recommended to disable it for performance-sensitive applications.
}\if{html}{\out{</div>}}} }\if{html}{\out{</div>}}}
\item{base_margin}{Base margin used for boosting from existing model.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
}\if{html}{\out{</div>}}}
\item{...}{Not used.} \item{...}{Not used.}
} }
\value{ \value{
@ -115,7 +144,7 @@ When \code{strict_shape = TRUE}, the output is always an array:
} }
} }
\description{ \description{
Predicted values based on either xgboost model or model handle object. Predict values on data based on xgboost model.
} }
\details{ \details{
Note that \code{iterationrange} would currently do nothing for predictions from "gblinear", Note that \code{iterationrange} would currently do nothing for predictions from "gblinear",

View File

@ -37,6 +37,9 @@ extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle); extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw); extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromDense_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromColumnar_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP); extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
@ -96,6 +99,9 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1}, {"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2}, {"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3}, {"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
{"XGBoosterPredictFromDense_R", (DL_FUNC) &XGBoosterPredictFromDense_R, 5},
{"XGBoosterPredictFromCSR_R", (DL_FUNC) &XGBoosterPredictFromCSR_R, 5},
{"XGBoosterPredictFromColumnar_R", (DL_FUNC) &XGBoosterPredictFromColumnar_R, 5},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2}, {"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3}, {"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3}, {"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},

View File

@ -13,6 +13,7 @@
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <memory>
#include <limits> #include <limits>
#include <sstream> #include <sstream>
#include <string> #include <string>
@ -207,25 +208,24 @@ SEXP SafeAllocInteger(size_t size, SEXP continuation_token) {
return xgboost::Json::Dump(jinterface); return xgboost::Json::Dump(jinterface);
} }
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) { void AddMissingToJson(xgboost::Json *jconfig, SEXP missing, SEXPTYPE arr_type) {
using namespace ::xgboost; // NOLINT if (Rf_isNull(missing) || ISNAN(Rf_asReal(missing))) {
Json jconfig{Object{}};
const SEXPTYPE missing_type = TYPEOF(missing);
if (Rf_isNull(missing) || (missing_type == REALSXP && ISNAN(Rf_asReal(missing))) ||
(missing_type == LGLSXP && Rf_asLogical(missing) == R_NaInt) ||
(missing_type == INTSXP && Rf_asInteger(missing) == R_NaInt)) {
// missing is not specified // missing is not specified
if (arr_type == REALSXP) { if (arr_type == REALSXP) {
jconfig["missing"] = std::numeric_limits<double>::quiet_NaN(); (*jconfig)["missing"] = std::numeric_limits<double>::quiet_NaN();
} else { } else {
jconfig["missing"] = R_NaInt; (*jconfig)["missing"] = R_NaInt;
} }
} else { } else {
// missing specified // missing specified
jconfig["missing"] = Rf_asReal(missing); (*jconfig)["missing"] = Rf_asReal(missing);
} }
}
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
using namespace ::xgboost; // NOLINT
Json jconfig{Object{}};
AddMissingToJson(&jconfig, missing, arr_type);
jconfig["nthread"] = Rf_asInteger(n_threads); jconfig["nthread"] = Rf_asInteger(n_threads);
return Json::Dump(jconfig); return Json::Dump(jconfig);
} }
@ -411,7 +411,7 @@ XGB_DLL SEXP XGDMatrixCreateFromDF_R(SEXP df, SEXP missing, SEXP n_threads) {
DMatrixHandle handle; DMatrixHandle handle;
std::int32_t rc{0}; std::int32_t rc{0};
{ {
std::string sinterface = MakeArrayInterfaceFromRDataFrame(df); const std::string sinterface = MakeArrayInterfaceFromRDataFrame(df);
xgboost::Json jconfig{xgboost::Object{}}; xgboost::Json jconfig{xgboost::Object{}};
jconfig["missing"] = asReal(missing); jconfig["missing"] = asReal(missing);
jconfig["nthread"] = asInteger(n_threads); jconfig["nthread"] = asInteger(n_threads);
@ -463,7 +463,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP
Json jconfig{Object{}}; Json jconfig{Object{}};
// Construct configuration // Construct configuration
jconfig["nthread"] = Integer{threads}; jconfig["nthread"] = Integer{threads};
jconfig["missing"] = xgboost::Number{asReal(missing)}; AddMissingToJson(&jconfig, missing, TYPEOF(data));
std::string config; std::string config;
Json::Dump(jconfig, &config); Json::Dump(jconfig, &config);
res_code = XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow, res_code = XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow,
@ -498,7 +498,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
Json jconfig{Object{}}; Json jconfig{Object{}};
// Construct configuration // Construct configuration
jconfig["nthread"] = Integer{threads}; jconfig["nthread"] = Integer{threads};
jconfig["missing"] = xgboost::Number{asReal(missing)}; AddMissingToJson(&jconfig, missing, TYPEOF(data));
std::string config; std::string config;
Json::Dump(jconfig, &config); Json::Dump(jconfig, &config);
res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol, res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
@ -1247,7 +1247,60 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
return mkString(ret); return mkString(ret);
} }
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) { namespace {
struct ProxyDmatrixError : public std::exception {};
struct ProxyDmatrixWrapper {
DMatrixHandle proxy_dmat_handle;
ProxyDmatrixWrapper() {
int res_code = XGProxyDMatrixCreate(&this->proxy_dmat_handle);
if (res_code != 0) {
throw ProxyDmatrixError();
}
}
~ProxyDmatrixWrapper() {
if (this->proxy_dmat_handle) {
XGDMatrixFree(this->proxy_dmat_handle);
this->proxy_dmat_handle = nullptr;
}
}
DMatrixHandle get_handle() {
return this->proxy_dmat_handle;
}
};
std::unique_ptr<ProxyDmatrixWrapper> GetProxyDMatrixWithBaseMargin(SEXP base_margin) {
if (Rf_isNull(base_margin)) {
return std::unique_ptr<ProxyDmatrixWrapper>(nullptr);
}
SEXP base_margin_dim = Rf_getAttrib(base_margin, R_DimSymbol);
int res_code;
try {
const std::string array_str = Rf_isNull(base_margin_dim)?
MakeArrayInterfaceFromRVector(base_margin) : MakeArrayInterfaceFromRMat(base_margin);
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat(new ProxyDmatrixWrapper());
res_code = XGDMatrixSetInfoFromInterface(proxy_dmat->get_handle(),
"base_margin",
array_str.c_str());
if (res_code != 0) {
throw ProxyDmatrixError();
}
return proxy_dmat;
} catch(ProxyDmatrixError &err) {
Rf_error("%s", XGBGetLastError());
}
}
enum class PredictionInputType {DMatrix, DenseMatrix, CSRMatrix, DataFrame};
SEXP XGBoosterPredictGeneric(SEXP handle, SEXP input_data, SEXP json_config,
PredictionInputType input_type, SEXP missing,
SEXP base_margin) {
SEXP r_out_shape; SEXP r_out_shape;
SEXP r_out_result; SEXP r_out_result;
SEXP r_out = PROTECT(allocVector(VECSXP, 2)); SEXP r_out = PROTECT(allocVector(VECSXP, 2));
@ -1259,9 +1312,79 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
bst_ulong out_dim; bst_ulong out_dim;
bst_ulong const *out_shape; bst_ulong const *out_shape;
float const *out_result; float const *out_result;
CHECK_CALL(XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat), c_json_config, int res_code;
&out_shape, &out_dim, &out_result)); {
switch (input_type) {
case PredictionInputType::DMatrix: {
res_code = XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(input_data), c_json_config,
&out_shape, &out_dim, &out_result);
break;
}
case PredictionInputType::CSRMatrix: {
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
base_margin);
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
SEXP indptr = VECTOR_ELT(input_data, 0);
SEXP indices = VECTOR_ELT(input_data, 1);
SEXP data = VECTOR_ELT(input_data, 2);
const int ncol_csr = Rf_asInteger(VECTOR_ELT(input_data, 3));
const SEXPTYPE type_data = TYPEOF(data);
CHECK_EQ(type_data, REALSXP);
std::string sindptr, sindices, sdata;
CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
xgboost::StringView json_str(c_json_config);
xgboost::Json new_json = xgboost::Json::Load(json_str);
AddMissingToJson(&new_json, missing, type_data);
const std::string new_c_json = xgboost::Json::Dump(new_json);
res_code = XGBoosterPredictFromCSR(
R_ExternalPtrAddr(handle), sindptr.c_str(), sindices.c_str(), sdata.c_str(),
ncol_csr, new_c_json.c_str(), proxy_dmat_handle, &out_shape, &out_dim, &out_result);
break;
}
case PredictionInputType::DenseMatrix: {
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
base_margin);
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
const std::string array_str = MakeArrayInterfaceFromRMat(input_data);
xgboost::StringView json_str(c_json_config);
xgboost::Json new_json = xgboost::Json::Load(json_str);
AddMissingToJson(&new_json, missing, TYPEOF(input_data));
const std::string new_c_json = xgboost::Json::Dump(new_json);
res_code = XGBoosterPredictFromDense(
R_ExternalPtrAddr(handle), array_str.c_str(), new_c_json.c_str(),
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
break;
}
case PredictionInputType::DataFrame: {
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
base_margin);
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
const std::string df_str = MakeArrayInterfaceFromRDataFrame(input_data);
xgboost::StringView json_str(c_json_config);
xgboost::Json new_json = xgboost::Json::Load(json_str);
AddMissingToJson(&new_json, missing, REALSXP);
const std::string new_c_json = xgboost::Json::Dump(new_json);
res_code = XGBoosterPredictFromColumnar(
R_ExternalPtrAddr(handle), df_str.c_str(), new_c_json.c_str(),
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
break;
}
}
}
CHECK_CALL(res_code);
r_out_shape = PROTECT(allocVector(INTSXP, out_dim)); r_out_shape = PROTECT(allocVector(INTSXP, out_dim));
size_t len = 1; size_t len = 1;
@ -1282,6 +1405,31 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
return r_out; return r_out;
} }
} // namespace
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
return XGBoosterPredictGeneric(handle, dmat, json_config,
PredictionInputType::DMatrix, R_NilValue, R_NilValue);
}
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
SEXP json_config, SEXP base_margin) {
return XGBoosterPredictGeneric(handle, R_mat, json_config,
PredictionInputType::DenseMatrix, missing, base_margin);
}
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
SEXP json_config, SEXP base_margin) {
return XGBoosterPredictGeneric(handle, lst, json_config,
PredictionInputType::CSRMatrix, missing, base_margin);
}
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
SEXP json_config, SEXP base_margin) {
return XGBoosterPredictGeneric(handle, R_df, json_config,
PredictionInputType::DataFrame, missing, base_margin);
}
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) { XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)))); CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));

View File

@ -371,6 +371,50 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
* \return A list containing 2 vectors, first one for shape while second one for prediction result. * \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/ */
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config); XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config);
/*!
* \brief Run prediction on R dense matrix
* \param handle handle
* \param R_mat R matrix
* \param missing missing value
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
* \param base_margin base margin for the prediction
*
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
SEXP json_config, SEXP base_margin);
/*!
* \brief Run prediction on R CSR matrix
* \param handle handle
* \param lst An R list, containing, in this order:
* (a) 'p' array (a.k.a. indptr)
* (b) 'j' array (a.k.a. indices)
* (c) 'x' array (a.k.a. data / values)
* (d) number of columns
* \param missing missing value
* \param json_config See `XGBoosterPredictFromCSR` in xgboost c_api.h. Doesn't include 'missing'
* \param base_margin base margin for the prediction
*
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
SEXP json_config, SEXP base_margin);
/*!
* \brief Run prediction on R data.frame
* \param handle handle
* \param R_df R data.frame
* \param missing missing value
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
* \param base_margin base margin for the prediction
*
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
SEXP json_config, SEXP base_margin);
/*! /*!
* \brief load model from existing file * \brief load model from existing file
* \param handle handle * \param handle handle

View File

@ -139,8 +139,8 @@ test_that("dart prediction works", {
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, iterationrange = c(1, nrounds)) pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, iterationrange = c(1, nrounds))
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE) pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
expect_true(all(matrix(pred_by_train_0, byrow = TRUE) == matrix(pred_by_xgboost_0, byrow = TRUE))) expect_equal(pred_by_train_0, pred_by_xgboost_0, tolerance = 1e-6)
expect_true(all(matrix(pred_by_train_1, byrow = TRUE) == matrix(pred_by_xgboost_1, byrow = TRUE))) expect_equal(pred_by_train_1, pred_by_xgboost_1, tolerance = 1e-6)
expect_true(all(matrix(pred_by_train_2, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE))) expect_true(all(matrix(pred_by_train_2, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE)))
}) })
@ -651,6 +651,51 @@ test_that("Can use ranking objectives with either 'qid' or 'group'", {
expect_equal(pred_qid, pred_gr) expect_equal(pred_qid, pred_gr)
}) })
test_that("Can predict on data.frame objects", {
data("mtcars")
y <- mtcars$mpg
x_df <- mtcars[, -1]
x_mat <- as.matrix(x_df)
dm <- xgb.DMatrix(x_mat, label = y, nthread = n_threads)
model <- xgb.train(
params = list(
tree_method = "hist",
objective = "reg:squarederror",
nthread = n_threads
),
data = dm,
nrounds = 5
)
pred_mat <- predict(model, xgb.DMatrix(x_mat), nthread = n_threads)
pred_df <- predict(model, x_df, nthread = n_threads)
expect_equal(pred_mat, pred_df)
})
test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", {
data("mtcars")
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
dm <- xgb.DMatrix(x, label = y, nthread = n_threads)
model <- xgb.train(
params = list(
tree_method = "hist",
objective = "reg:squarederror",
nthread = n_threads
),
data = dm,
nrounds = 5
)
set.seed(123)
base_margin <- rnorm(nrow(x))
dm_w_base <- xgb.DMatrix(data = x, base_margin = base_margin)
pred_from_dm <- predict(model, dm_w_base)
pred_from_mat <- predict(model, x, base_margin = base_margin)
expect_equal(pred_from_dm, pred_from_mat)
})
test_that("Coefficients from gblinear have the expected shape and names", { test_that("Coefficients from gblinear have the expected shape and names", {
# Single-column coefficients # Single-column coefficients
data(mtcars) data(mtcars)

View File

@ -302,6 +302,37 @@ test_that("xgb.DMatrix: Inf as missing", {
file.remove(fname_nan) file.remove(fname_nan)
}) })
test_that("xgb.DMatrix: missing in CSR", {
x_dense <- matrix(as.numeric(1:10), nrow = 5)
x_dense[2, 1] <- NA_real_
x_csr <- as(x_dense, "RsparseMatrix")
m_dense <- xgb.DMatrix(x_dense, nthread = n_threads, missing = NA_real_)
xgb.DMatrix.save(m_dense, "dense.dmatrix")
m_csr <- xgb.DMatrix(x_csr, nthread = n_threads, missing = NA)
xgb.DMatrix.save(m_csr, "csr.dmatrix")
denseconn <- file("dense.dmatrix", "rb")
csrconn <- file("csr.dmatrix", "rb")
expect_equal(file.size("dense.dmatrix"), file.size("csr.dmatrix"))
bytes <- file.size("dense.dmatrix")
densedmatrix <- readBin(denseconn, "raw", n = bytes)
csrmatrix <- readBin(csrconn, "raw", n = bytes)
expect_equal(length(densedmatrix), length(csrmatrix))
expect_equal(densedmatrix, csrmatrix)
close(denseconn)
close(csrconn)
file.remove("dense.dmatrix")
file.remove("csr.dmatrix")
})
test_that("xgb.DMatrix: error on three-dimensional array", { test_that("xgb.DMatrix: error on three-dimensional array", {
set.seed(123) set.seed(123)
x <- matrix(rnorm(500), nrow = 50) x <- matrix(rnorm(500), nrow = 50)