From 581784085858553440dde09778532641c3296c1e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 24 Jan 2022 02:44:07 +0800 Subject: [PATCH] Remove `omp_get_max_threads` in data. (#7588) --- include/xgboost/data.h | 2 +- src/common/hist_util.h | 17 ++++----- src/common/threading_utils.h | 11 ------ src/data/data.cc | 15 +++----- src/data/gradient_index.cc | 11 +++--- src/data/gradient_index.h | 15 ++++---- src/data/simple_dmatrix.cc | 13 ++++--- src/data/simple_dmatrix.h | 5 ++- src/data/sparse_page_dmatrix.cc | 8 ++-- src/data/sparse_page_source.h | 6 +-- tests/cpp/common/test_column_matrix.cc | 11 ++++-- tests/cpp/common/test_hist_util.cc | 42 +++++++++++---------- tests/cpp/common/test_hist_util.cu | 6 +-- tests/cpp/common/test_quantile.cc | 5 ++- tests/cpp/data/test_data.cc | 5 ++- tests/cpp/data/test_gradient_index.cc | 4 +- tests/cpp/tree/hist/test_evaluate_splits.cc | 5 ++- tests/cpp/tree/test_quantile_hist.cc | 8 ++-- 18 files changed, 97 insertions(+), 92 deletions(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 23859ac22..75f0eb60e 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -300,7 +300,7 @@ class SparsePage { base_rowid = row_id; } - SparsePage GetTranspose(int num_columns) const; + SparsePage GetTranspose(int num_columns, int32_t n_threads) const; void SortRows() { auto ncol = static_cast(this->Size()); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index e066ed3a3..8cb233605 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 by Contributors + * Copyright 2017-2022 by XGBoost Contributors * \file hist_util.h * \brief Utility for fast histogram aggregation * \author Philip Cho, Tianqi Chen @@ -137,19 +137,18 @@ class HistogramCuts { * \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient * but consumes more memory. */ -inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false, - Span const hessian = {}) { +inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads, + bool use_sorted = false, Span const hessian = {}) { HistogramCuts out; auto const& info = m->Info(); - const auto threads = omp_get_max_threads(); - std::vector> column_sizes(threads); + std::vector> column_sizes(n_threads); for (auto& column : column_sizes) { column.resize(info.num_col_, 0); } std::vector reduced(info.num_col_, 0); for (auto const& page : m->GetBatches()) { - auto const &entries_per_column = - HostSketchContainer::CalcColumnSize(page, info.num_col_, threads); + auto const& entries_per_column = + HostSketchContainer::CalcColumnSize(page, info.num_col_, n_threads); for (size_t i = 0; i < entries_per_column.size(); ++i) { reduced[i] += entries_per_column[i]; } @@ -157,14 +156,14 @@ inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sort if (!use_sorted) { HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), - hessian, threads); + hessian, n_threads); for (auto const& page : m->GetBatches()) { container.PushRowPage(page, info, hessian); } container.MakeCuts(&out); } else { SortedSketchContainer container{ - max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads}; + max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, n_threads}; for (auto const& page : m->GetBatches()) { container.PushColPage(page, info, hessian); } diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index da8ddf3c2..929f7e4df 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -263,17 +263,6 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) { return nthread_original; } -inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) { - auto& threads = *p_threads; - int32_t nthread_original = omp_get_max_threads(); - if (threads <= 0) { - threads = nthread_original; - } - threads = std::min(threads, OmpGetThreadLimit()); - omp_set_num_threads(threads); - return nthread_original; -} - inline int32_t OmpGetNumThreads(int32_t n_threads) { if (n_threads <= 0) { n_threads = std::min(omp_get_num_procs(), omp_get_max_threads()); diff --git a/src/data/data.cc b/src/data/data.cc index 2de5bc8d4..a318680e8 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2021 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file data.cc */ #include @@ -1001,15 +1001,14 @@ DMatrix::Create(data::IteratorAdapter *adapter, float missing, int nthread, const std::string &cache_prefix); -SparsePage SparsePage::GetTranspose(int num_columns) const { +SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { SparsePage transpose; common::ParallelGroupBuilder builder(&transpose.offset.HostVector(), &transpose.data.HostVector()); - const int nthread = omp_get_max_threads(); - builder.InitBudget(num_columns, nthread); + builder.InitBudget(num_columns, n_threads); long batch_size = static_cast(this->Size()); // NOLINT(*) auto page = this->GetView(); - common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) + common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = page[i]; for (const auto& entry : inst) { @@ -1017,7 +1016,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const { } }); builder.InitStorage(); - common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) + common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = page[i]; for (const auto& entry : inst) { @@ -1059,8 +1058,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor; // Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory nthread = kIsRowMajor ? nthread : 1; - // Set number of threads but keep old value so we can reset it after - int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread); if (!kIsRowMajor) { CHECK_EQ(nthread, 1); } @@ -1085,7 +1082,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread expected_rows = kIsRowMajor ? batch_size : expected_rows; uint64_t max_columns = 0; if (batch_size == 0) { - omp_set_num_threads(nthread_original); return max_columns; } const size_t thread_size = batch_size / nthread; @@ -1154,7 +1150,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread }); } exec.Rethrow(); - omp_set_num_threads(nthread_original); return max_columns; } diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index a004f5231..c68276e9a 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 by Contributors + * Copyright 2017-2022 by XGBoost Contributors * \brief Data type for fast histogram aggregation. */ #include @@ -126,17 +126,16 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, }); } -void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, +void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, int32_t n_threads, common::Span hess) { // We use sorted sketching for approx tree method since it's more efficient in // computation time (but higher memory usage). - cut = common::SketchOnDMatrix(p_fmat, max_bins, sorted_sketch, hess); + cut = common::SketchOnDMatrix(p_fmat, max_bins, n_threads, sorted_sketch, hess); max_num_bins = max_bins; - const int32_t nthread = omp_get_max_threads(); const uint32_t nbins = cut.Ptrs().back(); hit_count.resize(nbins, 0); - hit_count_tloc_.resize(nthread * nbins, 0); + hit_count_tloc_.resize(n_threads * nbins, 0); this->p_fmat = p_fmat; size_t new_size = 1; @@ -154,7 +153,7 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (const auto &batch : p_fmat->GetBatches()) { - this->PushBatch(batch, ft, rbegin, prev_sum, nbins, nthread); + this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); prev_sum = row_ptr[rbegin + batch.Size()]; rbegin += batch.Size(); } diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index f30e2267e..58f3a0753 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 by Contributors + * Copyright 2017-2022 by XGBoost Contributors * \brief Data type for fast histogram aggregation. */ #ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ @@ -19,9 +19,8 @@ namespace xgboost { * index for CPU histogram. On GPU ellpack page is used. */ class GHistIndexMatrix { - void PushBatch(SparsePage const &batch, common::Span ft, - size_t rbegin, size_t prev_sum, uint32_t nbins, - int32_t n_threads); + void PushBatch(SparsePage const& batch, common::Span ft, size_t rbegin, + size_t prev_sum, uint32_t nbins, int32_t n_threads); public: /*! \brief row pointer to rows by element position */ @@ -37,11 +36,13 @@ class GHistIndexMatrix { size_t base_rowid{0}; GHistIndexMatrix() = default; - GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, common::Span hess = {}) { - this->Init(x, max_bin, sorted_sketch, hess); + GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, int32_t n_threads, + common::Span hess = {}) { + this->Init(x, max_bin, sorted_sketch, n_threads, hess); } // Create a global histogram matrix, given cut - void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, common::Span hess); + void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, int32_t n_threads, + common::Span hess); void Init(SparsePage const& page, common::Span ft, common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense, int32_t n_threads); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 9ec343d91..09ed2f806 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014~2021 by Contributors + * Copyright 2014~2022 by XGBoost Contributors * \file simple_dmatrix.cc * \brief the input data structure for gradient boosting * \author Tianqi Chen @@ -55,7 +55,7 @@ BatchSet SimpleDMatrix::GetRowBatches() { BatchSet SimpleDMatrix::GetColumnBatches() { // column page doesn't exist, generate it if (!column_page_) { - column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_))); + column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads()))); } auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(column_page_)); @@ -66,7 +66,7 @@ BatchSet SimpleDMatrix::GetSortedColumnBatches() { // Sorted column page doesn't exist, generate it if (!sorted_column_page_) { sorted_column_page_.reset( - new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_))); + new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads()))); sorted_column_page_->SortRows(); } auto begin_iter = BatchIterator( @@ -99,7 +99,8 @@ BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& par CHECK_EQ(param.gpu_id, -1); // Used only by approx. auto sorted_sketch = param.regen; - gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, sorted_sketch, param.hess)); + gradient_index_.reset( + new GHistIndexMatrix(this, param.max_bin, sorted_sketch, this->ctx_.Threads(), param.hess)); batch_param_ = param; CHECK_EQ(batch_param_.hess.data(), param.hess.data()); } @@ -110,6 +111,8 @@ BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& par template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { + this->ctx_.nthread = nthread; + std::vector qids; uint64_t default_max = std::numeric_limits::max(); uint64_t last_group_id = default_max; @@ -124,7 +127,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { // Iterate over batches of input data while (adapter->Next()) { auto& batch = adapter->Value(); - auto batch_max_columns = sparse_page_->Push(batch, missing, nthread); + auto batch_max_columns = sparse_page_->Push(batch, missing, ctx_.Threads()); inferred_num_columns = std::max(batch_max_columns, inferred_num_columns); total_batch_size += batch.Size(); // Append meta information if available diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index a4993e3b3..ad7e1c1f4 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2021 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file simple_dmatrix.h * \brief In-memory version of DMatrix. * \author Tianqi Chen @@ -61,6 +61,9 @@ class SimpleDMatrix : public DMatrix { bool SparsePageExists() const override { return true; } + + private: + GenericParameter ctx_; }; } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index bb95e0d99..0ce3b8c38 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2021 by Contributors + * Copyright 2014-2022 by Contributors * \file sparse_page_dmatrix.cc * \brief The external memory version of Page Iterator. * \author Tianqi Chen @@ -164,7 +164,8 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& // all index here. if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { this->InitializeSparsePage(); - ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.regen}); + ghist_index_page_.reset( + new GHistIndexMatrix{this, param.max_bin, param.regen, ctx_.Threads()}); this->InitializeSparsePage(); batch_param_ = param; } @@ -181,7 +182,8 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam& MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); // Use sorted sketch for approx. auto sorted_sketch = param.regen; - auto cuts = common::SketchOnDMatrix(this, param.max_bin, sorted_sketch, param.hess); + auto cuts = + common::SketchOnDMatrix(this, param.max_bin, ctx_.Threads(), sorted_sketch, param.hess); this->InitializeSparsePage(); // reset after use. batch_param_ = param; diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index eec5052dc..90fdcea8f 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2014-2021 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file sparse_page_source.h */ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ @@ -311,7 +311,7 @@ class CSCPageSource : public PageSourceIncMixIn { auto const &csr = source_->Page(); this->page_.reset(new CSCPage{}); // we might be able to optimize this by merging transpose and pushcsc - this->page_->PushCSC(csr->GetTranspose(n_features_)); + this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_)); page_->SetBaseRowId(csr->base_rowid); this->WriteCache(); } @@ -336,7 +336,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn { auto const &csr = this->source_->Page(); this->page_.reset(new SortedCSCPage{}); // we might be able to optimize this by merging transpose and pushcsc - this->page_->PushCSC(csr->GetTranspose(n_features_)); + this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_)); CHECK_EQ(this->page_->Size(), n_features_); CHECK_EQ(this->page_->data.Size(), csr->data.Size()); this->page_->SortRows(); diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 379d364e1..6dc831834 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2018-2022 by XGBoost Contributors + */ #include #include @@ -14,7 +17,7 @@ TEST(DenseColumn, Test) { static_cast(std::numeric_limits::max()) + 2}; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); + GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2); @@ -61,7 +64,7 @@ TEST(SparseColumn, Test) { static_cast(std::numeric_limits::max()) + 2}; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); + GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.5); switch (column_matrix.GetTypeSize()) { @@ -101,7 +104,7 @@ TEST(DenseColumnWithMissing, Test) { static_cast(std::numeric_limits::max()) + 2 }; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); + GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2); switch (column_matrix.GetTypeSize()) { @@ -130,7 +133,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) { /* This should create multiple sparse pages */ std::unique_ptr dmat{ CreateSparsePageDMatrix(kEntries) }; omp_set_num_threads(nthreads); - GHistIndexMatrix gmat(dmat.get(), 256, false); + GHistIndexMatrix gmat(dmat.get(), 256, false, common::OmpGetNumThreads(0)); } TEST(HistIndexCreationWithExternalMemory, Test) { diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 350e544cf..8bcb33ca0 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 by XGBoost Contributors + * Copyright 2019-2022 by XGBoost Contributors */ #include #include @@ -188,7 +188,7 @@ TEST(HistUtil, DenseCutsCategorical) { std::vector x_sorted(x); std::sort(x_sorted.begin(), x_sorted.end()); auto dmat = GetDMatrixFromData(x, n, 1); - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); auto cuts_from_sketch = cuts.Values(); EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); @@ -207,7 +207,7 @@ TEST(HistUtil, DenseCutsAccuracyTest) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); for (auto num_bins : bin_sizes) { - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -224,11 +224,13 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) { dmat->Info().weights_.HostVector() = w; for (auto num_bins : bin_sizes) { { - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, true); + HistogramCuts cuts = + SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), true); ValidateCuts(cuts, dmat.get(), num_bins); } { - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, false); + HistogramCuts cuts = + SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), false); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -249,13 +251,15 @@ void TestQuantileWithHessian(bool use_sorted) { dmat->Info().weights_.HostVector() = w; for (auto num_bins : bin_sizes) { - HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, use_sorted, hessian); + HistogramCuts cuts_hess = + SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted, hessian); for (size_t i = 0; i < w.size(); ++i) { dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i]; } ValidateCuts(cuts_hess, dmat.get(), num_bins); - HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins, use_sorted); + HistogramCuts cuts_wh = + SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted); ValidateCuts(cuts_wh, dmat.get(), num_bins); ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); @@ -283,7 +287,7 @@ TEST(HistUtil, DenseCutsExternalMemory) { auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir); for (auto num_bins : bin_sizes) { - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -303,7 +307,7 @@ TEST(HistUtil, IndexBinBound) { for (auto max_bin : bin_sizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin, false); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0)); EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); } @@ -326,7 +330,7 @@ TEST(HistUtil, IndexBinData) { for (auto max_bin : kBinSizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin, false); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0)); uint32_t* offsets = hmat.index.Offset(); EXPECT_EQ(hmat.index.Size(), kRows*kCols); switch (max_bin) { @@ -351,7 +355,7 @@ void TestSketchFromWeights(bool with_group) { size_t constexpr kGroups = 10; auto m = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix(); - common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins); + common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); MetaInfo info; auto& h_weights = info.weights_.HostVector(); @@ -385,7 +389,7 @@ void TestSketchFromWeights(bool with_group) { ValidateCuts(cuts, m.get(), kBins); if (with_group) { - HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins); + HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); for (size_t i = 0; i < cuts.Values().size(); ++i) { EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); } @@ -404,14 +408,12 @@ TEST(HistUtil, SketchFromWeights) { } TEST(HistUtil, SketchCategoricalFeatures) { - TestCategoricalSketch(1000, 256, 32, false, - [](DMatrix *p_fmat, int32_t num_bins) { - return SketchOnDMatrix(p_fmat, num_bins); - }); - TestCategoricalSketch(1000, 256, 32, true, - [](DMatrix *p_fmat, int32_t num_bins) { - return SketchOnDMatrix(p_fmat, num_bins); - }); + TestCategoricalSketch(1000, 256, 32, false, [](DMatrix* p_fmat, int32_t num_bins) { + return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0)); + }); + TestCategoricalSketch(1000, 256, 32, true, [](DMatrix* p_fmat, int32_t num_bins) { + return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0)); + }); } } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index eb1b04cd5..4ab9b2c9e 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 by XGBoost Contributors + * Copyright 2019-2022 by XGBoost Contributors */ #include #include @@ -28,7 +28,7 @@ namespace common { template HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { data::SimpleDMatrix dmat(adapter, missing, 1); - HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins); + HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, common::OmpGetNumThreads(0)); return cuts; } @@ -40,7 +40,7 @@ TEST(HistUtil, DeviceSketch) { auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); - HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins); + HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index a079fdee0..ca3b7b74c 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2020-2022 by XGBoost Contributors + */ #include #include "test_quantile.h" #include "../../../src/common/quantile.h" @@ -201,7 +204,7 @@ TEST(Quantile, SameOnAllWorkers) { .MaxCategory(17) .Seed(rank + seed) .GenerateDMatrix(); - auto cuts = SketchOnDMatrix(m.get(), n_bins); + auto cuts = SketchOnDMatrix(m.get(), n_bins, common::OmpGetNumThreads(0)); std::vector cut_values(cuts.Values().size() * world, 0); std::vector< typename std::remove_reference_t::value_type> diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 6c1b42571..6c3d0f9d6 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2019-2022 by XGBoost Contributors + */ #include #include #include @@ -66,7 +69,7 @@ TEST(SparsePage, PushCSCAfterTranspose) { SparsePage page; // Consolidated sparse page for (const auto &batch : dmat->GetBatches()) { // Transpose each batch and push - SparsePage tmp = batch.GetTranspose(ncols); + SparsePage tmp = batch.GetTranspose(ncols, common::OmpGetNumThreads(0)); page.PushCSC(tmp); } diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index bbb0f6de4..dacf5d901 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2021 XGBoost contributors + * Copyright 2021-2022 XGBoost contributors */ #include #include @@ -36,7 +36,7 @@ TEST(GradientIndex, FromCategoricalBasic) { BatchParam p(0, max_bins); GHistIndexMatrix gidx; - gidx.Init(m.get(), max_bins, false, {}); + gidx.Init(m.get(), max_bins, false, common::OmpGetNumThreads(0), {}); auto x_copy = x; std::sort(x_copy.begin(), x_copy.end()); diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index d7adde257..5df2ff4e9 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -1,3 +1,6 @@ +/*! + * Copyright 2021-2022 by XGBoost Contributors + */ #include #include #include "../../../../src/tree/hist/evaluate_splits.h" @@ -29,7 +32,7 @@ template void TestEvaluateSplits() { size_t constexpr kMaxBins = 4; // dense, no missing values - GHistIndexMatrix gmat(dmat.get(), kMaxBins, false); + GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0)); common::RowSetCollection row_set_collection; std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kRows); diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 006cbf30d..0d60f0e44 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2018-2021 by Contributors + * Copyright 2018-2022 by XGBoost Contributors */ #include #include @@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker { // kNRows samples with kNCols features auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), kMaxBins, false); + GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0)); ColumnMatrix cm; // treat everything as dense, as this is what we intend to test here @@ -253,7 +253,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestInitData() { size_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0)); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -270,7 +270,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestInitDataSampling() { size_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0)); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_);