[R] Fix nthread in DMatrix constructor. (#7127)
* Break the R C API for nthread.
This commit is contained in:
parent
ba47eda61b
commit
8ee127469f
@ -20,7 +20,7 @@
|
|||||||
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
#' dtrain <- xgb.DMatrix('xgb.DMatrix.data')
|
||||||
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
|
#' if (file.exists('xgb.DMatrix.data')) file.remove('xgb.DMatrix.data')
|
||||||
#' @export
|
#' @export
|
||||||
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...) {
|
xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthread = NULL, ...) {
|
||||||
cnames <- NULL
|
cnames <- NULL
|
||||||
if (typeof(data) == "character") {
|
if (typeof(data) == "character") {
|
||||||
if (length(data) > 1)
|
if (length(data) > 1)
|
||||||
@ -29,7 +29,7 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...)
|
|||||||
data <- path.expand(data)
|
data <- path.expand(data)
|
||||||
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
|
handle <- .Call(XGDMatrixCreateFromFile_R, data, as.integer(silent))
|
||||||
} else if (is.matrix(data)) {
|
} else if (is.matrix(data)) {
|
||||||
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing)
|
handle <- .Call(XGDMatrixCreateFromMat_R, data, missing, as.integer(NVL(nthread, -1)))
|
||||||
cnames <- colnames(data)
|
cnames <- colnames(data)
|
||||||
} else if (inherits(data, "dgCMatrix")) {
|
} else if (inherits(data, "dgCMatrix")) {
|
||||||
handle <- .Call(XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data))
|
handle <- .Call(XGDMatrixCreateFromCSC_R, data@p, data@i, data@x, nrow(data))
|
||||||
@ -51,12 +51,12 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, ...)
|
|||||||
|
|
||||||
# get dmatrix from data, label
|
# get dmatrix from data, label
|
||||||
# internal helper method
|
# internal helper method
|
||||||
xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL) {
|
xgb.get.DMatrix <- function(data, label = NULL, missing = NA, weight = NULL, nthread = NULL) {
|
||||||
if (inherits(data, "dgCMatrix") || is.matrix(data)) {
|
if (inherits(data, "dgCMatrix") || is.matrix(data)) {
|
||||||
if (is.null(label)) {
|
if (is.null(label)) {
|
||||||
stop("label must be provided when data is a matrix")
|
stop("label must be provided when data is a matrix")
|
||||||
}
|
}
|
||||||
dtrain <- xgb.DMatrix(data, label = label, missing = missing)
|
dtrain <- xgb.DMatrix(data, label = label, missing = missing, nthread = nthread)
|
||||||
if (!is.null(weight)){
|
if (!is.null(weight)){
|
||||||
setinfo(dtrain, "weight", weight)
|
setinfo(dtrain, "weight", weight)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,7 +10,7 @@ xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL,
|
|||||||
save_period = NULL, save_name = "xgboost.model",
|
save_period = NULL, save_name = "xgboost.model",
|
||||||
xgb_model = NULL, callbacks = list(), ...) {
|
xgb_model = NULL, callbacks = list(), ...) {
|
||||||
|
|
||||||
dtrain <- xgb.get.DMatrix(data, label, missing, weight)
|
dtrain <- xgb.get.DMatrix(data, label, missing, weight, nthread = params$nthread)
|
||||||
|
|
||||||
watchlist <- list(train = dtrain)
|
watchlist <- list(train = dtrain)
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,7 @@ extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
|
|||||||
extern SEXP XGCheckNullPtr_R(SEXP);
|
extern SEXP XGCheckNullPtr_R(SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP);
|
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixNumCol_R(SEXP);
|
extern SEXP XGDMatrixNumCol_R(SEXP);
|
||||||
extern SEXP XGDMatrixNumRow_R(SEXP);
|
extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||||
@ -73,7 +73,7 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
|
||||||
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 4},
|
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 4},
|
||||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 2},
|
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
||||||
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
||||||
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
||||||
|
|||||||
@ -9,6 +9,8 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "../../src/common/threading_utils.h"
|
||||||
#include "./xgboost_R.h"
|
#include "./xgboost_R.h"
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -77,7 +79,7 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
|
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
|
||||||
SEXP ret;
|
SEXP ret;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
SEXP dim = getAttrib(mat, R_DimSymbol);
|
SEXP dim = getAttrib(mat, R_DimSymbol);
|
||||||
@ -93,7 +95,9 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
|
|||||||
}
|
}
|
||||||
std::vector<float> data(nrow * ncol);
|
std::vector<float> data(nrow * ncol);
|
||||||
dmlc::OMPException exc;
|
dmlc::OMPException exc;
|
||||||
#pragma omp parallel for schedule(static)
|
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
|
||||||
|
|
||||||
|
#pragma omp parallel for schedule(static) num_threads(threads)
|
||||||
for (omp_ulong i = 0; i < nrow; ++i) {
|
for (omp_ulong i = 0; i < nrow; ++i) {
|
||||||
exc.Run([&]() {
|
exc.Run([&]() {
|
||||||
for (size_t j = 0; j < ncol; ++j) {
|
for (size_t j = 0; j < ncol; ++j) {
|
||||||
@ -103,7 +107,8 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
|
|||||||
}
|
}
|
||||||
exc.Rethrow();
|
exc.Rethrow();
|
||||||
DMatrixHandle handle;
|
DMatrixHandle handle;
|
||||||
CHECK_CALL(XGDMatrixCreateFromMat(BeginPtr(data), nrow, ncol, asReal(missing), &handle));
|
CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol,
|
||||||
|
asReal(missing), &handle, threads));
|
||||||
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||||
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||||
R_API_END();
|
R_API_END();
|
||||||
|
|||||||
@ -47,10 +47,12 @@ XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent);
|
|||||||
* This assumes the matrix is stored in column major format
|
* This assumes the matrix is stored in column major format
|
||||||
* \param data R Matrix object
|
* \param data R Matrix object
|
||||||
* \param missing which value to represent missing value
|
* \param missing which value to represent missing value
|
||||||
|
* \param n_threads Number of threads used to construct DMatrix from dense matrix.
|
||||||
* \return created dmatrix
|
* \return created dmatrix
|
||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat,
|
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat,
|
||||||
SEXP missing);
|
SEXP missing,
|
||||||
|
SEXP n_threads);
|
||||||
/*!
|
/*!
|
||||||
* \brief create a matrix content from CSC format
|
* \brief create a matrix content from CSC format
|
||||||
* \param indptr pointer to column headers
|
* \param indptr pointer to column headers
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user