[R] Fix nthread in DMatrix constructor. (#7127)

* Break the R C API for nthread.
This commit is contained in:
Jiaming Yuan 2021-08-03 17:39:25 +08:00 committed by GitHub
parent ba47eda61b
commit 8ee127469f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 18 additions and 11 deletions

View File

@ -20,7 +20,7 @@
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data') #' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data') #' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
#' @export #' @export
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...) { xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) {
cnames <- NULL cnames <- NULL
if (typeof(data) == "character") { if (typeof(data) == "character") {
if (length(data) > 1) if (length(data) > 1)
@ -29,7 +29,7 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...)
data <- path.expand(data) data <- path.expand(data)
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent)) handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
} else if (is.matrix(data)) { } else if (is.matrix(data)) {
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing) handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
cnames <- colnames(data) cnames <- colnames(data)
} else if (inherits(data, "dgCMatrix")) { } else if (inherits(data, "dgCMatrix")) {
handle <- .Call(XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data)) handle <- .Call(XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data))
@ -51,12 +51,12 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...)
# get dmatrix from data, label # get dmatrix from data, label
# internal helper method # internal helper method
xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL) { xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL, nthread = NULL) {
if (inherits(data, "dgCMatrix") || is.matrix(data)) { if (inherits(data, "dgCMatrix") || is.matrix(data)) {
if (is.null(label)) { if (is.null(label)) {
stop("label must be provided when data is a matrix") stop("label must be provided when data is a matrix")
} }
dtrain <- xgb.DMatrix(data, label = label, missing = missing) dtrain <- xgb.DMatrix(data, label = label, missing = missing, nthread = nthread)
if (!is.null(weight)){ if (!is.null(weight)){
setinfo(dtrain, "weight", weight) setinfo(dtrain, "weight", weight)
} }

View File

@ -10,7 +10,7 @@ xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL,
save_period = NULL, save_name = "xgboost.model", save_period = NULL, save_name = "xgboost.model",
xgb_model = NULL, callbacks = list(), ...) { xgb_model = NULL, callbacks = list(), ...) {
dtrain <- xgb.get.DMatrix(data, label, missing, weight) dtrain <- xgb.get.DMatrix(data, label, missing, weight, nthread = params$nthread)
watchlist <- list(train = dtrain) watchlist <- list(train = dtrain)

View File

@ -38,7 +38,7 @@ extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
extern SEXP XGCheckNullPtr_R(SEXP); extern SEXP XGCheckNullPtr_R(SEXP);
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP); extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixNumCol_R(SEXP); extern SEXP XGDMatrixNumCol_R(SEXP);
extern SEXP XGDMatrixNumRow_R(SEXP); extern SEXP XGDMatrixNumRow_R(SEXP);
@ -73,7 +73,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 4}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 4},
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2}, {"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 2}, {"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2}, {"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1}, {"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1}, {"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},

View File

@ -9,6 +9,8 @@
#include <cstring> #include <cstring>
#include <cstdio> #include <cstdio>
#include <sstream> #include <sstream>
#include "../../src/common/threading_utils.h"
#include "./xgboost_R.h" #include "./xgboost_R.h"
/*! /*!
@ -77,7 +79,7 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
return ret; return ret;
} }
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) { XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
SEXP ret; SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
SEXP dim = getAttrib(mat, R_DimSymbol); SEXP dim = getAttrib(mat, R_DimSymbol);
@ -93,7 +95,9 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
} }
std::vector<float> data(nrow * ncol); std::vector<float> data(nrow * ncol);
dmlc::OMPException exc; dmlc::OMPException exc;
#pragma omp parallel for schedule(static) int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
#pragma omp parallel for schedule(static) num_threads(threads)
for (omp_ulong i = 0; i < nrow; ++i) { for (omp_ulong i = 0; i < nrow; ++i) {
exc.Run([&]() { exc.Run([&]() {
for (size_t j = 0; j < ncol; ++j) { for (size_t j = 0; j < ncol; ++j) {
@ -103,7 +107,8 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
} }
exc.Rethrow(); exc.Rethrow();
DMatrixHandle handle; DMatrixHandle handle;
CHECK_CALL(XGDMatrixCreateFromMat(BeginPtr(data), nrow, ncol, asReal(missing), &handle)); CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol,
asReal(missing), &handle, threads));
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END(); R_API_END();

View File

@ -47,10 +47,12 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent);
* This assumes the matrix is stored in column major format * This assumes the matrix is stored in column major format
* \param data R Matrix object * \param data R Matrix object
* \param missing which value to represent missing value * \param missing which value to represent missing value
* \param n_threads Number of threads used to construct DMatrix from dense matrix.
* \return created dmatrix * \return created dmatrix
*/ */
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat,
SEXP missing); SEXP missing,
SEXP n_threads);
/*! /*!
* \brief create a matrix content from CSC format * \brief create a matrix content from CSC format
* \param indptr pointer to column headers * \param indptr pointer to column headers