[R] Fix nthread in DMatrix constructor. (#7127)

* Break the R C API for nthread.
This commit is contained in:
Jiaming Yuan
2021-08-03 17:39:25 +08:00
committed by GitHub
parent ba47eda61b
commit 8ee127469f
5 changed files with 18 additions and 11 deletions

View File

@@ -9,6 +9,8 @@
#include <cstring>
#include <cstdio>
#include <sstream>
#include "../../src/common/threading_utils.h"
#include "./xgboost_R.h"
/*!
@@ -77,7 +79,7 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
return ret;
}
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
SEXP ret;
R_API_BEGIN();
SEXP dim = getAttrib(mat, R_DimSymbol);
@@ -93,7 +95,9 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
}
std::vector<float> data(nrow * ncol);
dmlc::OMPException exc;
#pragma omp parallel for schedule(static)
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
#pragma omp parallel for schedule(static) num_threads(threads)
for (omp_ulong i = 0; i < nrow; ++i) {
exc.Run([&]() {
for (size_t j = 0; j < ncol; ++j) {
@@ -103,7 +107,8 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
}
exc.Rethrow();
DMatrixHandle handle;
CHECK_CALL(XGDMatrixCreateFromMat(BeginPtr(data), nrow, ncol, asReal(missing), &handle));
CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol,
asReal(missing), &handle, threads));
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END();