From 7f738e7f6fd0fa5b11db38fd0717b15426af53a9 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sat, 29 Jan 2022 18:54:57 +0200 Subject: [PATCH] [R] Accept CSR data for predictions (#7615) --- R-package/R/xgb.Booster.R | 6 ++++- R-package/R/xgb.DMatrix.R | 17 ++++++++++++-- R-package/man/predict.xgb.Booster.Rd | 8 +++++-- R-package/man/xgb.DMatrix.Rd | 6 +++-- R-package/src/init.c | 2 ++ R-package/src/xgboost_R.cc | 33 +++++++++++++++++++++++++++ R-package/src/xgboost_R.h | 12 ++++++++++ R-package/tests/testthat/test_basic.R | 16 +++++++++++++ 8 files changed, 93 insertions(+), 7 deletions(-) diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 640e04d0a..2f1f5091c 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -162,7 +162,11 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) { #' Predicted values based on either xgboost model or model handle object. #' #' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle} -#' @param newdata takes \code{matrix}, \code{dgCMatrix}, local data file or \code{xgb.DMatrix}. +#' @param newdata takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector}, +#' local data file or \code{xgb.DMatrix}. +#' +#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing +#' a sparse vector, it will take it as a row vector. #' @param missing Missing is only used when input is dense matrix. Pick a float value that represents #' missing values in data (e.g., sometimes 0 or some other extreme value is used). #' @param outputmargin whether the prediction should be returned in the for of original untransformed diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 265e860c4..970e317cd 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -4,8 +4,10 @@ #' Supported input file formats are either a LIBSVM text file or a binary file that was created previously by #' \code{\link{xgb.DMatrix.save}}). #' -#' @param data a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, or a character -#' string representing a filename. +#' @param data a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, +#' a \code{dgRMatrix} object (only when making predictions from a fitted model), +#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be +#' interpreted as a row vector), or a character string representing a filename. #' @param info a named list of additional information to store in the \code{xgb.DMatrix} object. #' See \code{\link{setinfo}} for the specific allowed kinds of #' @param missing a float value to represents missing values in data (used only when input is a dense matrix). @@ -37,6 +39,17 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data), as.integer(NVL(nthread, -1)) ) cnames <- colnames(data) + } else if (inherits(data, "dgRMatrix")) { + handle <- .Call( + XGDMatrixCreateFromCSR_R, data@p, data@j, data@x, ncol(data), as.integer(NVL(nthread, -1)) + ) + cnames <- colnames(data) + } else if (inherits(data, "dsparseVector")) { + indptr <- c(0L, as.integer(length(data@i))) + ind <- as.integer(data@i) - 1L + handle <- .Call( + XGDMatrixCreateFromCSR_R, indptr, ind, data@x, length(data), as.integer(NVL(nthread, -1)) + ) } else { stop("xgb.DMatrix does not support construction from ", typeof(data)) } diff --git a/R-package/man/predict.xgb.Booster.Rd b/R-package/man/predict.xgb.Booster.Rd index 34d1f86e5..067cbf207 100644 --- a/R-package/man/predict.xgb.Booster.Rd +++ b/R-package/man/predict.xgb.Booster.Rd @@ -27,7 +27,11 @@ \arguments{ \item{object}{Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}} -\item{newdata}{takes \code{matrix}, \code{dgCMatrix}, local data file or \code{xgb.DMatrix}.} +\item{newdata}{takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector}, + local data file or \code{xgb.DMatrix}. + + For single-row predictions on sparse data, it's recommended to use CSR format. If passing + a sparse vector, it will take it as a row vector.} \item{missing}{Missing is only used when input is dense matrix. Pick a float value that represents missing values in data (e.g., sometimes 0 or some other extreme value is used).} @@ -55,7 +59,7 @@ training predicting will perform dropout.} \item{iterationrange}{Specifies which layer of trees are used in prediction. For example, if a random forest is trained with 100 rounds. Specifying -`iteration_range=(1, 21)`, then only the forests built during [1, 21) (half open set) +`iterationrange=(1, 21)`, then only the forests built during [1, 21) (half open set) rounds are used in this prediction. It's 1-based index just like R vector. When set to \code{c(1, 1)} XGBoost will use all trees.} diff --git a/R-package/man/xgb.DMatrix.Rd b/R-package/man/xgb.DMatrix.Rd index df9977653..52a31cfd1 100644 --- a/R-package/man/xgb.DMatrix.Rd +++ b/R-package/man/xgb.DMatrix.Rd @@ -14,8 +14,10 @@ xgb.DMatrix( ) } \arguments{ -\item{data}{a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, or a character -string representing a filename.} +\item{data}{a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, +a \code{dgRMatrix} object (only when making predictions from a fitted model), +a \code{dsparseVector} object (only when making predictions from a fitted model, will be +interpreted as a row vector), or a character string representing a filename.} \item{info}{a named list of additional information to store in the \code{xgb.DMatrix} object. See \code{\link{setinfo}} for the specific allowed kinds of} diff --git a/R-package/src/init.c b/R-package/src/init.c index 9a4d0cd53..4e38f8220 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -38,6 +38,7 @@ extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP); extern SEXP XGCheckNullPtr_R(SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP); +extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP); @@ -73,6 +74,7 @@ static const R_CallMethodDef CallEntries[] = { {"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3}, {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 5}, + {"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 5}, {"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2}, {"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3}, {"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index f40af4cfe..2383eb9a6 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -164,6 +164,39 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, return ret; } +XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, + SEXP num_col, SEXP n_threads) { + SEXP ret; + R_API_BEGIN(); + const int *p_indptr = INTEGER(indptr); + const int *p_indices = INTEGER(indices); + const double *p_data = REAL(data); + size_t nindptr = static_cast(length(indptr)); + size_t ndata = static_cast(length(data)); + size_t ncol = static_cast(INTEGER(num_col)[0]); + std::vector row_ptr_(nindptr); + std::vector indices_(ndata); + std::vector data_(ndata); + + for (size_t i = 0; i < nindptr; ++i) { + row_ptr_[i] = static_cast(p_indptr[i]); + } + int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads)); + xgboost::common::ParallelFor(ndata, threads, [&](xgboost::omp_ulong i) { + indices_[i] = static_cast(p_indices[i]); + data_[i] = static_cast(p_data[i]); + }); + DMatrixHandle handle; + CHECK_CALL(XGDMatrixCreateFromCSREx(BeginPtr(row_ptr_), BeginPtr(indices_), + BeginPtr(data_), nindptr, ndata, + ncol, &handle)); + ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); + R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); + R_API_END(); + UNPROTECT(1); + return ret; +} + XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { SEXP ret; R_API_BEGIN(); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 16b9d4667..3ece8417d 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -65,6 +65,18 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_row, SEXP n_threads); +/*! + * \brief create a matrix content from CSR format + * \param indptr pointer to row headers + * \param indices column indices + * \param data content of the data + * \param num_col numer of columns (when it's set to 0, then guess from data) + * \param n_threads Number of threads used to construct DMatrix from csr matrix. + * \return created dmatrix + */ +XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_col, + SEXP n_threads); + /*! * \brief create a new dmatrix from sliced content of existing matrix * \param handle instance of data matrix to be sliced diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index e90232dae..ad8c8a830 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -1,4 +1,5 @@ require(xgboost) +library(Matrix) context("basic functions") @@ -459,3 +460,18 @@ test_that("strict_shape works", { test_iris() test_agaricus() }) + +test_that("'predict' accepts CSR data", { + X <- agaricus.train$data + y <- agaricus.train$label + x_csc <- as(X[1L, , drop = FALSE], "CsparseMatrix") + x_csr <- as(x_csc, "RsparseMatrix") + x_spv <- as(x_csc, "sparseVector") + bst <- xgboost(data = X, label = y, objective = "binary:logistic", + nrounds = 5L, verbose = FALSE) + p_csc <- predict(bst, x_csc) + p_csr <- predict(bst, x_csr) + p_spv <- predict(bst, x_spv) + expect_equal(p_csc, p_csr) + expect_equal(p_csc, p_spv) +})