[R] Use inplace predict (#9829)
--------- Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
729fd97196
commit
f7005d32c1
@ -77,26 +77,45 @@ xgb.get.handle <- function(object) {
|
|||||||
|
|
||||||
#' Predict method for XGBoost model
|
#' Predict method for XGBoost model
|
||||||
#'
|
#'
|
||||||
#' Predicted values based on either xgboost model or model handle object.
|
#' Predict values on data based on xgboost model.
|
||||||
#'
|
#'
|
||||||
#' @param object Object of class `xgb.Booster`.
|
#' @param object Object of class `xgb.Booster`.
|
||||||
#' @param newdata Takes `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
|
#' @param newdata Takes `data.frame`, `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
|
||||||
#' local data file, or `xgb.DMatrix`.
|
#' local data file, or `xgb.DMatrix`.
|
||||||
#' For single-row predictions on sparse data, it is recommended to use the CSR format.
|
#'
|
||||||
#' If passing a sparse vector, it will take it as a row vector.
|
#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
|
||||||
#' @param missing Only used when input is a dense matrix. Pick a float value that represents
|
#' a sparse vector, it will take it as a row vector.
|
||||||
#' missing values in data (e.g., 0 or some other extreme value).
|
#'
|
||||||
|
#' Note that, for repeated predictions on the same data, one might want to create a DMatrix to
|
||||||
|
#' pass here instead of passing R types like matrices or data frames, as predictions will be
|
||||||
|
#' faster on DMatrix.
|
||||||
|
#'
|
||||||
|
#' If `newdata` is a `data.frame`, be aware that:\itemize{
|
||||||
|
#' \item Columns will be converted to numeric if they aren't already, which could potentially make
|
||||||
|
#' the operation slower than in an equivalent `matrix` object.
|
||||||
|
#' \item The order of the columns must match with that of the data from which the model was fitted
|
||||||
|
#' (i.e. columns will not be referenced by their names, just by their order in the data).
|
||||||
|
#' \item If the model was fitted to data with categorical columns, these columns must be of
|
||||||
|
#' `factor` type here, and must use the same encoding (i.e. have the same levels).
|
||||||
|
#' \item If `newdata` contains any `factor` columns, they will be converted to base-0
|
||||||
|
#' encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
|
||||||
|
#' under a column which during training had a different type.
|
||||||
|
#' }
|
||||||
|
#' @param missing Float value that represents missing values in data (e.g., 0 or some other extreme value).
|
||||||
|
#'
|
||||||
|
#' This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
|
||||||
|
#' this as an argument to the DMatrix constructor instead.
|
||||||
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed
|
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed
|
||||||
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
|
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
|
||||||
#' logistic regression would return log-odds instead of probabilities.
|
#' logistic regression would return log-odds instead of probabilities.
|
||||||
#' @param predleaf Whether to predict pre-tree leaf indices.
|
#' @param predleaf Whether to predict per-tree leaf indices.
|
||||||
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
|
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
|
||||||
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
|
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
|
||||||
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
|
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
|
||||||
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
|
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
|
||||||
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
|
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
|
||||||
#' or `predinteraction` is `TRUE`.
|
#' or `predinteraction` is `TRUE`.
|
||||||
#' @param training Whether the predictions are used for training. For dart booster,
|
#' @param training Whether the prediction result is used for training. For dart booster,
|
||||||
#' training predicting will perform dropout.
|
#' training predicting will perform dropout.
|
||||||
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
|
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
|
||||||
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
|
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
|
||||||
@ -111,6 +130,12 @@ xgb.get.handle <- function(object) {
|
|||||||
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
|
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
|
||||||
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
|
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
|
||||||
#' type and shape of predictions are invariant to the model type.
|
#' type and shape of predictions are invariant to the model type.
|
||||||
|
#' @param base_margin Base margin used for boosting from existing model.
|
||||||
|
#'
|
||||||
|
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
|
||||||
|
#' be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
|
||||||
|
#' an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
|
||||||
|
#'
|
||||||
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
|
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
|
||||||
#' match (only applicable when both `object` and `newdata` have feature names).
|
#' match (only applicable when both `object` and `newdata` have feature names).
|
||||||
#'
|
#'
|
||||||
@ -287,16 +312,80 @@ xgb.get.handle <- function(object) {
|
|||||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
|
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
|
||||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
||||||
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
|
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
|
||||||
validate_features = FALSE, ...) {
|
validate_features = FALSE, base_margin = NULL, ...) {
|
||||||
if (validate_features) {
|
if (validate_features) {
|
||||||
newdata <- validate.features(object, newdata)
|
newdata <- validate.features(object, newdata)
|
||||||
}
|
}
|
||||||
if (!inherits(newdata, "xgb.DMatrix")) {
|
is_dmatrix <- inherits(newdata, "xgb.DMatrix")
|
||||||
|
if (is_dmatrix && !is.null(base_margin)) {
|
||||||
|
stop(
|
||||||
|
"'base_margin' is not supported when passing 'xgb.DMatrix' as input.",
|
||||||
|
" Should be passed as argument to 'xgb.DMatrix' constructor."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
use_as_df <- FALSE
|
||||||
|
use_as_dense_matrix <- FALSE
|
||||||
|
use_as_csr_matrix <- FALSE
|
||||||
|
n_row <- NULL
|
||||||
|
if (!is_dmatrix) {
|
||||||
|
|
||||||
|
inplace_predict_supported <- !predcontrib && !predinteraction && !predleaf
|
||||||
|
if (inplace_predict_supported) {
|
||||||
|
booster_type <- xgb.booster_type(object)
|
||||||
|
if (booster_type == "gblinear" || (booster_type == "dart" && training)) {
|
||||||
|
inplace_predict_supported <- FALSE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (inplace_predict_supported) {
|
||||||
|
|
||||||
|
if (is.matrix(newdata)) {
|
||||||
|
use_as_dense_matrix <- TRUE
|
||||||
|
} else if (is.data.frame(newdata)) {
|
||||||
|
# note: since here it turns it into a non-data-frame list,
|
||||||
|
# needs to keep track of the number of rows it had for later
|
||||||
|
n_row <- nrow(newdata)
|
||||||
|
newdata <- lapply(
|
||||||
|
newdata,
|
||||||
|
function(x) {
|
||||||
|
if (is.factor(x)) {
|
||||||
|
return(as.numeric(x) - 1)
|
||||||
|
} else {
|
||||||
|
return(as.numeric(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
use_as_df <- TRUE
|
||||||
|
} else if (inherits(newdata, "dgRMatrix")) {
|
||||||
|
use_as_csr_matrix <- TRUE
|
||||||
|
csr_data <- list(newdata@p, newdata@j, newdata@x, ncol(newdata))
|
||||||
|
} else if (inherits(newdata, "dsparseVector")) {
|
||||||
|
use_as_csr_matrix <- TRUE
|
||||||
|
n_row <- 1L
|
||||||
|
i <- newdata@i - 1L
|
||||||
|
if (storage.mode(i) != "integer") {
|
||||||
|
storage.mode(i) <- "integer"
|
||||||
|
}
|
||||||
|
csr_data <- list(c(0L, length(i)), i, newdata@x, length(newdata))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
} # if (!is_dmatrix)
|
||||||
|
|
||||||
|
if (!is_dmatrix && !use_as_dense_matrix && !use_as_csr_matrix && !use_as_df) {
|
||||||
nthread <- xgb.nthread(object)
|
nthread <- xgb.nthread(object)
|
||||||
newdata <- xgb.DMatrix(
|
newdata <- xgb.DMatrix(
|
||||||
newdata,
|
newdata,
|
||||||
missing = missing, nthread = NVL(nthread, -1)
|
missing = missing,
|
||||||
|
base_margin = base_margin,
|
||||||
|
nthread = NVL(nthread, -1)
|
||||||
)
|
)
|
||||||
|
is_dmatrix <- TRUE
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is.null(n_row)) {
|
||||||
|
n_row <- nrow(newdata)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -354,18 +443,30 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
|||||||
args$type <- set_type(6)
|
args$type <- set_type(6)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json_conf <- jsonlite::toJSON(args, auto_unbox = TRUE)
|
||||||
|
if (is_dmatrix) {
|
||||||
predts <- .Call(
|
predts <- .Call(
|
||||||
XGBoosterPredictFromDMatrix_R,
|
XGBoosterPredictFromDMatrix_R, xgb.get.handle(object), newdata, json_conf
|
||||||
xgb.get.handle(object),
|
|
||||||
newdata,
|
|
||||||
jsonlite::toJSON(args, auto_unbox = TRUE)
|
|
||||||
)
|
)
|
||||||
|
} else if (use_as_dense_matrix) {
|
||||||
|
predts <- .Call(
|
||||||
|
XGBoosterPredictFromDense_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
|
||||||
|
)
|
||||||
|
} else if (use_as_csr_matrix) {
|
||||||
|
predts <- .Call(
|
||||||
|
XGBoosterPredictFromCSR_R, xgb.get.handle(object), csr_data, missing, json_conf, base_margin
|
||||||
|
)
|
||||||
|
} else if (use_as_df) {
|
||||||
|
predts <- .Call(
|
||||||
|
XGBoosterPredictFromColumnar_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
names(predts) <- c("shape", "results")
|
names(predts) <- c("shape", "results")
|
||||||
shape <- predts$shape
|
shape <- predts$shape
|
||||||
arr <- predts$results
|
arr <- predts$results
|
||||||
|
|
||||||
n_ret <- length(arr)
|
n_ret <- length(arr)
|
||||||
n_row <- nrow(newdata)
|
|
||||||
if (n_row != shape[1]) {
|
if (n_row != shape[1]) {
|
||||||
stop("Incorrect predict shape.")
|
stop("Incorrect predict shape.")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,25 +18,47 @@
|
|||||||
iterationrange = NULL,
|
iterationrange = NULL,
|
||||||
strict_shape = FALSE,
|
strict_shape = FALSE,
|
||||||
validate_features = FALSE,
|
validate_features = FALSE,
|
||||||
|
base_margin = NULL,
|
||||||
...
|
...
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{object}{Object of class \code{xgb.Booster}.}
|
\item{object}{Object of class \code{xgb.Booster}.}
|
||||||
|
|
||||||
\item{newdata}{Takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector},
|
\item{newdata}{Takes \code{data.frame}, \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector},
|
||||||
local data file, or \code{xgb.DMatrix}.
|
local data file, or \code{xgb.DMatrix}.
|
||||||
For single-row predictions on sparse data, it is recommended to use the CSR format.
|
|
||||||
If passing a sparse vector, it will take it as a row vector.}
|
|
||||||
|
|
||||||
\item{missing}{Only used when input is a dense matrix. Pick a float value that represents
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ For single-row predictions on sparse data, it's recommended to use CSR format. If passing
|
||||||
missing values in data (e.g., 0 or some other extreme value).}
|
a sparse vector, it will take it as a row vector.
|
||||||
|
|
||||||
|
Note that, for repeated predictions on the same data, one might want to create a DMatrix to
|
||||||
|
pass here instead of passing R types like matrices or data frames, as predictions will be
|
||||||
|
faster on DMatrix.
|
||||||
|
|
||||||
|
If `newdata` is a `data.frame`, be aware that:\\itemize\{
|
||||||
|
\\item Columns will be converted to numeric if they aren't already, which could potentially make
|
||||||
|
the operation slower than in an equivalent `matrix` object.
|
||||||
|
\\item The order of the columns must match with that of the data from which the model was fitted
|
||||||
|
(i.e. columns will not be referenced by their names, just by their order in the data).
|
||||||
|
\\item If the model was fitted to data with categorical columns, these columns must be of
|
||||||
|
`factor` type here, and must use the same encoding (i.e. have the same levels).
|
||||||
|
\\item If `newdata` contains any `factor` columns, they will be converted to base-0
|
||||||
|
encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
|
||||||
|
under a column which during training had a different type.
|
||||||
|
\}
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
|
\item{missing}{Float value that represents missing values in data (e.g., 0 or some other extreme value).
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
|
||||||
|
this as an argument to the DMatrix constructor instead.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{outputmargin}{Whether the prediction should be returned in the form of original untransformed
|
\item{outputmargin}{Whether the prediction should be returned in the form of original untransformed
|
||||||
sum of predictions from boosting iterations' results. E.g., setting \code{outputmargin=TRUE} for
|
sum of predictions from boosting iterations' results. E.g., setting \code{outputmargin=TRUE} for
|
||||||
logistic regression would return log-odds instead of probabilities.}
|
logistic regression would return log-odds instead of probabilities.}
|
||||||
|
|
||||||
\item{predleaf}{Whether to predict pre-tree leaf indices.}
|
\item{predleaf}{Whether to predict per-tree leaf indices.}
|
||||||
|
|
||||||
\item{predcontrib}{Whether to return feature contributions to individual predictions (see Details).}
|
\item{predcontrib}{Whether to return feature contributions to individual predictions (see Details).}
|
||||||
|
|
||||||
@ -48,7 +70,7 @@ logistic regression would return log-odds instead of probabilities.}
|
|||||||
prediction outputs per case. No effect if \code{predleaf}, \code{predcontrib},
|
prediction outputs per case. No effect if \code{predleaf}, \code{predcontrib},
|
||||||
or \code{predinteraction} is \code{TRUE}.}
|
or \code{predinteraction} is \code{TRUE}.}
|
||||||
|
|
||||||
\item{training}{Whether the predictions are used for training. For dart booster,
|
\item{training}{Whether the prediction result is used for training. For dart booster,
|
||||||
training predicting will perform dropout.}
|
training predicting will perform dropout.}
|
||||||
|
|
||||||
\item{iterationrange}{Sequence of rounds/iterations from the model to use for prediction, specified by passing
|
\item{iterationrange}{Sequence of rounds/iterations from the model to use for prediction, specified by passing
|
||||||
@ -84,6 +106,13 @@ match (only applicable when both \code{object} and \code{newdata} have feature n
|
|||||||
recommended to disable it for performance-sensitive applications.
|
recommended to disable it for performance-sensitive applications.
|
||||||
}\if{html}{\out{</div>}}}
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
|
\item{base_margin}{Base margin used for boosting from existing model.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
|
||||||
|
be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
|
||||||
|
an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{...}{Not used.}
|
\item{...}{Not used.}
|
||||||
}
|
}
|
||||||
\value{
|
\value{
|
||||||
@ -115,7 +144,7 @@ When \code{strict_shape = TRUE}, the output is always an array:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
\description{
|
\description{
|
||||||
Predicted values based on either xgboost model or model handle object.
|
Predict values on data based on xgboost model.
|
||||||
}
|
}
|
||||||
\details{
|
\details{
|
||||||
Note that \code{iterationrange} would currently do nothing for predictions from "gblinear",
|
Note that \code{iterationrange} would currently do nothing for predictions from "gblinear",
|
||||||
|
|||||||
@ -37,6 +37,9 @@ extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
|||||||
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
||||||
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
||||||
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
|
||||||
|
extern SEXP XGBoosterPredictFromDense_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
|
extern SEXP XGBoosterPredictFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
|
extern SEXP XGBoosterPredictFromColumnar_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
||||||
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
||||||
@ -96,6 +99,9 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
|
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
|
||||||
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
|
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
|
||||||
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
|
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
|
||||||
|
{"XGBoosterPredictFromDense_R", (DL_FUNC) &XGBoosterPredictFromDense_R, 5},
|
||||||
|
{"XGBoosterPredictFromCSR_R", (DL_FUNC) &XGBoosterPredictFromCSR_R, 5},
|
||||||
|
{"XGBoosterPredictFromColumnar_R", (DL_FUNC) &XGBoosterPredictFromColumnar_R, 5},
|
||||||
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
||||||
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
|
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
|
||||||
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
|
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <memory>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -207,25 +208,24 @@ SEXP SafeAllocInteger(size_t size, SEXP continuation_token) {
|
|||||||
return xgboost::Json::Dump(jinterface);
|
return xgboost::Json::Dump(jinterface);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
|
void AddMissingToJson(xgboost::Json *jconfig, SEXP missing, SEXPTYPE arr_type) {
|
||||||
using namespace ::xgboost; // NOLINT
|
if (Rf_isNull(missing) || ISNAN(Rf_asReal(missing))) {
|
||||||
Json jconfig{Object{}};
|
|
||||||
|
|
||||||
const SEXPTYPE missing_type = TYPEOF(missing);
|
|
||||||
if (Rf_isNull(missing) || (missing_type == REALSXP && ISNAN(Rf_asReal(missing))) ||
|
|
||||||
(missing_type == LGLSXP && Rf_asLogical(missing) == R_NaInt) ||
|
|
||||||
(missing_type == INTSXP && Rf_asInteger(missing) == R_NaInt)) {
|
|
||||||
// missing is not specified
|
// missing is not specified
|
||||||
if (arr_type == REALSXP) {
|
if (arr_type == REALSXP) {
|
||||||
jconfig["missing"] = std::numeric_limits<double>::quiet_NaN();
|
(*jconfig)["missing"] = std::numeric_limits<double>::quiet_NaN();
|
||||||
} else {
|
} else {
|
||||||
jconfig["missing"] = R_NaInt;
|
(*jconfig)["missing"] = R_NaInt;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// missing specified
|
// missing specified
|
||||||
jconfig["missing"] = Rf_asReal(missing);
|
(*jconfig)["missing"] = Rf_asReal(missing);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
|
||||||
|
using namespace ::xgboost; // NOLINT
|
||||||
|
Json jconfig{Object{}};
|
||||||
|
AddMissingToJson(&jconfig, missing, arr_type);
|
||||||
jconfig["nthread"] = Rf_asInteger(n_threads);
|
jconfig["nthread"] = Rf_asInteger(n_threads);
|
||||||
return Json::Dump(jconfig);
|
return Json::Dump(jconfig);
|
||||||
}
|
}
|
||||||
@ -411,7 +411,7 @@ XGB_DLL SEXP XGDMatrixCreateFromDF_R(SEXP df, SEXP missing, SEXP n_threads) {
|
|||||||
DMatrixHandle handle;
|
DMatrixHandle handle;
|
||||||
std::int32_t rc{0};
|
std::int32_t rc{0};
|
||||||
{
|
{
|
||||||
std::string sinterface = MakeArrayInterfaceFromRDataFrame(df);
|
const std::string sinterface = MakeArrayInterfaceFromRDataFrame(df);
|
||||||
xgboost::Json jconfig{xgboost::Object{}};
|
xgboost::Json jconfig{xgboost::Object{}};
|
||||||
jconfig["missing"] = asReal(missing);
|
jconfig["missing"] = asReal(missing);
|
||||||
jconfig["nthread"] = asInteger(n_threads);
|
jconfig["nthread"] = asInteger(n_threads);
|
||||||
@ -463,7 +463,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
|||||||
Json jconfig{Object{}};
|
Json jconfig{Object{}};
|
||||||
// Construct configuration
|
// Construct configuration
|
||||||
jconfig["nthread"] = Integer{threads};
|
jconfig["nthread"] = Integer{threads};
|
||||||
jconfig["missing"] = xgboost::Number{asReal(missing)};
|
AddMissingToJson(&jconfig, missing, TYPEOF(data));
|
||||||
std::string config;
|
std::string config;
|
||||||
Json::Dump(jconfig, &config);
|
Json::Dump(jconfig, &config);
|
||||||
res_code = XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow,
|
res_code = XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow,
|
||||||
@ -498,7 +498,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
|||||||
Json jconfig{Object{}};
|
Json jconfig{Object{}};
|
||||||
// Construct configuration
|
// Construct configuration
|
||||||
jconfig["nthread"] = Integer{threads};
|
jconfig["nthread"] = Integer{threads};
|
||||||
jconfig["missing"] = xgboost::Number{asReal(missing)};
|
AddMissingToJson(&jconfig, missing, TYPEOF(data));
|
||||||
std::string config;
|
std::string config;
|
||||||
Json::Dump(jconfig, &config);
|
Json::Dump(jconfig, &config);
|
||||||
res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
|
res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
|
||||||
@ -1247,7 +1247,60 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
|
|||||||
return mkString(ret);
|
return mkString(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
|
namespace {
|
||||||
|
|
||||||
|
struct ProxyDmatrixError : public std::exception {};
|
||||||
|
|
||||||
|
struct ProxyDmatrixWrapper {
|
||||||
|
DMatrixHandle proxy_dmat_handle;
|
||||||
|
|
||||||
|
ProxyDmatrixWrapper() {
|
||||||
|
int res_code = XGProxyDMatrixCreate(&this->proxy_dmat_handle);
|
||||||
|
if (res_code != 0) {
|
||||||
|
throw ProxyDmatrixError();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~ProxyDmatrixWrapper() {
|
||||||
|
if (this->proxy_dmat_handle) {
|
||||||
|
XGDMatrixFree(this->proxy_dmat_handle);
|
||||||
|
this->proxy_dmat_handle = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DMatrixHandle get_handle() {
|
||||||
|
return this->proxy_dmat_handle;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<ProxyDmatrixWrapper> GetProxyDMatrixWithBaseMargin(SEXP base_margin) {
|
||||||
|
if (Rf_isNull(base_margin)) {
|
||||||
|
return std::unique_ptr<ProxyDmatrixWrapper>(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
SEXP base_margin_dim = Rf_getAttrib(base_margin, R_DimSymbol);
|
||||||
|
int res_code;
|
||||||
|
try {
|
||||||
|
const std::string array_str = Rf_isNull(base_margin_dim)?
|
||||||
|
MakeArrayInterfaceFromRVector(base_margin) : MakeArrayInterfaceFromRMat(base_margin);
|
||||||
|
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat(new ProxyDmatrixWrapper());
|
||||||
|
res_code = XGDMatrixSetInfoFromInterface(proxy_dmat->get_handle(),
|
||||||
|
"base_margin",
|
||||||
|
array_str.c_str());
|
||||||
|
if (res_code != 0) {
|
||||||
|
throw ProxyDmatrixError();
|
||||||
|
}
|
||||||
|
return proxy_dmat;
|
||||||
|
} catch(ProxyDmatrixError &err) {
|
||||||
|
Rf_error("%s", XGBGetLastError());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class PredictionInputType {DMatrix, DenseMatrix, CSRMatrix, DataFrame};
|
||||||
|
|
||||||
|
SEXP XGBoosterPredictGeneric(SEXP handle, SEXP input_data, SEXP json_config,
|
||||||
|
PredictionInputType input_type, SEXP missing,
|
||||||
|
SEXP base_margin) {
|
||||||
SEXP r_out_shape;
|
SEXP r_out_shape;
|
||||||
SEXP r_out_result;
|
SEXP r_out_result;
|
||||||
SEXP r_out = PROTECT(allocVector(VECSXP, 2));
|
SEXP r_out = PROTECT(allocVector(VECSXP, 2));
|
||||||
@ -1259,9 +1312,79 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
|
|||||||
bst_ulong out_dim;
|
bst_ulong out_dim;
|
||||||
bst_ulong const *out_shape;
|
bst_ulong const *out_shape;
|
||||||
float const *out_result;
|
float const *out_result;
|
||||||
CHECK_CALL(XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
|
|
||||||
R_ExternalPtrAddr(dmat), c_json_config,
|
int res_code;
|
||||||
&out_shape, &out_dim, &out_result));
|
{
|
||||||
|
switch (input_type) {
|
||||||
|
case PredictionInputType::DMatrix: {
|
||||||
|
res_code = XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
|
||||||
|
R_ExternalPtrAddr(input_data), c_json_config,
|
||||||
|
&out_shape, &out_dim, &out_result);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case PredictionInputType::CSRMatrix: {
|
||||||
|
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
|
||||||
|
base_margin);
|
||||||
|
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
|
||||||
|
|
||||||
|
SEXP indptr = VECTOR_ELT(input_data, 0);
|
||||||
|
SEXP indices = VECTOR_ELT(input_data, 1);
|
||||||
|
SEXP data = VECTOR_ELT(input_data, 2);
|
||||||
|
const int ncol_csr = Rf_asInteger(VECTOR_ELT(input_data, 3));
|
||||||
|
const SEXPTYPE type_data = TYPEOF(data);
|
||||||
|
CHECK_EQ(type_data, REALSXP);
|
||||||
|
std::string sindptr, sindices, sdata;
|
||||||
|
CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
|
||||||
|
|
||||||
|
xgboost::StringView json_str(c_json_config);
|
||||||
|
xgboost::Json new_json = xgboost::Json::Load(json_str);
|
||||||
|
AddMissingToJson(&new_json, missing, type_data);
|
||||||
|
const std::string new_c_json = xgboost::Json::Dump(new_json);
|
||||||
|
|
||||||
|
res_code = XGBoosterPredictFromCSR(
|
||||||
|
R_ExternalPtrAddr(handle), sindptr.c_str(), sindices.c_str(), sdata.c_str(),
|
||||||
|
ncol_csr, new_c_json.c_str(), proxy_dmat_handle, &out_shape, &out_dim, &out_result);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case PredictionInputType::DenseMatrix: {
|
||||||
|
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
|
||||||
|
base_margin);
|
||||||
|
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
|
||||||
|
const std::string array_str = MakeArrayInterfaceFromRMat(input_data);
|
||||||
|
|
||||||
|
xgboost::StringView json_str(c_json_config);
|
||||||
|
xgboost::Json new_json = xgboost::Json::Load(json_str);
|
||||||
|
AddMissingToJson(&new_json, missing, TYPEOF(input_data));
|
||||||
|
const std::string new_c_json = xgboost::Json::Dump(new_json);
|
||||||
|
|
||||||
|
res_code = XGBoosterPredictFromDense(
|
||||||
|
R_ExternalPtrAddr(handle), array_str.c_str(), new_c_json.c_str(),
|
||||||
|
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case PredictionInputType::DataFrame: {
|
||||||
|
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
|
||||||
|
base_margin);
|
||||||
|
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
|
||||||
|
|
||||||
|
const std::string df_str = MakeArrayInterfaceFromRDataFrame(input_data);
|
||||||
|
|
||||||
|
xgboost::StringView json_str(c_json_config);
|
||||||
|
xgboost::Json new_json = xgboost::Json::Load(json_str);
|
||||||
|
AddMissingToJson(&new_json, missing, REALSXP);
|
||||||
|
const std::string new_c_json = xgboost::Json::Dump(new_json);
|
||||||
|
|
||||||
|
res_code = XGBoosterPredictFromColumnar(
|
||||||
|
R_ExternalPtrAddr(handle), df_str.c_str(), new_c_json.c_str(),
|
||||||
|
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK_CALL(res_code);
|
||||||
|
|
||||||
r_out_shape = PROTECT(allocVector(INTSXP, out_dim));
|
r_out_shape = PROTECT(allocVector(INTSXP, out_dim));
|
||||||
size_t len = 1;
|
size_t len = 1;
|
||||||
@ -1282,6 +1405,31 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
|
|||||||
return r_out;
|
return r_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
|
||||||
|
return XGBoosterPredictGeneric(handle, dmat, json_config,
|
||||||
|
PredictionInputType::DMatrix, R_NilValue, R_NilValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
|
||||||
|
SEXP json_config, SEXP base_margin) {
|
||||||
|
return XGBoosterPredictGeneric(handle, R_mat, json_config,
|
||||||
|
PredictionInputType::DenseMatrix, missing, base_margin);
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
|
||||||
|
SEXP json_config, SEXP base_margin) {
|
||||||
|
return XGBoosterPredictGeneric(handle, lst, json_config,
|
||||||
|
PredictionInputType::CSRMatrix, missing, base_margin);
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
|
||||||
|
SEXP json_config, SEXP base_margin) {
|
||||||
|
return XGBoosterPredictGeneric(handle, R_df, json_config,
|
||||||
|
PredictionInputType::DataFrame, missing, base_margin);
|
||||||
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
|
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
|
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
|
||||||
|
|||||||
@ -371,6 +371,50 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
|
|||||||
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config);
|
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Run prediction on R dense matrix
|
||||||
|
* \param handle handle
|
||||||
|
* \param R_mat R matrix
|
||||||
|
* \param missing missing value
|
||||||
|
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
|
||||||
|
* \param base_margin base margin for the prediction
|
||||||
|
*
|
||||||
|
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
|
||||||
|
SEXP json_config, SEXP base_margin);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Run prediction on R CSR matrix
|
||||||
|
* \param handle handle
|
||||||
|
* \param lst An R list, containing, in this order:
|
||||||
|
* (a) 'p' array (a.k.a. indptr)
|
||||||
|
* (b) 'j' array (a.k.a. indices)
|
||||||
|
* (c) 'x' array (a.k.a. data / values)
|
||||||
|
* (d) number of columns
|
||||||
|
* \param missing missing value
|
||||||
|
* \param json_config See `XGBoosterPredictFromCSR` in xgboost c_api.h. Doesn't include 'missing'
|
||||||
|
* \param base_margin base margin for the prediction
|
||||||
|
*
|
||||||
|
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
|
||||||
|
SEXP json_config, SEXP base_margin);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Run prediction on R data.frame
|
||||||
|
* \param handle handle
|
||||||
|
* \param R_df R data.frame
|
||||||
|
* \param missing missing value
|
||||||
|
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
|
||||||
|
* \param base_margin base margin for the prediction
|
||||||
|
*
|
||||||
|
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
|
||||||
|
SEXP json_config, SEXP base_margin);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from existing file
|
* \brief load model from existing file
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
@ -139,8 +139,8 @@ test_that("dart prediction works", {
|
|||||||
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, iterationrange = c(1, nrounds))
|
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, iterationrange = c(1, nrounds))
|
||||||
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
|
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
|
||||||
|
|
||||||
expect_true(all(matrix(pred_by_train_0, byrow = TRUE) == matrix(pred_by_xgboost_0, byrow = TRUE)))
|
expect_equal(pred_by_train_0, pred_by_xgboost_0, tolerance = 1e-6)
|
||||||
expect_true(all(matrix(pred_by_train_1, byrow = TRUE) == matrix(pred_by_xgboost_1, byrow = TRUE)))
|
expect_equal(pred_by_train_1, pred_by_xgboost_1, tolerance = 1e-6)
|
||||||
expect_true(all(matrix(pred_by_train_2, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE)))
|
expect_true(all(matrix(pred_by_train_2, byrow = TRUE) == matrix(pred_by_xgboost_2, byrow = TRUE)))
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -651,6 +651,51 @@ test_that("Can use ranking objectives with either 'qid' or 'group'", {
|
|||||||
expect_equal(pred_qid, pred_gr)
|
expect_equal(pred_qid, pred_gr)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Can predict on data.frame objects", {
|
||||||
|
data("mtcars")
|
||||||
|
y <- mtcars$mpg
|
||||||
|
x_df <- mtcars[, -1]
|
||||||
|
x_mat <- as.matrix(x_df)
|
||||||
|
dm <- xgb.DMatrix(x_mat, label = y, nthread = n_threads)
|
||||||
|
model <- xgb.train(
|
||||||
|
params = list(
|
||||||
|
tree_method = "hist",
|
||||||
|
objective = "reg:squarederror",
|
||||||
|
nthread = n_threads
|
||||||
|
),
|
||||||
|
data = dm,
|
||||||
|
nrounds = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_mat <- predict(model, xgb.DMatrix(x_mat), nthread = n_threads)
|
||||||
|
pred_df <- predict(model, x_df, nthread = n_threads)
|
||||||
|
expect_equal(pred_mat, pred_df)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", {
|
||||||
|
data("mtcars")
|
||||||
|
y <- mtcars$mpg
|
||||||
|
x <- as.matrix(mtcars[, -1])
|
||||||
|
dm <- xgb.DMatrix(x, label = y, nthread = n_threads)
|
||||||
|
model <- xgb.train(
|
||||||
|
params = list(
|
||||||
|
tree_method = "hist",
|
||||||
|
objective = "reg:squarederror",
|
||||||
|
nthread = n_threads
|
||||||
|
),
|
||||||
|
data = dm,
|
||||||
|
nrounds = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
set.seed(123)
|
||||||
|
base_margin <- rnorm(nrow(x))
|
||||||
|
dm_w_base <- xgb.DMatrix(data = x, base_margin = base_margin)
|
||||||
|
pred_from_dm <- predict(model, dm_w_base)
|
||||||
|
pred_from_mat <- predict(model, x, base_margin = base_margin)
|
||||||
|
|
||||||
|
expect_equal(pred_from_dm, pred_from_mat)
|
||||||
|
})
|
||||||
|
|
||||||
test_that("Coefficients from gblinear have the expected shape and names", {
|
test_that("Coefficients from gblinear have the expected shape and names", {
|
||||||
# Single-column coefficients
|
# Single-column coefficients
|
||||||
data(mtcars)
|
data(mtcars)
|
||||||
|
|||||||
@ -302,6 +302,37 @@ test_that("xgb.DMatrix: Inf as missing", {
|
|||||||
file.remove(fname_nan)
|
file.remove(fname_nan)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.DMatrix: missing in CSR", {
|
||||||
|
x_dense <- matrix(as.numeric(1:10), nrow = 5)
|
||||||
|
x_dense[2, 1] <- NA_real_
|
||||||
|
|
||||||
|
x_csr <- as(x_dense, "RsparseMatrix")
|
||||||
|
|
||||||
|
m_dense <- xgb.DMatrix(x_dense, nthread = n_threads, missing = NA_real_)
|
||||||
|
xgb.DMatrix.save(m_dense, "dense.dmatrix")
|
||||||
|
|
||||||
|
m_csr <- xgb.DMatrix(x_csr, nthread = n_threads, missing = NA)
|
||||||
|
xgb.DMatrix.save(m_csr, "csr.dmatrix")
|
||||||
|
|
||||||
|
denseconn <- file("dense.dmatrix", "rb")
|
||||||
|
csrconn <- file("csr.dmatrix", "rb")
|
||||||
|
|
||||||
|
expect_equal(file.size("dense.dmatrix"), file.size("csr.dmatrix"))
|
||||||
|
|
||||||
|
bytes <- file.size("dense.dmatrix")
|
||||||
|
densedmatrix <- readBin(denseconn, "raw", n = bytes)
|
||||||
|
csrmatrix <- readBin(csrconn, "raw", n = bytes)
|
||||||
|
|
||||||
|
expect_equal(length(densedmatrix), length(csrmatrix))
|
||||||
|
expect_equal(densedmatrix, csrmatrix)
|
||||||
|
|
||||||
|
close(denseconn)
|
||||||
|
close(csrconn)
|
||||||
|
|
||||||
|
file.remove("dense.dmatrix")
|
||||||
|
file.remove("csr.dmatrix")
|
||||||
|
})
|
||||||
|
|
||||||
test_that("xgb.DMatrix: error on three-dimensional array", {
|
test_that("xgb.DMatrix: error on three-dimensional array", {
|
||||||
set.seed(123)
|
set.seed(123)
|
||||||
x <- matrix(rnorm(500), nrow = 50)
|
x <- matrix(rnorm(500), nrow = 50)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user