diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 75f0eb60e..69973bfcf 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -302,22 +302,7 @@ class SparsePage { SparsePage GetTranspose(int num_columns, int32_t n_threads) const; - void SortRows() { - auto ncol = static_cast(this->Size()); - dmlc::OMPException exc; -#pragma omp parallel for schedule(dynamic, 1) - for (bst_omp_uint i = 0; i < ncol; ++i) { - exc.Run([&]() { - if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) { - std::sort( - this->data.HostVector().begin() + this->offset.HostVector()[i], - this->data.HostVector().begin() + this->offset.HostVector()[i + 1], - Entry::CmpValue); - } - }); - } - exc.Rethrow(); - } + void SortRows(int32_t n_threads); /** * \brief Pushes external data batch onto this page diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 35e50252f..5cee0779c 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file tree_updater.h * \brief General primitive for tree learning, * Updating a collection of trees given the information. @@ -32,7 +32,7 @@ class Json; */ class TreeUpdater : public Configurable { protected: - GenericParameter const* tparam_; + GenericParameter const* ctx_; public: /*! \brief virtual destructor */ diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 804d7a568..b652bcc4a 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -146,8 +146,7 @@ class ColumnMatrix { } // construct column matrix from GHistIndexMatrix - inline void Init(const GHistIndexMatrix& gmat, - double sparse_threshold) { + inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold, int32_t n_threads) { const int32_t nfeature = static_cast(gmat.cut.Ptrs().size() - 1); const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column @@ -208,12 +207,15 @@ class ColumnMatrix { if (all_dense) { BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); if (gmat_bin_size == kUint8BinsTypeSize) { - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues); + SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + n_threads); } else if (gmat_bin_size == kUint16BinsTypeSize) { - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues); + SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + n_threads); } else { - CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize); - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues); + CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize); + SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + n_threads); } /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize but for ColumnMatrix we still have a chance to reduce the memory consumption */ @@ -266,13 +268,13 @@ class ColumnMatrix { template inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat, const size_t nrow, const size_t nfeature, - const bool noMissingValues) { + const bool noMissingValues, int32_t n_threads) { T* local_index = reinterpret_cast(&index_[0]); /* missing values make sense only for column with type kDenseColumn, and if no missing values were observed it could be handled much faster. */ if (noMissingValues) { - ParallelFor(omp_ulong(nrow), [&](omp_ulong rid) { + ParallelFor(nrow, n_threads, [&](auto rid) { const size_t ibegin = rid*nfeature; const size_t iend = (rid+1)*nfeature; size_t j = 0; diff --git a/src/data/data.cc b/src/data/data.cc index a318680e8..3d1e3cc28 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1035,6 +1035,16 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { return transpose; } +void SparsePage::SortRows(int32_t n_threads) { + auto& h_offset = this->offset.HostVector(); + auto& h_data = this->data.HostVector(); + common::ParallelFor(this->Size(), n_threads, [&](auto i) { + if (h_offset[i] < h_offset[i + 1]) { + std::sort(h_data.begin() + h_offset[i], h_data.begin() + h_offset[i + 1], Entry::CmpValue); + } + }); +} + void SparsePage::Push(const SparsePage &batch) { auto& data_vec = data.HostVector(); auto& offset_vec = offset.HostVector(); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 09ed2f806..d447f14ce 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -67,7 +67,7 @@ BatchSet SimpleDMatrix::GetSortedColumnBatches() { if (!sorted_column_page_) { sorted_column_page_.reset( new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads()))); - sorted_column_page_->SortRows(); + sorted_column_page_->SortRows(ctx_.Threads()); } auto begin_iter = BatchIterator( new SimpleBatchIteratorImpl(sorted_column_page_)); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 90fdcea8f..4bada04c8 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -339,7 +339,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn { this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_)); CHECK_EQ(this->page_->Size(), n_features_); CHECK_EQ(this->page_->data.Size(), csr->data.Size()); - this->page_->SortRows(); + this->page_->SortRows(this->nthreads_); page_->SetBaseRowId(csr->base_rowid); this->WriteCache(); } diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 6d0aed009..05f6c4bb5 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file tree_updater.cc * \brief Registry of tree updaters. */ @@ -21,7 +21,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const LOG(FATAL) << "Unknown tree updater " << name; } auto p_updater = (e->body)(task); - p_updater->tparam_ = tparam; + p_updater->ctx_ = tparam; return p_updater; } diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 1f3c7342b..9ae8c12ae 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2021 XGBoost contributors + * Copyright 2021-2022 XGBoost contributors * * \brief Implementation for the approx tree method. */ @@ -318,10 +318,10 @@ class GlobalApproxUpdater : public TreeUpdater { param_.learning_rate = lr / trees.size(); if (hist_param_.single_precision_histogram) { - f32_impl_ = std::make_unique>(param_, m->Info(), tparam_, + f32_impl_ = std::make_unique>(param_, m->Info(), ctx_, column_sampler_, task_, &monitor_); } else { - f64_impl_ = std::make_unique>(param_, m->Info(), tparam_, + f64_impl_ = std::make_unique>(param_, m->Info(), ctx_, column_sampler_, task_, &monitor_); } diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index c7c60b750..da239b209 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file updater_basemaker-inl.h * \brief implement a common tree constructor * \author Tianqi Chen @@ -220,9 +220,7 @@ class BaseMaker: public TreeUpdater { // set default direct nodes to default // for leaf nodes that are not fresh, mark then to ~nid, // so that they are ignored in future statistics collection - const auto ndata = static_cast(p_fmat->Info().num_row_); - - common::ParallelFor(ndata, [&](bst_omp_uint ridx) { + common::ParallelFor(p_fmat->Info().num_row_, ctx_->Threads(), [&](auto ridx) { const int nid = this->DecodePosition(ridx); if (tree[nid].IsLeaf()) { // mark finish when it is not a fresh leaf @@ -256,8 +254,7 @@ class BaseMaker: public TreeUpdater { auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid); if (it != sorted_split_set.end() && *it == fid) { - const auto ndata = static_cast(col.size()); - common::ParallelFor(ndata, [&](bst_omp_uint j) { + common::ParallelFor(col.size(), ctx_->Threads(), [&](auto j) { const bst_uint ridx = col[j].index; const bst_float fvalue = col[j].fvalue; const int nid = this->DecodePosition(ridx); @@ -312,8 +309,7 @@ class BaseMaker: public TreeUpdater { auto page = batch.GetView(); for (auto fid : fsplits) { auto col = page[fid]; - const auto ndata = static_cast(col.size()); - common::ParallelFor(ndata, [&](bst_omp_uint j) { + common::ParallelFor(col.size(), ctx_->Threads(), [&](auto j) { const bst_uint ridx = col[j].index; const bst_float fvalue = col[j].fvalue; const int nid = this->DecodePosition(ridx); @@ -337,10 +333,10 @@ class BaseMaker: public TreeUpdater { std::vector< std::vector > *p_thread_temp, std::vector *p_node_stats) { std::vector< std::vector > &thread_temp = *p_thread_temp; - thread_temp.resize(omp_get_max_threads()); + thread_temp.resize(ctx_->Threads()); p_node_stats->resize(tree.param.num_nodes); dmlc::OMPException exc; -#pragma omp parallel +#pragma omp parallel num_threads(ctx_->Threads()) { exc.Run([&]() { const int tid = omp_get_thread_num(); @@ -352,8 +348,7 @@ class BaseMaker: public TreeUpdater { } exc.Rethrow(); // setup position - const auto ndata = static_cast(fmat.Info().num_row_); - common::ParallelFor(ndata, [&](bst_omp_uint ridx) { + common::ParallelFor(fmat.Info().num_row_, ctx_->Threads(), [&](auto ridx) { const int nid = position_[ridx]; const int tid = omp_get_thread_num(); if (nid >= 0) { diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index c78bbf6f5..d121bb4ad 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file updater_colmaker.cc * \brief use columnwise update to construct a tree * \author Tianqi Chen @@ -114,8 +114,8 @@ class ColMaker: public TreeUpdater { interaction_constraints_.Configure(param_, dmat->Info().num_row_); // build tree for (auto tree : trees) { - CHECK(tparam_); - Builder builder(param_, colmaker_param_, interaction_constraints_, tparam_, + CHECK(ctx_); + Builder builder(param_, colmaker_param_, interaction_constraints_, ctx_, column_densities_); builder.Update(gpair->ConstHostVector(), dmat, tree); } @@ -270,17 +270,11 @@ class ColMaker: public TreeUpdater { } const MetaInfo& info = fmat.Info(); // setup position - const auto ndata = static_cast(info.num_row_); - dmlc::OMPException exc; - #pragma omp parallel for schedule(static) - for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { - exc.Run([&]() { - const int tid = omp_get_thread_num(); - if (position_[ridx] < 0) return; - stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]); - }); - } - exc.Rethrow(); + common::ParallelFor(info.num_row_, ctx_->Threads(), [&](auto ridx) { + int32_t const tid = omp_get_thread_num(); + if (position_[ridx] < 0) return; + stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]); + }); // sum the per thread statistics together for (int nid : qexpand) { GradStats stats; @@ -449,27 +443,20 @@ class ColMaker: public TreeUpdater { // update the solution candidate virtual void UpdateSolution(const SortedCSCPage &batch, const std::vector &feat_set, - const std::vector &gpair, - DMatrix*) { + const std::vector &gpair, DMatrix *) { // start enumeration - const auto num_features = static_cast(feat_set.size()); -#if defined(_OPENMP) + const auto num_features = feat_set.size(); CHECK(this->ctx_); const int batch_size = // NOLINT std::max(static_cast(num_features / this->ctx_->Threads() / 32), 1); -#endif // defined(_OPENMP) - { - auto page = batch.GetView(); - dmlc::OMPException exc; -#pragma omp parallel for schedule(dynamic, batch_size) - for (bst_omp_uint i = 0; i < num_features; ++i) { - exc.Run([&]() { + auto page = batch.GetView(); + common::ParallelFor( + num_features, ctx_->Threads(), common::Sched::Dyn(batch_size), [&](auto i) { auto evaluator = tree_evaluator_.GetEvaluator(); bst_feature_t const fid = feat_set[i]; int32_t const tid = omp_get_thread_num(); auto c = page[fid]; - const bool ind = - c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue; + const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue; if (colmaker_train_param_.NeedForwardSearch(column_densities_[fid], ind)) { this->EnumerateSplit(c.data(), c.data() + c.size(), +1, fid, gpair, stemp_[tid], evaluator); @@ -479,9 +466,6 @@ class ColMaker: public TreeUpdater { stemp_[tid], evaluator); } }); - } - exc.Rethrow(); - } } // find splits at current level, do split per level inline void FindSplit(int depth, @@ -529,11 +513,9 @@ class ColMaker: public TreeUpdater { // set default direct nodes to default // for leaf nodes that are not fresh, mark then to ~nid, // so that they are ignored in future statistics collection - const auto ndata = static_cast(p_fmat->Info().num_row_); - - common::ParallelFor(ndata, [&](bst_omp_uint ridx) { - CHECK_LT(ridx, position_.size()) - << "ridx exceed bound " << "ridx="<< ridx << " pos=" << position_.size(); + common::ParallelFor(p_fmat->Info().num_row_, this->ctx_->Threads(), [&](auto ridx) { + CHECK_LT(ridx, position_.size()) << "ridx exceed bound " + << "ridx=" << ridx << " pos=" << position_.size(); const int nid = this->DecodePosition(ridx); if (tree[nid].IsLeaf()) { // mark finish when it is not a fresh leaf @@ -577,8 +559,7 @@ class ColMaker: public TreeUpdater { auto page = batch.GetView(); for (auto fid : fsplits) { auto col = page[fid]; - const auto ndata = static_cast(col.size()); - common::ParallelFor(ndata, [&](bst_omp_uint j) { + common::ParallelFor(col.size(), this->ctx_->Threads(), [&](auto j) { const bst_uint ridx = col[j].index; const int nid = this->DecodePosition(ridx); const bst_float fvalue = col[j].fvalue; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9f92b8654..9587c3b83 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -875,11 +875,11 @@ class GPUHistMaker : public TreeUpdater { if (hist_maker_param_.single_precision_histogram) { float_maker_.reset(new GPUHistMakerSpecialised(task_)); float_maker_->param_ = param; - float_maker_->Configure(args, tparam_); + float_maker_->Configure(args, ctx_); } else { double_maker_.reset(new GPUHistMakerSpecialised(task_)); double_maker_->param_ = param; - double_maker_->Configure(args, tparam_); + double_maker_->Configure(args, ctx_); } } diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index ee4a618bb..0a85d2d73 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file updater_histmaker.cc * \brief use histogram counting to construct a tree * \author Tianqi Chen @@ -203,27 +203,22 @@ class HistMaker: public BaseMaker { // get the best split condition for each node std::vector sol(qexpand_.size()); std::vector left_sum(qexpand_.size()); - auto nexpand = static_cast(qexpand_.size()); - dmlc::OMPException exc; -#pragma omp parallel for schedule(dynamic, 1) - for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { - exc.Run([&]() { - const int nid = qexpand_[wid]; - CHECK_EQ(node2workindex_[nid], static_cast(wid)); - SplitEntry &best = sol[wid]; - GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; - for (size_t i = 0; i < feature_set.size(); ++i) { - // Query is thread safe as it's a const function. - if (!this->interaction_constraints_.Query(nid, feature_set[i])) { - continue; - } - - EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)], - node_sum, feature_set[i], &best, &left_sum[wid]); + auto nexpand = qexpand_.size(); + common::ParallelFor(nexpand, ctx_->Threads(), common::Sched::Dyn(1), [&](auto wid) { + const int nid = qexpand_[wid]; + CHECK_EQ(node2workindex_[nid], static_cast(wid)); + SplitEntry &best = sol[wid]; + GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; + for (size_t i = 0; i < feature_set.size(); ++i) { + // Query is thread safe as it's a const function. + if (!this->interaction_constraints_.Query(nid, feature_set[i])) { + continue; } - }); - } - exc.Rethrow(); + + EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature + 1)], node_sum, feature_set[i], + &best, &left_sum[wid]); + } + }); // get the best result, we can synchronize the solution for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { const bst_node_t nid = qexpand_[wid]; @@ -341,26 +336,19 @@ class CQHistMaker: public HistMaker { // if it is C++11, use lazy evaluation for Allreduce, // to gain speedup in recovery auto lazy_get_hist = [&]() { - thread_hist_.resize(omp_get_max_threads()); + thread_hist_.resize(ctx_->Threads()); // start accumulating statistics for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); // start enumeration - const auto nsize = static_cast(fset.size()); - dmlc::OMPException exc; -#pragma omp parallel for schedule(dynamic, 1) - for (bst_omp_uint i = 0; i < nsize; ++i) { - exc.Run([&]() { - int fid = fset[i]; - int offset = feat2workindex_[fid]; - if (offset >= 0) { - this->UpdateHistCol(gpair, page[fid], info, tree, - fset, offset, - &thread_hist_[omp_get_thread_num()]); - } - }); - } - exc.Rethrow(); + common::ParallelFor(fset.size(), ctx_->Threads(), common::Sched::Dyn(1), [&](auto i) { + int fid = fset[i]; + int offset = feat2workindex_[fid]; + if (offset >= 0) { + this->UpdateHistCol(gpair, page[fid], info, tree, fset, offset, + &thread_hist_[omp_get_thread_num()]); + } + }); } // update node statistics. this->GetNodeStats(gpair, *p_fmat, tree, @@ -412,7 +400,7 @@ class CQHistMaker: public HistMaker { } { // get summary - thread_sketch_.resize(omp_get_max_threads()); + thread_sketch_.resize(ctx_->Threads()); // TWOPASS: use the real set + split set in the column iteration. this->SetDefaultPostion(p_fmat, tree); @@ -426,21 +414,15 @@ class CQHistMaker: public HistMaker { this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree); auto page = batch.GetView(); // start enumeration - const auto nsize = static_cast(work_set_.size()); - dmlc::OMPException exc; -#pragma omp parallel for schedule(dynamic, 1) - for (bst_omp_uint i = 0; i < nsize; ++i) { - exc.Run([&]() { - int fid = work_set_[i]; - int offset = feat2workindex_[fid]; - if (offset >= 0) { - this->UpdateSketchCol(gpair, page[fid], tree, - work_set_size, offset, - &thread_sketch_[omp_get_thread_num()]); - } - }); - } - exc.Rethrow(); + common::ParallelFor(work_set_.size(), ctx_->Threads(), common::Sched::Dyn(1), + [&](auto i) { + int fid = work_set_[i]; + int offset = feat2workindex_[fid]; + if (offset >= 0) { + this->UpdateSketchCol(gpair, page[fid], tree, work_set_size, offset, + &thread_sketch_[omp_get_thread_num()]); + } + }); } for (size_t i = 0; i < sketchs_.size(); ++i) { common::WXQuantileSketch::SummaryContainer out; diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index 293f302cf..f71f1c698 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2020 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file updater_prune.cc * \brief prune a tree given the statistics * \author Tianqi Chen @@ -24,7 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune); class TreePruner: public TreeUpdater { public: explicit TreePruner(ObjInfo task) { - syncher_.reset(TreeUpdater::Create("sync", tparam_, task)); + syncher_.reset(TreeUpdater::Create("sync", ctx_, task)); pruner_monitor_.Init("TreePruner"); } char const* Name() const override { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 334f9e324..540b157a2 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 by Contributors + * Copyright 2017-2022 by XGBoost Contributors * \file updater_quantile_hist.cc * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Checn, Egor Smirnov @@ -40,19 +40,18 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); void QuantileHistMaker::Configure(const Args& args) { // initialize pruner if (!pruner_) { - pruner_.reset(TreeUpdater::Create("prune", tparam_, task_)); + pruner_.reset(TreeUpdater::Create("prune", ctx_, task_)); } pruner_->Configure(args); param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args); } -template +template void QuantileHistMaker::SetBuilder(const size_t n_trees, - std::unique_ptr>* builder, - DMatrix *dmat) { + std::unique_ptr>* builder, DMatrix* dmat) { builder->reset( - new Builder(n_trees, param_, std::move(pruner_), dmat, task_)); + new Builder(n_trees, param_, std::move(pruner_), dmat, task_, ctx_)); } template @@ -75,7 +74,7 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, auto p_gmat = it.Page(); if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { updater_monitor_.Start("GmatInitialization"); - column_matrix_.Init(*p_gmat, param_.sparse_threshold); + column_matrix_.Init(*p_gmat, param_.sparse_threshold, ctx_->Threads()); updater_monitor_.Stop("GmatInitialization"); // A proper solution is puting cut matrix in DMatrix, see: // https://github.com/dmlc/xgboost/issues/5143 @@ -347,7 +346,7 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( return row_set_collection_[node].Size(); }, 1024); CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); - common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { + common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) { const RowSetCollection::Elem rowset = row_set_collection_[node]; if (rowset.begin != nullptr && rowset.end != nullptr) { int nid = rowset.node_id; @@ -388,20 +387,19 @@ void QuantileHistMaker::Builder::InitSampling(const DMatrix& fmat, } } #else - const size_t nthread = this->nthread_; uint64_t initial_seed = rnd(); - const size_t discard_size = info.num_row_ / nthread; + auto n_threads = static_cast(ctx_->Threads()); + const size_t discard_size = info.num_row_ / n_threads; std::bernoulli_distribution coin_flip(param_.subsample); dmlc::OMPException exc; - #pragma omp parallel num_threads(nthread) + #pragma omp parallel num_threads(n_threads) { exc.Run([&]() { const size_t tid = omp_get_thread_num(); const size_t ibegin = tid * discard_size; - const size_t iend = (tid == (nthread - 1)) ? - info.num_row_ : ibegin + discard_size; + const size_t iend = (tid == (n_threads - 1)) ? info.num_row_ : ibegin + discard_size; RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) { return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng)); }, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref); @@ -436,16 +434,9 @@ void QuantileHistMaker::Builder::InitData( uint32_t nbins = gmat.cut.Ptrs().back(); // initialize histogram builder dmlc::OMPException exc; -#pragma omp parallel - { - exc.Run([&]() { - this->nthread_ = omp_get_num_threads(); - }); - } exc.Rethrow(); - this->histogram_builder_->Reset( - nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, - this->nthread_, 1, rabit::IsDistributed()); + this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, + this->ctx_->Threads(), 1, rabit::IsDistributed()); std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(info.num_row_); @@ -463,13 +454,14 @@ void QuantileHistMaker::Builder::InitData( // We should check that the partitioning was done correctly // and each row of the dataset fell into exactly one of the categories } - common::MemStackAllocator buff(this->nthread_); + auto n_threads = this->ctx_->Threads(); + common::MemStackAllocator buff(n_threads); bool* p_buff = buff.Get(); - std::fill(p_buff, p_buff + this->nthread_, false); + std::fill(p_buff, p_buff + this->ctx_->Threads(), false); - const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_); + const size_t block_size = info.num_row_ / n_threads + !!(info.num_row_ % n_threads); - #pragma omp parallel num_threads(this->nthread_) +#pragma omp parallel num_threads(n_threads) { exc.Run([&]() { const size_t tid = omp_get_thread_num(); @@ -488,7 +480,7 @@ void QuantileHistMaker::Builder::InitData( exc.Rethrow(); bool has_neg_hess = false; - for (int32_t tid = 0; tid < this->nthread_; ++tid) { + for (int32_t tid = 0; tid < n_threads; ++tid) { if (p_buff[tid]) { has_neg_hess = true; } @@ -503,7 +495,7 @@ void QuantileHistMaker::Builder::InitData( } row_indices.resize(j); } else { - #pragma omp parallel num_threads(this->nthread_) + #pragma omp parallel num_threads(n_threads) { exc.Run([&]() { const size_t tid = omp_get_thread_num(); @@ -543,10 +535,10 @@ void QuantileHistMaker::Builder::InitData( p_last_tree_ = &tree; if (data_layout_ == DataLayout::kDenseDataOneBased) { evaluator_.reset(new HistEvaluator{ - param_, info, this->nthread_, column_sampler_, task_, true}); + param_, info, this->ctx_->Threads(), column_sampler_, task_, true}); } else { evaluator_.reset(new HistEvaluator{ - param_, info, this->nthread_, column_sampler_, task_, false}); + param_, info, this->ctx_->Threads(), column_sampler_, task_, false}); } if (data_layout_ == DataLayout::kDenseDataZeroBased @@ -642,7 +634,7 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vectornthread_, [&](size_t node_in_set, common::Range1d r) { + common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) { size_t begin = r.begin(); const int32_t nid = nodes[node_in_set].nid; const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin); @@ -673,7 +665,7 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vectornthread_, [&](size_t node_in_set, common::Range1d r) { + common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) { const int32_t nid = nodes[node_in_set].nid; partition_builder_.MergeToArray(node_in_set, r.begin(), const_cast(row_set_collection_[nid].begin)); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 5fb357e32..f2103270c 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 by Contributors + * Copyright 2017-2022 by XGBoost Contributors * \file updater_quantile_hist.h * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Chen, Egor Smirnov @@ -155,14 +155,16 @@ class QuantileHistMaker: public TreeUpdater { using GradientPairT = xgboost::detail::GradientPairInternal; // constructor explicit Builder(const size_t n_trees, const TrainParam& param, - std::unique_ptr pruner, DMatrix const* fmat, ObjInfo task) + std::unique_ptr pruner, DMatrix const* fmat, ObjInfo task, + GenericParameter const* ctx) : n_trees_(n_trees), param_(param), pruner_(std::move(pruner)), p_last_tree_(nullptr), p_last_fmat_(fmat), histogram_builder_{new HistogramBuilder}, - task_{task} { + task_{task}, + ctx_{ctx} { builder_monitor_.Init("Quantile::Builder"); } // update one tree, growing @@ -225,8 +227,6 @@ class QuantileHistMaker: public TreeUpdater { // --data fields-- const size_t n_trees_; const TrainParam& param_; - // number of omp thread used during training - int nthread_; std::shared_ptr column_sampler_{ std::make_shared()}; @@ -258,9 +258,10 @@ class QuantileHistMaker: public TreeUpdater { enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; - std::unique_ptr> - histogram_builder_; + std::unique_ptr> histogram_builder_; ObjInfo task_; + // Context for number of threads + GenericParameter const* ctx_; common::Monitor builder_monitor_; }; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 993899c7b..d17c1e144 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file updater_refresh.cc * \brief refresh the statistics and leaf value on the tree on the dataset * \author Tianqi Chen @@ -51,11 +51,11 @@ class TreeRefresher: public TreeUpdater { std::vector > stemp; std::vector fvec_temp; // setup temp space for each thread - const int nthread = omp_get_max_threads(); + const int nthread = ctx_->Threads(); fvec_temp.resize(nthread, RegTree::FVec()); stemp.resize(nthread, std::vector()); dmlc::OMPException exc; - #pragma omp parallel +#pragma omp parallel num_threads(nthread) { exc.Run([&]() { int tid = omp_get_thread_num(); @@ -78,7 +78,7 @@ class TreeRefresher: public TreeUpdater { auto page = batch.GetView(); CHECK_LT(batch.Size(), std::numeric_limits::max()); const auto nbatch = static_cast(batch.Size()); - common::ParallelFor(nbatch, [&](bst_omp_uint i) { + common::ParallelFor(nbatch, ctx_->Threads(), [&](bst_omp_uint i) { SparsePage::Inst inst = page[i]; const int tid = omp_get_thread_num(); const auto ridx = static_cast(batch.base_rowid + i); @@ -95,7 +95,7 @@ class TreeRefresher: public TreeUpdater { } // aggregate the statistics auto num_nodes = static_cast(stemp[0].size()); - common::ParallelFor(num_nodes, [&](int nid) { + common::ParallelFor(num_nodes, ctx_->Threads(), [&](int nid) { for (int tid = 1; tid < nthread; ++tid) { stemp[0][nid].Add(stemp[tid][nid]); } diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 6dc831834..46ca6e6bb 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -19,7 +19,7 @@ TEST(DenseColumn, Test) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.2); + column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) { @@ -66,7 +66,7 @@ TEST(SparseColumn, Test) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.5); + column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0)); switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { auto col = column_matrix.GetColumn(0); @@ -106,7 +106,7 @@ TEST(DenseColumnWithMissing, Test) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.2); + column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { auto col = column_matrix.GetColumn(0); diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 6c3d0f9d6..5dc3d0646 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -76,7 +76,7 @@ TEST(SparsePage, PushCSCAfterTranspose) { // Make sure that the final sparse page has the right number of entries ASSERT_EQ(kEntries, page.data.Size()); - page.SortRows(); + page.SortRows(common::OmpGetNumThreads(0)); auto v = page.GetView(); for (size_t i = 0; i < v.Size(); ++i) { auto column = v[i]; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 0d60f0e44..931ac79cb 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -27,8 +27,8 @@ class QuantileHistMock : public QuantileHistMaker { using GHistRowT = typename RealImpl::GHistRowT; BuilderMock(const TrainParam ¶m, std::unique_ptr pruner, - DMatrix const *fmat) - : RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}) {} + DMatrix const *fmat, GenericParameter const* ctx) + : RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}, ctx) {} public: void TestInitData(const GHistIndexMatrix& gmat, @@ -166,7 +166,7 @@ class QuantileHistMock : public QuantileHistMaker { ColumnMatrix cm; // treat everything as dense, as this is what we intend to test here - cm.Init(gmat, 0.0); + cm.Init(gmat, 0.0, common::OmpGetNumThreads(0)); RealImpl::InitData(gmat, *dmat, tree, &row_gpairs); const size_t num_row = dmat->Info().num_row_; // split by feature 0 @@ -222,6 +222,7 @@ class QuantileHistMock : public QuantileHistMaker { int static constexpr kNRows = 8, kNCols = 16; std::shared_ptr dmat_; + GenericParameter ctx_; const std::vector > cfg_; std::shared_ptr > float_builder_; std::shared_ptr > double_builder_; @@ -233,18 +234,12 @@ class QuantileHistMock : public QuantileHistMaker { QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} { QuantileHistMaker::Configure(args); dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); + ctx_.UpdateAllowUnknown(Args{}); if (single_precision_histogram) { - float_builder_.reset( - new BuilderMock( - param_, - std::move(pruner_), - dmat_.get())); + float_builder_.reset(new BuilderMock(param_, std::move(pruner_), dmat_.get(), &ctx_)); } else { double_builder_.reset( - new BuilderMock( - param_, - std::move(pruner_), - dmat_.get())); + new BuilderMock(param_, std::move(pruner_), dmat_.get(), &ctx_)); } } ~QuantileHistMock() override = default;