[R] Accept CSR data for predictions (#7615)
This commit is contained in:
parent
549bd419bb
commit
7f738e7f6f
@ -162,7 +162,11 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
|||||||
#' Predicted values based on either xgboost model or model handle object.
|
#' Predicted values based on either xgboost model or model handle object.
|
||||||
#'
|
#'
|
||||||
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}
|
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}
|
||||||
#' @param newdata takes \code{matrix}, \code{dgCMatrix}, local data file or \code{xgb.DMatrix}.
|
#' @param newdata takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector},
|
||||||
|
#' local data file or \code{xgb.DMatrix}.
|
||||||
|
#'
|
||||||
|
#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
|
||||||
|
#' a sparse vector, it will take it as a row vector.
|
||||||
#' @param missing Missing is only used when input is dense matrix. Pick a float value that represents
|
#' @param missing Missing is only used when input is dense matrix. Pick a float value that represents
|
||||||
#' missing values in data (e.g., sometimes 0 or some other extreme value is used).
|
#' missing values in data (e.g., sometimes 0 or some other extreme value is used).
|
||||||
#' @param outputmargin whether the prediction should be returned in the for of original untransformed
|
#' @param outputmargin whether the prediction should be returned in the for of original untransformed
|
||||||
|
|||||||
@ -4,8 +4,10 @@
|
|||||||
#' Supported input file formats are either a LIBSVM text file or a binary file that was created previously by
|
#' Supported input file formats are either a LIBSVM text file or a binary file that was created previously by
|
||||||
#' \code{\link{xgb.DMatrix.save}}).
|
#' \code{\link{xgb.DMatrix.save}}).
|
||||||
#'
|
#'
|
||||||
#' @param data a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, or a character
|
#' @param data a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object,
|
||||||
#' string representing a filename.
|
#' a \code{dgRMatrix} object (only when making predictions from a fitted model),
|
||||||
|
#' a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
||||||
|
#' interpreted as a row vector), or a character string representing a filename.
|
||||||
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
|
#' @param info a named list of additional information to store in the \code{xgb.DMatrix} object.
|
||||||
#' See \code{\link{setinfo}} for the specific allowed kinds of
|
#' See \code{\link{setinfo}} for the specific allowed kinds of
|
||||||
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
||||||
@ -37,6 +39,17 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre
|
|||||||
XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data), as.integer(NVL(nthread, -1))
|
XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data), as.integer(NVL(nthread, -1))
|
||||||
)
|
)
|
||||||
cnames <- colnames(data)
|
cnames <- colnames(data)
|
||||||
|
} else if (inherits(data, "dgRMatrix")) {
|
||||||
|
handle <- .Call(
|
||||||
|
XGDMatrixCreateFromCSR_R, data@p, data@j, data@x, ncol(data), as.integer(NVL(nthread, -1))
|
||||||
|
)
|
||||||
|
cnames <- colnames(data)
|
||||||
|
} else if (inherits(data, "dsparseVector")) {
|
||||||
|
indptr <- c(0L, as.integer(length(data@i)))
|
||||||
|
ind <- as.integer(data@i) - 1L
|
||||||
|
handle <- .Call(
|
||||||
|
XGDMatrixCreateFromCSR_R, indptr, ind, data@x, length(data), as.integer(NVL(nthread, -1))
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
stop("xgb.DMatrix does not support construction from ", typeof(data))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,7 +27,11 @@
|
|||||||
\arguments{
|
\arguments{
|
||||||
\item{object}{Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}}
|
\item{object}{Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}}
|
||||||
|
|
||||||
\item{newdata}{takes \code{matrix}, \code{dgCMatrix}, local data file or \code{xgb.DMatrix}.}
|
\item{newdata}{takes \code{matrix}, \code{dgCMatrix}, \code{dgRMatrix}, \code{dsparseVector},
|
||||||
|
local data file or \code{xgb.DMatrix}.
|
||||||
|
|
||||||
|
For single-row predictions on sparse data, it's recommended to use CSR format. If passing
|
||||||
|
a sparse vector, it will take it as a row vector.}
|
||||||
|
|
||||||
\item{missing}{Missing is only used when input is dense matrix. Pick a float value that represents
|
\item{missing}{Missing is only used when input is dense matrix. Pick a float value that represents
|
||||||
missing values in data (e.g., sometimes 0 or some other extreme value is used).}
|
missing values in data (e.g., sometimes 0 or some other extreme value is used).}
|
||||||
@ -55,7 +59,7 @@ training predicting will perform dropout.}
|
|||||||
|
|
||||||
\item{iterationrange}{Specifies which layer of trees are used in prediction. For
|
\item{iterationrange}{Specifies which layer of trees are used in prediction. For
|
||||||
example, if a random forest is trained with 100 rounds. Specifying
|
example, if a random forest is trained with 100 rounds. Specifying
|
||||||
`iteration_range=(1, 21)`, then only the forests built during [1, 21) (half open set)
|
`iterationrange=(1, 21)`, then only the forests built during [1, 21) (half open set)
|
||||||
rounds are used in this prediction. It's 1-based index just like R vector. When set
|
rounds are used in this prediction. It's 1-based index just like R vector. When set
|
||||||
to \code{c(1, 1)} XGBoost will use all trees.}
|
to \code{c(1, 1)} XGBoost will use all trees.}
|
||||||
|
|
||||||
|
|||||||
@ -14,8 +14,10 @@ xgb.DMatrix(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{data}{a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object, or a character
|
\item{data}{a \code{matrix} object (either numeric or integer), a \code{dgCMatrix} object,
|
||||||
string representing a filename.}
|
a \code{dgRMatrix} object (only when making predictions from a fitted model),
|
||||||
|
a \code{dsparseVector} object (only when making predictions from a fitted model, will be
|
||||||
|
interpreted as a row vector), or a character string representing a filename.}
|
||||||
|
|
||||||
\item{info}{a named list of additional information to store in the \code{xgb.DMatrix} object.
|
\item{info}{a named list of additional information to store in the \code{xgb.DMatrix} object.
|
||||||
See \code{\link{setinfo}} for the specific allowed kinds of}
|
See \code{\link{setinfo}} for the specific allowed kinds of}
|
||||||
|
|||||||
@ -38,6 +38,7 @@ extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
|||||||
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGCheckNullPtr_R(SEXP);
|
extern SEXP XGCheckNullPtr_R(SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
|
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
||||||
@ -73,6 +74,7 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
|
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
|
||||||
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
||||||
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 5},
|
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 5},
|
||||||
|
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 5},
|
||||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
||||||
|
|||||||
@ -164,6 +164,39 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data,
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data,
|
||||||
|
SEXP num_col, SEXP n_threads) {
|
||||||
|
SEXP ret;
|
||||||
|
R_API_BEGIN();
|
||||||
|
const int *p_indptr = INTEGER(indptr);
|
||||||
|
const int *p_indices = INTEGER(indices);
|
||||||
|
const double *p_data = REAL(data);
|
||||||
|
size_t nindptr = static_cast<size_t>(length(indptr));
|
||||||
|
size_t ndata = static_cast<size_t>(length(data));
|
||||||
|
size_t ncol = static_cast<size_t>(INTEGER(num_col)[0]);
|
||||||
|
std::vector<size_t> row_ptr_(nindptr);
|
||||||
|
std::vector<unsigned> indices_(ndata);
|
||||||
|
std::vector<float> data_(ndata);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < nindptr; ++i) {
|
||||||
|
row_ptr_[i] = static_cast<size_t>(p_indptr[i]);
|
||||||
|
}
|
||||||
|
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
|
||||||
|
xgboost::common::ParallelFor(ndata, threads, [&](xgboost::omp_ulong i) {
|
||||||
|
indices_[i] = static_cast<unsigned>(p_indices[i]);
|
||||||
|
data_[i] = static_cast<float>(p_data[i]);
|
||||||
|
});
|
||||||
|
DMatrixHandle handle;
|
||||||
|
CHECK_CALL(XGDMatrixCreateFromCSREx(BeginPtr(row_ptr_), BeginPtr(indices_),
|
||||||
|
BeginPtr(data_), nindptr, ndata,
|
||||||
|
ncol, &handle));
|
||||||
|
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||||
|
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||||
|
R_API_END();
|
||||||
|
UNPROTECT(1);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
||||||
SEXP ret;
|
SEXP ret;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
|
|||||||
@ -65,6 +65,18 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat,
|
|||||||
XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_row,
|
XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_row,
|
||||||
SEXP n_threads);
|
SEXP n_threads);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief create a matrix content from CSR format
|
||||||
|
* \param indptr pointer to row headers
|
||||||
|
* \param indices column indices
|
||||||
|
* \param data content of the data
|
||||||
|
* \param num_col numer of columns (when it's set to 0, then guess from data)
|
||||||
|
* \param n_threads Number of threads used to construct DMatrix from csr matrix.
|
||||||
|
* \return created dmatrix
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP num_col,
|
||||||
|
SEXP n_threads);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief create a new dmatrix from sliced content of existing matrix
|
* \brief create a new dmatrix from sliced content of existing matrix
|
||||||
* \param handle instance of data matrix to be sliced
|
* \param handle instance of data matrix to be sliced
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
require(xgboost)
|
require(xgboost)
|
||||||
|
library(Matrix)
|
||||||
|
|
||||||
context("basic functions")
|
context("basic functions")
|
||||||
|
|
||||||
@ -459,3 +460,18 @@ test_that("strict_shape works", {
|
|||||||
test_iris()
|
test_iris()
|
||||||
test_agaricus()
|
test_agaricus()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("'predict' accepts CSR data", {
|
||||||
|
X <- agaricus.train$data
|
||||||
|
y <- agaricus.train$label
|
||||||
|
x_csc <- as(X[1L, , drop = FALSE], "CsparseMatrix")
|
||||||
|
x_csr <- as(x_csc, "RsparseMatrix")
|
||||||
|
x_spv <- as(x_csc, "sparseVector")
|
||||||
|
bst <- xgboost(data = X, label = y, objective = "binary:logistic",
|
||||||
|
nrounds = 5L, verbose = FALSE)
|
||||||
|
p_csc <- predict(bst, x_csc)
|
||||||
|
p_csr <- predict(bst, x_csr)
|
||||||
|
p_spv <- predict(bst, x_spv)
|
||||||
|
expect_equal(p_csc, p_csr)
|
||||||
|
expect_equal(p_csc, p_spv)
|
||||||
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user