[R] Accept CSR data for predictions (#7615)

This commit is contained in:
david-cortes
2022-01-29 18:54:57 +02:00
committed by GitHub
parent 549bd419bb
commit 7f738e7f6f
8 changed files with 93 additions and 7 deletions

View File

@@ -164,6 +164,39 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data,
return ret;
}
XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data,
SEXP num_col, SEXP n_threads) {
SEXP ret;
R_API_BEGIN();
const int *p_indptr = INTEGER(indptr);
const int *p_indices = INTEGER(indices);
const double *p_data = REAL(data);
size_t nindptr = static_cast<size_t>(length(indptr));
size_t ndata = static_cast<size_t>(length(data));
size_t ncol = static_cast<size_t>(INTEGER(num_col)[0]);
std::vector<size_t> row_ptr_(nindptr);
std::vector<unsigned> indices_(ndata);
std::vector<float> data_(ndata);
for (size_t i = 0; i < nindptr; ++i) {
row_ptr_[i] = static_cast<size_t>(p_indptr[i]);
}
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
xgboost::common::ParallelFor(ndata, threads, [&](xgboost::omp_ulong i) {
indices_[i] = static_cast<unsigned>(p_indices[i]);
data_[i] = static_cast<float>(p_data[i]);
});
DMatrixHandle handle;
CHECK_CALL(XGDMatrixCreateFromCSREx(BeginPtr(row_ptr_), BeginPtr(indices_),
BeginPtr(data_), nindptr, ndata,
ncol, &handle));
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END();
UNPROTECT(1);
return ret;
}
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
SEXP ret;
R_API_BEGIN();