Remove omp_get_max_threads in data. (#7588)

This commit is contained in:
Jiaming Yuan 2022-01-24 02:44:07 +08:00 committed by GitHub
parent f84291c1e1
commit 5817840858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 97 additions and 92 deletions

View File

@ -300,7 +300,7 @@ class SparsePage {
base_rowid = row_id; base_rowid = row_id;
} }
SparsePage GetTranspose(int num_columns) const; SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
void SortRows() { void SortRows() {
auto ncol = static_cast<bst_omp_uint>(this->Size()); auto ncol = static_cast<bst_omp_uint>(this->Size());

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 by Contributors * Copyright 2017-2022 by XGBoost Contributors
* \file hist_util.h * \file hist_util.h
* \brief Utility for fast histogram aggregation * \brief Utility for fast histogram aggregation
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen
@ -137,19 +137,18 @@ class HistogramCuts {
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient * \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
* but consumes more memory. * but consumes more memory.
*/ */
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false, inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
Span<float> const hessian = {}) { bool use_sorted = false, Span<float> const hessian = {}) {
HistogramCuts out; HistogramCuts out;
auto const& info = m->Info(); auto const& info = m->Info();
const auto threads = omp_get_max_threads(); std::vector<std::vector<bst_row_t>> column_sizes(n_threads);
std::vector<std::vector<bst_row_t>> column_sizes(threads);
for (auto& column : column_sizes) { for (auto& column : column_sizes) {
column.resize(info.num_col_, 0); column.resize(info.num_col_, 0);
} }
std::vector<bst_row_t> reduced(info.num_col_, 0); std::vector<bst_row_t> reduced(info.num_col_, 0);
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {
auto const &entries_per_column = auto const& entries_per_column =
HostSketchContainer::CalcColumnSize(page, info.num_col_, threads); HostSketchContainer::CalcColumnSize(page, info.num_col_, n_threads);
for (size_t i = 0; i < entries_per_column.size(); ++i) { for (size_t i = 0; i < entries_per_column.size(); ++i) {
reduced[i] += entries_per_column[i]; reduced[i] += entries_per_column[i];
} }
@ -157,14 +156,14 @@ inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sort
if (!use_sorted) { if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
hessian, threads); hessian, n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian); container.PushRowPage(page, info, hessian);
} }
container.MakeCuts(&out); container.MakeCuts(&out);
} else { } else {
SortedSketchContainer container{ SortedSketchContainer container{
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads}; max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) { for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian); container.PushColPage(page, info, hessian);
} }

View File

@ -263,17 +263,6 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) {
return nthread_original; return nthread_original;
} }
inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
auto& threads = *p_threads;
int32_t nthread_original = omp_get_max_threads();
if (threads <= 0) {
threads = nthread_original;
}
threads = std::min(threads, OmpGetThreadLimit());
omp_set_num_threads(threads);
return nthread_original;
}
inline int32_t OmpGetNumThreads(int32_t n_threads) { inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) { if (n_threads <= 0) {
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads()); n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2021 by Contributors * Copyright 2015-2022 by XGBoost Contributors
* \file data.cc * \file data.cc
*/ */
#include <dmlc/registry.h> #include <dmlc/registry.h>
@ -1001,15 +1001,14 @@ DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> *adapter, XGBoostBatchCSR> *adapter,
float missing, int nthread, const std::string &cache_prefix); float missing, int nthread, const std::string &cache_prefix);
SparsePage SparsePage::GetTranspose(int num_columns) const { SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
SparsePage transpose; SparsePage transpose;
common::ParallelGroupBuilder<Entry, bst_row_t> builder(&transpose.offset.HostVector(), common::ParallelGroupBuilder<Entry, bst_row_t> builder(&transpose.offset.HostVector(),
&transpose.data.HostVector()); &transpose.data.HostVector());
const int nthread = omp_get_max_threads(); builder.InitBudget(num_columns, n_threads);
builder.InitBudget(num_columns, nthread);
long batch_size = static_cast<long>(this->Size()); // NOLINT(*) long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
auto page = this->GetView(); auto page = this->GetView();
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
auto inst = page[i]; auto inst = page[i];
for (const auto& entry : inst) { for (const auto& entry : inst) {
@ -1017,7 +1016,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const {
} }
}); });
builder.InitStorage(); builder.InitStorage();
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
auto inst = page[i]; auto inst = page[i];
for (const auto& entry : inst) { for (const auto& entry : inst) {
@ -1059,8 +1058,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor; constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor;
// Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory // Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory
nthread = kIsRowMajor ? nthread : 1; nthread = kIsRowMajor ? nthread : 1;
// Set number of threads but keep old value so we can reset it after
int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread);
if (!kIsRowMajor) { if (!kIsRowMajor) {
CHECK_EQ(nthread, 1); CHECK_EQ(nthread, 1);
} }
@ -1085,7 +1082,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
expected_rows = kIsRowMajor ? batch_size : expected_rows; expected_rows = kIsRowMajor ? batch_size : expected_rows;
uint64_t max_columns = 0; uint64_t max_columns = 0;
if (batch_size == 0) { if (batch_size == 0) {
omp_set_num_threads(nthread_original);
return max_columns; return max_columns;
} }
const size_t thread_size = batch_size / nthread; const size_t thread_size = batch_size / nthread;
@ -1154,7 +1150,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}); });
} }
exec.Rethrow(); exec.Rethrow();
omp_set_num_threads(nthread_original);
return max_columns; return max_columns;
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 by Contributors * Copyright 2017-2022 by XGBoost Contributors
* \brief Data type for fast histogram aggregation. * \brief Data type for fast histogram aggregation.
*/ */
#include <algorithm> #include <algorithm>
@ -126,17 +126,16 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
}); });
} }
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, int32_t n_threads,
common::Span<float> hess) { common::Span<float> hess) {
// We use sorted sketching for approx tree method since it's more efficient in // We use sorted sketching for approx tree method since it's more efficient in
// computation time (but higher memory usage). // computation time (but higher memory usage).
cut = common::SketchOnDMatrix(p_fmat, max_bins, sorted_sketch, hess); cut = common::SketchOnDMatrix(p_fmat, max_bins, n_threads, sorted_sketch, hess);
max_num_bins = max_bins; max_num_bins = max_bins;
const int32_t nthread = omp_get_max_threads();
const uint32_t nbins = cut.Ptrs().back(); const uint32_t nbins = cut.Ptrs().back();
hit_count.resize(nbins, 0); hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0); hit_count_tloc_.resize(n_threads * nbins, 0);
this->p_fmat = p_fmat; this->p_fmat = p_fmat;
size_t new_size = 1; size_t new_size = 1;
@ -154,7 +153,7 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch,
auto ft = p_fmat->Info().feature_types.ConstHostSpan(); auto ft = p_fmat->Info().feature_types.ConstHostSpan();
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) { for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
this->PushBatch(batch, ft, rbegin, prev_sum, nbins, nthread); this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
prev_sum = row_ptr[rbegin + batch.Size()]; prev_sum = row_ptr[rbegin + batch.Size()];
rbegin += batch.Size(); rbegin += batch.Size();
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 by Contributors * Copyright 2017-2022 by XGBoost Contributors
* \brief Data type for fast histogram aggregation. * \brief Data type for fast histogram aggregation.
*/ */
#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ #ifndef XGBOOST_DATA_GRADIENT_INDEX_H_
@ -19,9 +19,8 @@ namespace xgboost {
* index for CPU histogram. On GPU ellpack page is used. * index for CPU histogram. On GPU ellpack page is used.
*/ */
class GHistIndexMatrix { class GHistIndexMatrix {
void PushBatch(SparsePage const &batch, common::Span<FeatureType const> ft, void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, size_t rbegin,
size_t rbegin, size_t prev_sum, uint32_t nbins, size_t prev_sum, uint32_t nbins, int32_t n_threads);
int32_t n_threads);
public: public:
/*! \brief row pointer to rows by element position */ /*! \brief row pointer to rows by element position */
@ -37,11 +36,13 @@ class GHistIndexMatrix {
size_t base_rowid{0}; size_t base_rowid{0};
GHistIndexMatrix() = default; GHistIndexMatrix() = default;
GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, common::Span<float> hess = {}) { GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, int32_t n_threads,
this->Init(x, max_bin, sorted_sketch, hess); common::Span<float> hess = {}) {
this->Init(x, max_bin, sorted_sketch, n_threads, hess);
} }
// Create a global histogram matrix, given cut // Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, common::Span<float> hess); void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, int32_t n_threads,
common::Span<float> hess);
void Init(SparsePage const& page, common::Span<FeatureType const> ft, void Init(SparsePage const& page, common::Span<FeatureType const> ft,
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense, common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
int32_t n_threads); int32_t n_threads);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014~2021 by Contributors * Copyright 2014~2022 by XGBoost Contributors
* \file simple_dmatrix.cc * \file simple_dmatrix.cc
* \brief the input data structure for gradient boosting * \brief the input data structure for gradient boosting
* \author Tianqi Chen * \author Tianqi Chen
@ -55,7 +55,7 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() { BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
// column page doesn't exist, generate it // column page doesn't exist, generate it
if (!column_page_) { if (!column_page_) {
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_))); column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
} }
auto begin_iter = auto begin_iter =
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_)); BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
@ -66,7 +66,7 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
// Sorted column page doesn't exist, generate it // Sorted column page doesn't exist, generate it
if (!sorted_column_page_) { if (!sorted_column_page_) {
sorted_column_page_.reset( sorted_column_page_.reset(
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_))); new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
sorted_column_page_->SortRows(); sorted_column_page_->SortRows();
} }
auto begin_iter = BatchIterator<SortedCSCPage>( auto begin_iter = BatchIterator<SortedCSCPage>(
@ -99,7 +99,8 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
CHECK_EQ(param.gpu_id, -1); CHECK_EQ(param.gpu_id, -1);
// Used only by approx. // Used only by approx.
auto sorted_sketch = param.regen; auto sorted_sketch = param.regen;
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, sorted_sketch, param.hess)); gradient_index_.reset(
new GHistIndexMatrix(this, param.max_bin, sorted_sketch, this->ctx_.Threads(), param.hess));
batch_param_ = param; batch_param_ = param;
CHECK_EQ(batch_param_.hess.data(), param.hess.data()); CHECK_EQ(batch_param_.hess.data(), param.hess.data());
} }
@ -110,6 +111,8 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
template <typename AdapterT> template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
this->ctx_.nthread = nthread;
std::vector<uint64_t> qids; std::vector<uint64_t> qids;
uint64_t default_max = std::numeric_limits<uint64_t>::max(); uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max; uint64_t last_group_id = default_max;
@ -124,7 +127,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Iterate over batches of input data // Iterate over batches of input data
while (adapter->Next()) { while (adapter->Next()) {
auto& batch = adapter->Value(); auto& batch = adapter->Value();
auto batch_max_columns = sparse_page_->Push(batch, missing, nthread); auto batch_max_columns = sparse_page_->Push(batch, missing, ctx_.Threads());
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns); inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
total_batch_size += batch.Size(); total_batch_size += batch.Size();
// Append meta information if available // Append meta information if available

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2021 by Contributors * Copyright 2015-2022 by XGBoost Contributors
* \file simple_dmatrix.h * \file simple_dmatrix.h
* \brief In-memory version of DMatrix. * \brief In-memory version of DMatrix.
* \author Tianqi Chen * \author Tianqi Chen
@ -61,6 +61,9 @@ class SimpleDMatrix : public DMatrix {
bool SparsePageExists() const override { bool SparsePageExists() const override {
return true; return true;
} }
private:
GenericParameter ctx_;
}; };
} // namespace data } // namespace data
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2021 by Contributors * Copyright 2014-2022 by Contributors
* \file sparse_page_dmatrix.cc * \file sparse_page_dmatrix.cc
* \brief The external memory version of Page Iterator. * \brief The external memory version of Page Iterator.
* \author Tianqi Chen * \author Tianqi Chen
@ -164,7 +164,8 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
// all index here. // all index here.
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage(); this->InitializeSparsePage();
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.regen}); ghist_index_page_.reset(
new GHistIndexMatrix{this, param.max_bin, param.regen, ctx_.Threads()});
this->InitializeSparsePage(); this->InitializeSparsePage();
batch_param_ = param; batch_param_ = param;
} }
@ -181,7 +182,8 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
// Use sorted sketch for approx. // Use sorted sketch for approx.
auto sorted_sketch = param.regen; auto sorted_sketch = param.regen;
auto cuts = common::SketchOnDMatrix(this, param.max_bin, sorted_sketch, param.hess); auto cuts =
common::SketchOnDMatrix(this, param.max_bin, ctx_.Threads(), sorted_sketch, param.hess);
this->InitializeSparsePage(); // reset after use. this->InitializeSparsePage(); // reset after use.
batch_param_ = param; batch_param_ = param;

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright (c) 2014-2021 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file sparse_page_source.h * \file sparse_page_source.h
*/ */
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
@ -311,7 +311,7 @@ class CSCPageSource : public PageSourceIncMixIn<CSCPage> {
auto const &csr = source_->Page(); auto const &csr = source_->Page();
this->page_.reset(new CSCPage{}); this->page_.reset(new CSCPage{});
// we might be able to optimize this by merging transpose and pushcsc // we might be able to optimize this by merging transpose and pushcsc
this->page_->PushCSC(csr->GetTranspose(n_features_)); this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
page_->SetBaseRowId(csr->base_rowid); page_->SetBaseRowId(csr->base_rowid);
this->WriteCache(); this->WriteCache();
} }
@ -336,7 +336,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
auto const &csr = this->source_->Page(); auto const &csr = this->source_->Page();
this->page_.reset(new SortedCSCPage{}); this->page_.reset(new SortedCSCPage{});
// we might be able to optimize this by merging transpose and pushcsc // we might be able to optimize this by merging transpose and pushcsc
this->page_->PushCSC(csr->GetTranspose(n_features_)); this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
CHECK_EQ(this->page_->Size(), n_features_); CHECK_EQ(this->page_->Size(), n_features_);
CHECK_EQ(this->page_->data.Size(), csr->data.Size()); CHECK_EQ(this->page_->data.Size(), csr->data.Size());
this->page_->SortRows(); this->page_->SortRows();

View File

@ -1,3 +1,6 @@
/*!
* Copyright 2018-2022 by XGBoost Contributors
*/
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -14,7 +17,7 @@ TEST(DenseColumn, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) { for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2); column_matrix.Init(gmat, 0.2);
@ -61,7 +64,7 @@ TEST(SparseColumn, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) { for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5); column_matrix.Init(gmat, 0.5);
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
@ -101,7 +104,7 @@ TEST(DenseColumnWithMissing, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 }; static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
for (size_t max_num_bin : max_num_bins) { for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2); column_matrix.Init(gmat, 0.2);
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
@ -130,7 +133,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) {
/* This should create multiple sparse pages */ /* This should create multiple sparse pages */
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries) }; std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries) };
omp_set_num_threads(nthreads); omp_set_num_threads(nthreads);
GHistIndexMatrix gmat(dmat.get(), 256, false); GHistIndexMatrix gmat(dmat.get(), 256, false, common::OmpGetNumThreads(0));
} }
TEST(HistIndexCreationWithExternalMemory, Test) { TEST(HistIndexCreationWithExternalMemory, Test) {

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2021 by XGBoost Contributors * Copyright 2019-2022 by XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
@ -188,7 +188,7 @@ TEST(HistUtil, DenseCutsCategorical) {
std::vector<float> x_sorted(x); std::vector<float> x_sorted(x);
std::sort(x_sorted.begin(), x_sorted.end()); std::sort(x_sorted.begin(), x_sorted.end());
auto dmat = GetDMatrixFromData(x, n, 1); auto dmat = GetDMatrixFromData(x, n, 1);
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
auto cuts_from_sketch = cuts.Values(); auto cuts_from_sketch = cuts.Values();
EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
@ -207,7 +207,7 @@ TEST(HistUtil, DenseCutsAccuracyTest) {
auto x = GenerateRandom(num_rows, num_columns); auto x = GenerateRandom(num_rows, num_columns);
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -224,11 +224,13 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
dmat->Info().weights_.HostVector() = w; dmat->Info().weights_.HostVector() = w;
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
{ {
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, true); HistogramCuts cuts =
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), true);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
{ {
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, false); HistogramCuts cuts =
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), false);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -249,13 +251,15 @@ void TestQuantileWithHessian(bool use_sorted) {
dmat->Info().weights_.HostVector() = w; dmat->Info().weights_.HostVector() = w;
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, use_sorted, hessian); HistogramCuts cuts_hess =
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted, hessian);
for (size_t i = 0; i < w.size(); ++i) { for (size_t i = 0; i < w.size(); ++i) {
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i]; dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
} }
ValidateCuts(cuts_hess, dmat.get(), num_bins); ValidateCuts(cuts_hess, dmat.get(), num_bins);
HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins, use_sorted); HistogramCuts cuts_wh =
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted);
ValidateCuts(cuts_wh, dmat.get(), num_bins); ValidateCuts(cuts_wh, dmat.get(), num_bins);
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
@ -283,7 +287,7 @@ TEST(HistUtil, DenseCutsExternalMemory) {
auto dmat = auto dmat =
GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir); GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir);
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
@ -303,7 +307,7 @@ TEST(HistUtil, IndexBinBound) {
for (auto max_bin : bin_sizes) { for (auto max_bin : bin_sizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false); GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0));
EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(hmat.index.Size(), kRows*kCols);
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
} }
@ -326,7 +330,7 @@ TEST(HistUtil, IndexBinData) {
for (auto max_bin : kBinSizes) { for (auto max_bin : kBinSizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin, false); GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0));
uint32_t* offsets = hmat.index.Offset(); uint32_t* offsets = hmat.index.Offset();
EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(hmat.index.Size(), kRows*kCols);
switch (max_bin) { switch (max_bin) {
@ -351,7 +355,7 @@ void TestSketchFromWeights(bool with_group) {
size_t constexpr kGroups = 10; size_t constexpr kGroups = 10;
auto m = auto m =
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix(); RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins); common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
MetaInfo info; MetaInfo info;
auto& h_weights = info.weights_.HostVector(); auto& h_weights = info.weights_.HostVector();
@ -385,7 +389,7 @@ void TestSketchFromWeights(bool with_group) {
ValidateCuts(cuts, m.get(), kBins); ValidateCuts(cuts, m.get(), kBins);
if (with_group) { if (with_group) {
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins); HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
for (size_t i = 0; i < cuts.Values().size(); ++i) { for (size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
} }
@ -404,14 +408,12 @@ TEST(HistUtil, SketchFromWeights) {
} }
TEST(HistUtil, SketchCategoricalFeatures) { TEST(HistUtil, SketchCategoricalFeatures) {
TestCategoricalSketch(1000, 256, 32, false, TestCategoricalSketch(1000, 256, 32, false, [](DMatrix* p_fmat, int32_t num_bins) {
[](DMatrix *p_fmat, int32_t num_bins) { return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0));
return SketchOnDMatrix(p_fmat, num_bins); });
}); TestCategoricalSketch(1000, 256, 32, true, [](DMatrix* p_fmat, int32_t num_bins) {
TestCategoricalSketch(1000, 256, 32, true, return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0));
[](DMatrix *p_fmat, int32_t num_bins) { });
return SketchOnDMatrix(p_fmat, num_bins);
});
} }
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2021 by XGBoost Contributors * Copyright 2019-2022 by XGBoost Contributors
*/ */
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -28,7 +28,7 @@ namespace common {
template <typename AdapterT> template <typename AdapterT>
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
data::SimpleDMatrix dmat(adapter, missing, 1); data::SimpleDMatrix dmat(adapter, missing, 1);
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins); HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, common::OmpGetNumThreads(0));
return cuts; return cuts;
} }
@ -40,7 +40,7 @@ TEST(HistUtil, DeviceSketch) {
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins); HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());

View File

@ -1,3 +1,6 @@
/*!
* Copyright 2020-2022 by XGBoost Contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "test_quantile.h" #include "test_quantile.h"
#include "../../../src/common/quantile.h" #include "../../../src/common/quantile.h"
@ -201,7 +204,7 @@ TEST(Quantile, SameOnAllWorkers) {
.MaxCategory(17) .MaxCategory(17)
.Seed(rank + seed) .Seed(rank + seed)
.GenerateDMatrix(); .GenerateDMatrix();
auto cuts = SketchOnDMatrix(m.get(), n_bins); auto cuts = SketchOnDMatrix(m.get(), n_bins, common::OmpGetNumThreads(0));
std::vector<float> cut_values(cuts.Values().size() * world, 0); std::vector<float> cut_values(cuts.Values().size() * world, 0);
std::vector< std::vector<
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type> typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>

View File

@ -1,3 +1,6 @@
/*!
* Copyright 2019-2022 by XGBoost Contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <fstream> #include <fstream>
@ -66,7 +69,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
SparsePage page; // Consolidated sparse page SparsePage page; // Consolidated sparse page
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) { for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
// Transpose each batch and push // Transpose each batch and push
SparsePage tmp = batch.GetTranspose(ncols); SparsePage tmp = batch.GetTranspose(ncols, common::OmpGetNumThreads(0));
page.PushCSC(tmp); page.PushCSC(tmp);
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 XGBoost contributors * Copyright 2021-2022 XGBoost contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/data.h> #include <xgboost/data.h>
@ -36,7 +36,7 @@ TEST(GradientIndex, FromCategoricalBasic) {
BatchParam p(0, max_bins); BatchParam p(0, max_bins);
GHistIndexMatrix gidx; GHistIndexMatrix gidx;
gidx.Init(m.get(), max_bins, false, {}); gidx.Init(m.get(), max_bins, false, common::OmpGetNumThreads(0), {});
auto x_copy = x; auto x_copy = x;
std::sort(x_copy.begin(), x_copy.end()); std::sort(x_copy.begin(), x_copy.end());

View File

@ -1,3 +1,6 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include "../../../../src/tree/hist/evaluate_splits.h" #include "../../../../src/tree/hist/evaluate_splits.h"
@ -29,7 +32,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
size_t constexpr kMaxBins = 4; size_t constexpr kMaxBins = 4;
// dense, no missing values // dense, no missing values
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false); GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
common::RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows); row_indices.resize(kRows);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2018-2021 by Contributors * Copyright 2018-2022 by XGBoost Contributors
*/ */
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker {
// kNRows samples with kNCols features // kNRows samples with kNCols features
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false); GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
ColumnMatrix cm; ColumnMatrix cm;
// treat everything as dense, as this is what we intend to test here // treat everything as dense, as this is what we intend to test here
@ -253,7 +253,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestInitData() { void TestInitData() {
size_t constexpr kMaxBins = 4; size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false); GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
RegTree tree = RegTree(); RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_); tree.param.UpdateAllowUnknown(cfg_);
@ -270,7 +270,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestInitDataSampling() { void TestInitDataSampling() {
size_t constexpr kMaxBins = 4; size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false); GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
RegTree tree = RegTree(); RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_); tree.param.UpdateAllowUnknown(cfg_);