Remove omp_get_max_threads in tree updaters. (#7590)
This commit is contained in:
@@ -146,8 +146,7 @@ class ColumnMatrix {
|
||||
}
|
||||
|
||||
// construct column matrix from GHistIndexMatrix
|
||||
inline void Init(const GHistIndexMatrix& gmat,
|
||||
double sparse_threshold) {
|
||||
inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold, int32_t n_threads) {
|
||||
const int32_t nfeature = static_cast<int32_t>(gmat.cut.Ptrs().size() - 1);
|
||||
const size_t nrow = gmat.row_ptr.size() - 1;
|
||||
// identify type of each column
|
||||
@@ -208,12 +207,15 @@ class ColumnMatrix {
|
||||
if (all_dense) {
|
||||
BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize();
|
||||
if (gmat_bin_size == kUint8BinsTypeSize) {
|
||||
SetIndexAllDense(gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues);
|
||||
SetIndexAllDense(gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues,
|
||||
n_threads);
|
||||
} else if (gmat_bin_size == kUint16BinsTypeSize) {
|
||||
SetIndexAllDense(gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues);
|
||||
SetIndexAllDense(gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues,
|
||||
n_threads);
|
||||
} else {
|
||||
CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize);
|
||||
SetIndexAllDense(gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues);
|
||||
CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize);
|
||||
SetIndexAllDense(gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues,
|
||||
n_threads);
|
||||
}
|
||||
/* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize
|
||||
but for ColumnMatrix we still have a chance to reduce the memory consumption */
|
||||
@@ -266,13 +268,13 @@ class ColumnMatrix {
|
||||
template <typename T>
|
||||
inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat,
|
||||
const size_t nrow, const size_t nfeature,
|
||||
const bool noMissingValues) {
|
||||
const bool noMissingValues, int32_t n_threads) {
|
||||
T* local_index = reinterpret_cast<T*>(&index_[0]);
|
||||
|
||||
/* missing values make sense only for column with type kDenseColumn,
|
||||
and if no missing values were observed it could be handled much faster. */
|
||||
if (noMissingValues) {
|
||||
ParallelFor(omp_ulong(nrow), [&](omp_ulong rid) {
|
||||
ParallelFor(nrow, n_threads, [&](auto rid) {
|
||||
const size_t ibegin = rid*nfeature;
|
||||
const size_t iend = (rid+1)*nfeature;
|
||||
size_t j = 0;
|
||||
|
||||
Reference in New Issue
Block a user