[R] Accept CSR data for predictions (#7615)

This commit is contained in:
david-cortes 2022-01-29 18:54:57 +02:00 committed by GitHub
parent 549bd419bb
commit 7f738e7f6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 93 additions and 7 deletions

View File

@ -162,7 +162,11 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
#' Predicted values based on either xgboost model or model handle object. #' Predicted values based on either xgboost model or model handle object.
#' #'
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle} #' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}
#' @param newdata takes \code{matrix}, \code{dgCMatrix}, local data file or \code{xgb.DMatrix}. #' @param newdata takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector},
#' local data file or \code{xgb.DMatrix}.
#'
#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
#' a sparse vector, it will take it as a row vector.
#' @param missing Missing is only used when input is dense matrix. Pick a float value that represents #' @param missing Missing is only used when input is dense matrix. Pick a float value that represents
#' missing values in data (e.g., sometimes 0 or some other extreme value is used). #' missing values in data (e.g., sometimes 0 or some other extreme value is used).
#' @param outputmargin whether the prediction should be returned in the for of original untransformed #' @param outputmargin whether the prediction should be returned in the for of original untransformed

View File

@ -4,8 +4,10 @@
#' Supported input file formats are either a LIBSVM text file or a binary file that was created previously by #' Supported input file formats are either a LIBSVM text file or a binary file that was created previously by
#' \code{\link{xgb.DMatrix.save}}). #' \code{\link{xgb.DMatrix.save}}).
#' #'
#' @param data a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, or a character #' @param data a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object,
#' string representing a filename. #' a \code{dgRMatrix} object (only when making predictions from a fitted model),
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
#' interpreted as a row vector), or a character string representing a filename.
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object. #' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
#' See \code{\link{setinfo}} for the specific allowed kinds of #' See \code{\link{setinfo}} for the specific allowed kinds of
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix). #' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
@ -37,6 +39,17 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data), as.integer(NVL(nthread, -1)) XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data), as.integer(NVL(nthread, -1))
) )
cnames <- colnames(data) cnames <- colnames(data)
} else if (inherits(data, "dgRMatrix")) {
handle <- .Call(
XGDMatrixCreateFromCSR_R, data@p, data@j, data@x, ncol(data), as.integer(NVL(nthread, -1))
)
cnames <- colnames(data)
} else if (inherits(data, "dsparseVector")) {
indptr <- c(0L, as.integer(length(data@i)))
ind <- as.integer(data@i) - 1L
handle <- .Call(
XGDMatrixCreateFromCSR_R, indptr, ind, data@x, length(data), as.integer(NVL(nthread, -1))
)
} else { } else {
stop("xgb.DMatrix does not support construction from ", typeof(data)) stop("xgb.DMatrix does not support construction from ", typeof(data))
} }

View File

@ -27,7 +27,11 @@
\arguments{ \arguments{
\item{object}{Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}} \item{object}{Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}}
\item{newdata}{takes \code{matrix}, \code{dgCMatrix}, local data file or \code{xgb.DMatrix}.} \item{newdata}{takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector},
local data file or \code{xgb.DMatrix}.
For single-row predictions on sparse data, it's recommended to use CSR format. If passing
a sparse vector, it will take it as a row vector.}
\item{missing}{Missing is only used when input is dense matrix. Pick a float value that represents \item{missing}{Missing is only used when input is dense matrix. Pick a float value that represents
missing values in data (e.g., sometimes 0 or some other extreme value is used).} missing values in data (e.g., sometimes 0 or some other extreme value is used).}
@ -55,7 +59,7 @@ training predicting will perform dropout.}
\item{iterationrange}{Specifies which layer of trees are used in prediction. For \item{iterationrange}{Specifies which layer of trees are used in prediction. For
example, if a random forest is trained with 100 rounds. Specifying example, if a random forest is trained with 100 rounds. Specifying
`iteration_range=(1, 21)`, then only the forests built during [1, 21) (half open set) `iterationrange=(1, 21)`, then only the forests built during [1, 21) (half open set)
rounds are used in this prediction. It's 1-based index just like R vector. When set rounds are used in this prediction. It's 1-based index just like R vector. When set
to \code{c(1, 1)} XGBoost will use all trees.} to \code{c(1, 1)} XGBoost will use all trees.}

View File

@ -14,8 +14,10 @@ xgb.DMatrix(
) )
} }
\arguments{ \arguments{
\item{data}{a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, or a character \item{data}{a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object,
string representing a filename.} a \code{dgRMatrix} object (only when making predictions from a fitted model),
a \code{dsparseVector} object (only when making predictions from a fitted model, will be
interpreted as a row vector), or a character string representing a filename.}
\item{info}{a named list of additional information to store in the \code{xgb.DMatrix} object. \item{info}{a named list of additional information to store in the \code{xgb.DMatrix} object.
See \code{\link{setinfo}} for the specific allowed kinds of} See \code{\link{setinfo}} for the specific allowed kinds of}

View File

@ -38,6 +38,7 @@ extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
extern SEXP XGCheckNullPtr_R(SEXP); extern SEXP XGCheckNullPtr_R(SEXP);
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP); extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
@ -73,6 +74,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3}, {"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 5}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 5},
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 5},
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2}, {"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3}, {"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2}, {"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},

View File

@ -164,6 +164,39 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data,
return ret; return ret;
} }
XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data,
SEXP num_col, SEXP n_threads) {
SEXP ret;
R_API_BEGIN();
const int *p_indptr = INTEGER(indptr);
const int *p_indices = INTEGER(indices);
const double *p_data = REAL(data);
size_t nindptr = static_cast<size_t>(length(indptr));
size_t ndata = static_cast<size_t>(length(data));
size_t ncol = static_cast<size_t>(INTEGER(num_col)[0]);
std::vector<size_t> row_ptr_(nindptr);
std::vector<unsigned> indices_(ndata);
std::vector<float> data_(ndata);
for (size_t i = 0; i < nindptr; ++i) {
row_ptr_[i] = static_cast<size_t>(p_indptr[i]);
}
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
xgboost::common::ParallelFor(ndata, threads, [&](xgboost::omp_ulong i) {
indices_[i] = static_cast<unsigned>(p_indices[i]);
data_[i] = static_cast<float>(p_data[i]);
});
DMatrixHandle handle;
CHECK_CALL(XGDMatrixCreateFromCSREx(BeginPtr(row_ptr_), BeginPtr(indices_),
BeginPtr(data_), nindptr, ndata,
ncol, &handle));
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END();
UNPROTECT(1);
return ret;
}
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
SEXP ret; SEXP ret;
R_API_BEGIN(); R_API_BEGIN();

View File

@ -65,6 +65,18 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat,
XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_row, XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_row,
SEXP n_threads); SEXP n_threads);
/*!
* \brief create a matrix content from CSR format
* \param indptr pointer to row headers
* \param indices column indices
* \param data content of the data
* \param num_col numer of columns (when it's set to 0, then guess from data)
* \param n_threads Number of threads used to construct DMatrix from csr matrix.
* \return created dmatrix
*/
XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_col,
SEXP n_threads);
/*! /*!
* \brief create a new dmatrix from sliced content of existing matrix * \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced * \param handle instance of data matrix to be sliced

View File

@ -1,4 +1,5 @@
require(xgboost) require(xgboost)
library(Matrix)
context("basic functions") context("basic functions")
@ -459,3 +460,18 @@ test_that("strict_shape works", {
test_iris() test_iris()
test_agaricus() test_agaricus()
}) })
test_that("'predict' accepts CSR data", {
X <- agaricus.train$data
y <- agaricus.train$label
x_csc <- as(X[1L, , drop = FALSE], "CsparseMatrix")
x_csr <- as(x_csc, "RsparseMatrix")
x_spv <- as(x_csc, "sparseVector")
bst <- xgboost(data = X, label = y, objective = "binary:logistic",
nrounds = 5L, verbose = FALSE)
p_csc <- predict(bst, x_csc)
p_csr <- predict(bst, x_csr)
p_spv <- predict(bst, x_spv)
expect_equal(p_csc, p_csr)
expect_equal(p_csc, p_spv)
})