From 8ee127469f76226f709e9bb8eab99e403d6c1b33 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 3 Aug 2021 17:39:25 +0800 Subject: [PATCH] [R] Fix nthread in DMatrix constructor. (#7127) * Break the R C API for nthread. --- R-package/R/xgb.DMatrix.R | 8 ++++---- R-package/R/xgboost.R | 2 +- R-package/src/init.c | 4 ++-- R-package/src/xgboost_R.cc | 11 ++++++++--- R-package/src/xgboost_R.h | 4 +++- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 2460e23a4..3ff0fae4f 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -20,7 +20,7 @@ #' dtrain <- xgb.DMatrix('xgb.DMatrix.data') #' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data') #' @export -xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...) { +xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) { cnames <- NULL if (typeof(data) == "character") { if (length(data) > 1) @@ -29,7 +29,7 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...) data <- path.expand(data) handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent)) } else if (is.matrix(data)) { - handle <- .Call(XGDMatrixCreateFromMat_R, data, missing) + handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1))) cnames <- colnames(data) } else if (inherits(data, "dgCMatrix")) { handle <- .Call(XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data)) @@ -51,12 +51,12 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...) # get dmatrix from data, label # internal helper method -xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL) { +xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL, nthread = NULL) { if (inherits(data, "dgCMatrix") || is.matrix(data)) { if (is.null(label)) { stop("label must be provided when data is a matrix") } - dtrain <- xgb.DMatrix(data, label = label, missing = missing) + dtrain <- xgb.DMatrix(data, label = label, missing = missing, nthread = nthread) if (!is.null(weight)){ setinfo(dtrain, "weight", weight) } diff --git a/R-package/R/xgboost.R b/R-package/R/xgboost.R index 460f3c963..4b8c15a3e 100644 --- a/R-package/R/xgboost.R +++ b/R-package/R/xgboost.R @@ -10,7 +10,7 @@ xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL, save_period = NULL, save_name = "xgboost.model", xgb_model = NULL, callbacks = list(), ...) { - dtrain <- xgb.get.DMatrix(data, label, missing, weight) + dtrain <- xgb.get.DMatrix(data, label, missing, weight, nthread = params$nthread) watchlist <- list(train = dtrain) diff --git a/R-package/src/init.c b/R-package/src/init.c index 5f136ff22..119fdb9f3 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -38,7 +38,7 @@ extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP); extern SEXP XGCheckNullPtr_R(SEXP); extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP); -extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP); +extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP); extern SEXP XGDMatrixNumCol_R(SEXP); extern SEXP XGDMatrixNumRow_R(SEXP); @@ -73,7 +73,7 @@ static const R_CallMethodDef CallEntries[] = { {"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1}, {"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 4}, {"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2}, - {"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 2}, + {"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3}, {"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2}, {"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1}, {"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 56fb61f8d..9921bb74b 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -9,6 +9,8 @@ #include #include #include + +#include "../../src/common/threading_utils.h" #include "./xgboost_R.h" /*! @@ -77,7 +79,7 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) { return ret; } -XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) { +XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) { SEXP ret; R_API_BEGIN(); SEXP dim = getAttrib(mat, R_DimSymbol); @@ -93,7 +95,9 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) { } std::vector data(nrow * ncol); dmlc::OMPException exc; - #pragma omp parallel for schedule(static) + int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads)); + +#pragma omp parallel for schedule(static) num_threads(threads) for (omp_ulong i = 0; i < nrow; ++i) { exc.Run([&]() { for (size_t j = 0; j < ncol; ++j) { @@ -103,7 +107,8 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) { } exc.Rethrow(); DMatrixHandle handle; - CHECK_CALL(XGDMatrixCreateFromMat(BeginPtr(data), nrow, ncol, asReal(missing), &handle)); + CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol, + asReal(missing), &handle, threads)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_API_END(); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 467aa54a3..786514593 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -47,10 +47,12 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent); * This assumes the matrix is stored in column major format * \param data R Matrix object * \param missing which value to represent missing value + * \param n_threads Number of threads used to construct DMatrix from dense matrix. * \return created dmatrix */ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, - SEXP missing); + SEXP missing, + SEXP n_threads); /*! * \brief create a matrix content from CSC format * \param indptr pointer to column headers