Remove omp_get_max_threads in tree updaters. (#7590)

This commit is contained in:
Jiaming Yuan 2022-01-26 19:55:47 +08:00 committed by GitHub
parent 24789429fd
commit 5d7818e75d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 142 additions and 199 deletions

View File

@ -302,22 +302,7 @@ class SparsePage {
SparsePage GetTranspose(int num_columns, int32_t n_threads) const; SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
void SortRows() { void SortRows(int32_t n_threads);
auto ncol = static_cast<bst_omp_uint>(this->Size());
dmlc::OMPException exc;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < ncol; ++i) {
exc.Run([&]() {
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
std::sort(
this->data.HostVector().begin() + this->offset.HostVector()[i],
this->data.HostVector().begin() + this->offset.HostVector()[i + 1],
Entry::CmpValue);
}
});
}
exc.Rethrow();
}
/** /**
* \brief Pushes external data batch onto this page * \brief Pushes external data batch onto this page

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2019 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file tree_updater.h * \file tree_updater.h
* \brief General primitive for tree learning, * \brief General primitive for tree learning,
* Updating a collection of trees given the information. * Updating a collection of trees given the information.
@ -32,7 +32,7 @@ class Json;
*/ */
class TreeUpdater : public Configurable { class TreeUpdater : public Configurable {
protected: protected:
GenericParameter const* tparam_; GenericParameter const* ctx_;
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */

View File

@ -146,8 +146,7 @@ class ColumnMatrix {
} }
// construct column matrix from GHistIndexMatrix // construct column matrix from GHistIndexMatrix
inline void Init(const GHistIndexMatrix& gmat, inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold, int32_t n_threads) {
double sparse_threshold) {
const int32_t nfeature = static_cast<int32_t>(gmat.cut.Ptrs().size() - 1); const int32_t nfeature = static_cast<int32_t>(gmat.cut.Ptrs().size() - 1);
const size_t nrow = gmat.row_ptr.size() - 1; const size_t nrow = gmat.row_ptr.size() - 1;
// identify type of each column // identify type of each column
@ -208,12 +207,15 @@ class ColumnMatrix {
if (all_dense) { if (all_dense) {
BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize();
if (gmat_bin_size == kUint8BinsTypeSize) { if (gmat_bin_size == kUint8BinsTypeSize) {
SetIndexAllDense(gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues); SetIndexAllDense(gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
} else if (gmat_bin_size == kUint16BinsTypeSize) { } else if (gmat_bin_size == kUint16BinsTypeSize) {
SetIndexAllDense(gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues); SetIndexAllDense(gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
} else { } else {
CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize); CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize);
SetIndexAllDense(gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues); SetIndexAllDense(gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
} }
/* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize
but for ColumnMatrix we still have a chance to reduce the memory consumption */ but for ColumnMatrix we still have a chance to reduce the memory consumption */
@ -266,13 +268,13 @@ class ColumnMatrix {
template <typename T> template <typename T>
inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat, inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat,
const size_t nrow, const size_t nfeature, const size_t nrow, const size_t nfeature,
const bool noMissingValues) { const bool noMissingValues, int32_t n_threads) {
T* local_index = reinterpret_cast<T*>(&index_[0]); T* local_index = reinterpret_cast<T*>(&index_[0]);
/* missing values make sense only for column with type kDenseColumn, /* missing values make sense only for column with type kDenseColumn,
and if no missing values were observed it could be handled much faster. */ and if no missing values were observed it could be handled much faster. */
if (noMissingValues) { if (noMissingValues) {
ParallelFor(omp_ulong(nrow), [&](omp_ulong rid) { ParallelFor(nrow, n_threads, [&](auto rid) {
const size_t ibegin = rid*nfeature; const size_t ibegin = rid*nfeature;
const size_t iend = (rid+1)*nfeature; const size_t iend = (rid+1)*nfeature;
size_t j = 0; size_t j = 0;

View File

@ -1035,6 +1035,16 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
return transpose; return transpose;
} }
void SparsePage::SortRows(int32_t n_threads) {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
if (h_offset[i] < h_offset[i + 1]) {
std::sort(h_data.begin() + h_offset[i], h_data.begin() + h_offset[i + 1], Entry::CmpValue);
}
});
}
void SparsePage::Push(const SparsePage &batch) { void SparsePage::Push(const SparsePage &batch) {
auto& data_vec = data.HostVector(); auto& data_vec = data.HostVector();
auto& offset_vec = offset.HostVector(); auto& offset_vec = offset.HostVector();

View File

@ -67,7 +67,7 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
if (!sorted_column_page_) { if (!sorted_column_page_) {
sorted_column_page_.reset( sorted_column_page_.reset(
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads()))); new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
sorted_column_page_->SortRows(); sorted_column_page_->SortRows(ctx_.Threads());
} }
auto begin_iter = BatchIterator<SortedCSCPage>( auto begin_iter = BatchIterator<SortedCSCPage>(
new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_)); new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_));

View File

@ -339,7 +339,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_)); this->page_->PushCSC(csr->GetTranspose(n_features_, nthreads_));
CHECK_EQ(this->page_->Size(), n_features_); CHECK_EQ(this->page_->Size(), n_features_);
CHECK_EQ(this->page_->data.Size(), csr->data.Size()); CHECK_EQ(this->page_->data.Size(), csr->data.Size());
this->page_->SortRows(); this->page_->SortRows(this->nthreads_);
page_->SetBaseRowId(csr->base_rowid); page_->SetBaseRowId(csr->base_rowid);
this->WriteCache(); this->WriteCache();
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015 by Contributors * Copyright 2015-2022 by XGBoost Contributors
* \file tree_updater.cc * \file tree_updater.cc
* \brief Registry of tree updaters. * \brief Registry of tree updaters.
*/ */
@ -21,7 +21,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const
LOG(FATAL) << "Unknown tree updater " << name; LOG(FATAL) << "Unknown tree updater " << name;
} }
auto p_updater = (e->body)(task); auto p_updater = (e->body)(task);
p_updater->tparam_ = tparam; p_updater->ctx_ = tparam;
return p_updater; return p_updater;
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 XGBoost contributors * Copyright 2021-2022 XGBoost contributors
* *
* \brief Implementation for the approx tree method. * \brief Implementation for the approx tree method.
*/ */
@ -318,10 +318,10 @@ class GlobalApproxUpdater : public TreeUpdater {
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
if (hist_param_.single_precision_histogram) { if (hist_param_.single_precision_histogram) {
f32_impl_ = std::make_unique<GloablApproxBuilder<float>>(param_, m->Info(), tparam_, f32_impl_ = std::make_unique<GloablApproxBuilder<float>>(param_, m->Info(), ctx_,
column_sampler_, task_, &monitor_); column_sampler_, task_, &monitor_);
} else { } else {
f64_impl_ = std::make_unique<GloablApproxBuilder<double>>(param_, m->Info(), tparam_, f64_impl_ = std::make_unique<GloablApproxBuilder<double>>(param_, m->Info(), ctx_,
column_sampler_, task_, &monitor_); column_sampler_, task_, &monitor_);
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file updater_basemaker-inl.h * \file updater_basemaker-inl.h
* \brief implement a common tree constructor * \brief implement a common tree constructor
* \author Tianqi Chen * \author Tianqi Chen
@ -220,9 +220,7 @@ class BaseMaker: public TreeUpdater {
// set default direct nodes to default // set default direct nodes to default
// for leaf nodes that are not fresh, mark then to ~nid, // for leaf nodes that are not fresh, mark then to ~nid,
// so that they are ignored in future statistics collection // so that they are ignored in future statistics collection
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_); common::ParallelFor(p_fmat->Info().num_row_, ctx_->Threads(), [&](auto ridx) {
common::ParallelFor(ndata, [&](bst_omp_uint ridx) {
const int nid = this->DecodePosition(ridx); const int nid = this->DecodePosition(ridx);
if (tree[nid].IsLeaf()) { if (tree[nid].IsLeaf()) {
// mark finish when it is not a fresh leaf // mark finish when it is not a fresh leaf
@ -256,8 +254,7 @@ class BaseMaker: public TreeUpdater {
auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid); auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid);
if (it != sorted_split_set.end() && *it == fid) { if (it != sorted_split_set.end() && *it == fid) {
const auto ndata = static_cast<bst_omp_uint>(col.size()); common::ParallelFor(col.size(), ctx_->Threads(), [&](auto j) {
common::ParallelFor(ndata, [&](bst_omp_uint j) {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const bst_float fvalue = col[j].fvalue; const bst_float fvalue = col[j].fvalue;
const int nid = this->DecodePosition(ridx); const int nid = this->DecodePosition(ridx);
@ -312,8 +309,7 @@ class BaseMaker: public TreeUpdater {
auto page = batch.GetView(); auto page = batch.GetView();
for (auto fid : fsplits) { for (auto fid : fsplits) {
auto col = page[fid]; auto col = page[fid];
const auto ndata = static_cast<bst_omp_uint>(col.size()); common::ParallelFor(col.size(), ctx_->Threads(), [&](auto j) {
common::ParallelFor(ndata, [&](bst_omp_uint j) {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const bst_float fvalue = col[j].fvalue; const bst_float fvalue = col[j].fvalue;
const int nid = this->DecodePosition(ridx); const int nid = this->DecodePosition(ridx);
@ -337,10 +333,10 @@ class BaseMaker: public TreeUpdater {
std::vector< std::vector<TStats> > *p_thread_temp, std::vector< std::vector<TStats> > *p_thread_temp,
std::vector<TStats> *p_node_stats) { std::vector<TStats> *p_node_stats) {
std::vector< std::vector<TStats> > &thread_temp = *p_thread_temp; std::vector< std::vector<TStats> > &thread_temp = *p_thread_temp;
thread_temp.resize(omp_get_max_threads()); thread_temp.resize(ctx_->Threads());
p_node_stats->resize(tree.param.num_nodes); p_node_stats->resize(tree.param.num_nodes);
dmlc::OMPException exc; dmlc::OMPException exc;
#pragma omp parallel #pragma omp parallel num_threads(ctx_->Threads())
{ {
exc.Run([&]() { exc.Run([&]() {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
@ -352,8 +348,7 @@ class BaseMaker: public TreeUpdater {
} }
exc.Rethrow(); exc.Rethrow();
// setup position // setup position
const auto ndata = static_cast<bst_omp_uint>(fmat.Info().num_row_); common::ParallelFor(fmat.Info().num_row_, ctx_->Threads(), [&](auto ridx) {
common::ParallelFor(ndata, [&](bst_omp_uint ridx) {
const int nid = position_[ridx]; const int nid = position_[ridx];
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
if (nid >= 0) { if (nid >= 0) {

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2019 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file updater_colmaker.cc * \file updater_colmaker.cc
* \brief use columnwise update to construct a tree * \brief use columnwise update to construct a tree
* \author Tianqi Chen * \author Tianqi Chen
@ -114,8 +114,8 @@ class ColMaker: public TreeUpdater {
interaction_constraints_.Configure(param_, dmat->Info().num_row_); interaction_constraints_.Configure(param_, dmat->Info().num_row_);
// build tree // build tree
for (auto tree : trees) { for (auto tree : trees) {
CHECK(tparam_); CHECK(ctx_);
Builder builder(param_, colmaker_param_, interaction_constraints_, tparam_, Builder builder(param_, colmaker_param_, interaction_constraints_, ctx_,
column_densities_); column_densities_);
builder.Update(gpair->ConstHostVector(), dmat, tree); builder.Update(gpair->ConstHostVector(), dmat, tree);
} }
@ -270,17 +270,11 @@ class ColMaker: public TreeUpdater {
} }
const MetaInfo& info = fmat.Info(); const MetaInfo& info = fmat.Info();
// setup position // setup position
const auto ndata = static_cast<bst_omp_uint>(info.num_row_); common::ParallelFor(info.num_row_, ctx_->Threads(), [&](auto ridx) {
dmlc::OMPException exc; int32_t const tid = omp_get_thread_num();
#pragma omp parallel for schedule(static)
for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) {
exc.Run([&]() {
const int tid = omp_get_thread_num();
if (position_[ridx] < 0) return; if (position_[ridx] < 0) return;
stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]); stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]);
}); });
}
exc.Rethrow();
// sum the per thread statistics together // sum the per thread statistics together
for (int nid : qexpand) { for (int nid : qexpand) {
GradStats stats; GradStats stats;
@ -449,27 +443,20 @@ class ColMaker: public TreeUpdater {
// update the solution candidate // update the solution candidate
virtual void UpdateSolution(const SortedCSCPage &batch, virtual void UpdateSolution(const SortedCSCPage &batch,
const std::vector<bst_feature_t> &feat_set, const std::vector<bst_feature_t> &feat_set,
const std::vector<GradientPair> &gpair, const std::vector<GradientPair> &gpair, DMatrix *) {
DMatrix*) {
// start enumeration // start enumeration
const auto num_features = static_cast<bst_omp_uint>(feat_set.size()); const auto num_features = feat_set.size();
#if defined(_OPENMP)
CHECK(this->ctx_); CHECK(this->ctx_);
const int batch_size = // NOLINT const int batch_size = // NOLINT
std::max(static_cast<int>(num_features / this->ctx_->Threads() / 32), 1); std::max(static_cast<int>(num_features / this->ctx_->Threads() / 32), 1);
#endif // defined(_OPENMP)
{
auto page = batch.GetView(); auto page = batch.GetView();
dmlc::OMPException exc; common::ParallelFor(
#pragma omp parallel for schedule(dynamic, batch_size) num_features, ctx_->Threads(), common::Sched::Dyn(batch_size), [&](auto i) {
for (bst_omp_uint i = 0; i < num_features; ++i) {
exc.Run([&]() {
auto evaluator = tree_evaluator_.GetEvaluator(); auto evaluator = tree_evaluator_.GetEvaluator();
bst_feature_t const fid = feat_set[i]; bst_feature_t const fid = feat_set[i];
int32_t const tid = omp_get_thread_num(); int32_t const tid = omp_get_thread_num();
auto c = page[fid]; auto c = page[fid];
const bool ind = const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
if (colmaker_train_param_.NeedForwardSearch(column_densities_[fid], ind)) { if (colmaker_train_param_.NeedForwardSearch(column_densities_[fid], ind)) {
this->EnumerateSplit(c.data(), c.data() + c.size(), +1, fid, gpair, stemp_[tid], this->EnumerateSplit(c.data(), c.data() + c.size(), +1, fid, gpair, stemp_[tid],
evaluator); evaluator);
@ -480,9 +467,6 @@ class ColMaker: public TreeUpdater {
} }
}); });
} }
exc.Rethrow();
}
}
// find splits at current level, do split per level // find splits at current level, do split per level
inline void FindSplit(int depth, inline void FindSplit(int depth,
const std::vector<int> &qexpand, const std::vector<int> &qexpand,
@ -529,11 +513,9 @@ class ColMaker: public TreeUpdater {
// set default direct nodes to default // set default direct nodes to default
// for leaf nodes that are not fresh, mark then to ~nid, // for leaf nodes that are not fresh, mark then to ~nid,
// so that they are ignored in future statistics collection // so that they are ignored in future statistics collection
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_); common::ParallelFor(p_fmat->Info().num_row_, this->ctx_->Threads(), [&](auto ridx) {
CHECK_LT(ridx, position_.size()) << "ridx exceed bound "
common::ParallelFor(ndata, [&](bst_omp_uint ridx) { << "ridx=" << ridx << " pos=" << position_.size();
CHECK_LT(ridx, position_.size())
<< "ridx exceed bound " << "ridx="<< ridx << " pos=" << position_.size();
const int nid = this->DecodePosition(ridx); const int nid = this->DecodePosition(ridx);
if (tree[nid].IsLeaf()) { if (tree[nid].IsLeaf()) {
// mark finish when it is not a fresh leaf // mark finish when it is not a fresh leaf
@ -577,8 +559,7 @@ class ColMaker: public TreeUpdater {
auto page = batch.GetView(); auto page = batch.GetView();
for (auto fid : fsplits) { for (auto fid : fsplits) {
auto col = page[fid]; auto col = page[fid];
const auto ndata = static_cast<bst_omp_uint>(col.size()); common::ParallelFor(col.size(), this->ctx_->Threads(), [&](auto j) {
common::ParallelFor(ndata, [&](bst_omp_uint j) {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const int nid = this->DecodePosition(ridx); const int nid = this->DecodePosition(ridx);
const bst_float fvalue = col[j].fvalue; const bst_float fvalue = col[j].fvalue;

View File

@ -875,11 +875,11 @@ class GPUHistMaker : public TreeUpdater {
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_)); float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>(task_));
float_maker_->param_ = param; float_maker_->param_ = param;
float_maker_->Configure(args, tparam_); float_maker_->Configure(args, ctx_);
} else { } else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_)); double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>(task_));
double_maker_->param_ = param; double_maker_->param_ = param;
double_maker_->Configure(args, tparam_); double_maker_->Configure(args, ctx_);
} }
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2019 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file updater_histmaker.cc * \file updater_histmaker.cc
* \brief use histogram counting to construct a tree * \brief use histogram counting to construct a tree
* \author Tianqi Chen * \author Tianqi Chen
@ -203,11 +203,8 @@ class HistMaker: public BaseMaker {
// get the best split condition for each node // get the best split condition for each node
std::vector<SplitEntry> sol(qexpand_.size()); std::vector<SplitEntry> sol(qexpand_.size());
std::vector<GradStats> left_sum(qexpand_.size()); std::vector<GradStats> left_sum(qexpand_.size());
auto nexpand = static_cast<bst_omp_uint>(qexpand_.size()); auto nexpand = qexpand_.size();
dmlc::OMPException exc; common::ParallelFor(nexpand, ctx_->Threads(), common::Sched::Dyn(1), [&](auto wid) {
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
exc.Run([&]() {
const int nid = qexpand_[wid]; const int nid = qexpand_[wid];
CHECK_EQ(node2workindex_[nid], static_cast<int>(wid)); CHECK_EQ(node2workindex_[nid], static_cast<int>(wid));
SplitEntry &best = sol[wid]; SplitEntry &best = sol[wid];
@ -218,12 +215,10 @@ class HistMaker: public BaseMaker {
continue; continue;
} }
EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)], EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature + 1)], node_sum, feature_set[i],
node_sum, feature_set[i], &best, &left_sum[wid]); &best, &left_sum[wid]);
} }
}); });
}
exc.Rethrow();
// get the best result, we can synchronize the solution // get the best result, we can synchronize the solution
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
const bst_node_t nid = qexpand_[wid]; const bst_node_t nid = qexpand_[wid];
@ -341,27 +336,20 @@ class CQHistMaker: public HistMaker {
// if it is C++11, use lazy evaluation for Allreduce, // if it is C++11, use lazy evaluation for Allreduce,
// to gain speedup in recovery // to gain speedup in recovery
auto lazy_get_hist = [&]() { auto lazy_get_hist = [&]() {
thread_hist_.resize(omp_get_max_threads()); thread_hist_.resize(ctx_->Threads());
// start accumulating statistics // start accumulating statistics
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) { for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>()) {
auto page = batch.GetView(); auto page = batch.GetView();
// start enumeration // start enumeration
const auto nsize = static_cast<bst_omp_uint>(fset.size()); common::ParallelFor(fset.size(), ctx_->Threads(), common::Sched::Dyn(1), [&](auto i) {
dmlc::OMPException exc;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
exc.Run([&]() {
int fid = fset[i]; int fid = fset[i];
int offset = feat2workindex_[fid]; int offset = feat2workindex_[fid];
if (offset >= 0) { if (offset >= 0) {
this->UpdateHistCol(gpair, page[fid], info, tree, this->UpdateHistCol(gpair, page[fid], info, tree, fset, offset,
fset, offset,
&thread_hist_[omp_get_thread_num()]); &thread_hist_[omp_get_thread_num()]);
} }
}); });
} }
exc.Rethrow();
}
// update node statistics. // update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree, this->GetNodeStats(gpair, *p_fmat, tree,
&thread_stats_, &node_stats_); &thread_stats_, &node_stats_);
@ -412,7 +400,7 @@ class CQHistMaker: public HistMaker {
} }
{ {
// get summary // get summary
thread_sketch_.resize(omp_get_max_threads()); thread_sketch_.resize(ctx_->Threads());
// TWOPASS: use the real set + split set in the column iteration. // TWOPASS: use the real set + split set in the column iteration.
this->SetDefaultPostion(p_fmat, tree); this->SetDefaultPostion(p_fmat, tree);
@ -426,22 +414,16 @@ class CQHistMaker: public HistMaker {
this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree); this->CorrectNonDefaultPositionByBatch(batch, fsplit_set_, tree);
auto page = batch.GetView(); auto page = batch.GetView();
// start enumeration // start enumeration
const auto nsize = static_cast<bst_omp_uint>(work_set_.size()); common::ParallelFor(work_set_.size(), ctx_->Threads(), common::Sched::Dyn(1),
dmlc::OMPException exc; [&](auto i) {
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
exc.Run([&]() {
int fid = work_set_[i]; int fid = work_set_[i];
int offset = feat2workindex_[fid]; int offset = feat2workindex_[fid];
if (offset >= 0) { if (offset >= 0) {
this->UpdateSketchCol(gpair, page[fid], tree, this->UpdateSketchCol(gpair, page[fid], tree, work_set_size, offset,
work_set_size, offset,
&thread_sketch_[omp_get_thread_num()]); &thread_sketch_[omp_get_thread_num()]);
} }
}); });
} }
exc.Rethrow();
}
for (size_t i = 0; i < sketchs_.size(); ++i) { for (size_t i = 0; i < sketchs_.size(); ++i) {
common::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out; common::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs_[i].GetSummary(&out); sketchs_[i].GetSummary(&out);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2020 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file updater_prune.cc * \file updater_prune.cc
* \brief prune a tree given the statistics * \brief prune a tree given the statistics
* \author Tianqi Chen * \author Tianqi Chen
@ -24,7 +24,7 @@ DMLC_REGISTRY_FILE_TAG(updater_prune);
class TreePruner: public TreeUpdater { class TreePruner: public TreeUpdater {
public: public:
explicit TreePruner(ObjInfo task) { explicit TreePruner(ObjInfo task) {
syncher_.reset(TreeUpdater::Create("sync", tparam_, task)); syncher_.reset(TreeUpdater::Create("sync", ctx_, task));
pruner_monitor_.Init("TreePruner"); pruner_monitor_.Init("TreePruner");
} }
char const* Name() const override { char const* Name() const override {

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 by Contributors * Copyright 2017-2022 by XGBoost Contributors
* \file updater_quantile_hist.cc * \file updater_quantile_hist.cc
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov * \author Philip Cho, Tianqi Checn, Egor Smirnov
@ -40,19 +40,18 @@ DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
void QuantileHistMaker::Configure(const Args& args) { void QuantileHistMaker::Configure(const Args& args) {
// initialize pruner // initialize pruner
if (!pruner_) { if (!pruner_) {
pruner_.reset(TreeUpdater::Create("prune", tparam_, task_)); pruner_.reset(TreeUpdater::Create("prune", ctx_, task_));
} }
pruner_->Configure(args); pruner_->Configure(args);
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
} }
template<typename GradientSumT> template <typename GradientSumT>
void QuantileHistMaker::SetBuilder(const size_t n_trees, void QuantileHistMaker::SetBuilder(const size_t n_trees,
std::unique_ptr<Builder<GradientSumT>>* builder, std::unique_ptr<Builder<GradientSumT>>* builder, DMatrix* dmat) {
DMatrix *dmat) {
builder->reset( builder->reset(
new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat, task_)); new Builder<GradientSumT>(n_trees, param_, std::move(pruner_), dmat, task_, ctx_));
} }
template<typename GradientSumT> template<typename GradientSumT>
@ -75,7 +74,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
auto p_gmat = it.Page(); auto p_gmat = it.Page();
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("GmatInitialization"); updater_monitor_.Start("GmatInitialization");
column_matrix_.Init(*p_gmat, param_.sparse_threshold); column_matrix_.Init(*p_gmat, param_.sparse_threshold, ctx_->Threads());
updater_monitor_.Stop("GmatInitialization"); updater_monitor_.Stop("GmatInitialization");
// A proper solution is puting cut matrix in DMatrix, see: // A proper solution is puting cut matrix in DMatrix, see:
// https://github.com/dmlc/xgboost/issues/5143 // https://github.com/dmlc/xgboost/issues/5143
@ -347,7 +346,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
return row_set_collection_[node].Size(); return row_set_collection_[node].Size();
}, 1024); }, 1024);
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) { common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) {
const RowSetCollection::Elem rowset = row_set_collection_[node]; const RowSetCollection::Elem rowset = row_set_collection_[node];
if (rowset.begin != nullptr && rowset.end != nullptr) { if (rowset.begin != nullptr && rowset.end != nullptr) {
int nid = rowset.node_id; int nid = rowset.node_id;
@ -388,20 +387,19 @@ void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
} }
} }
#else #else
const size_t nthread = this->nthread_;
uint64_t initial_seed = rnd(); uint64_t initial_seed = rnd();
const size_t discard_size = info.num_row_ / nthread; auto n_threads = static_cast<size_t>(ctx_->Threads());
const size_t discard_size = info.num_row_ / n_threads;
std::bernoulli_distribution coin_flip(param_.subsample); std::bernoulli_distribution coin_flip(param_.subsample);
dmlc::OMPException exc; dmlc::OMPException exc;
#pragma omp parallel num_threads(nthread) #pragma omp parallel num_threads(n_threads)
{ {
exc.Run([&]() { exc.Run([&]() {
const size_t tid = omp_get_thread_num(); const size_t tid = omp_get_thread_num();
const size_t ibegin = tid * discard_size; const size_t ibegin = tid * discard_size;
const size_t iend = (tid == (nthread - 1)) ? const size_t iend = (tid == (n_threads - 1)) ? info.num_row_ : ibegin + discard_size;
info.num_row_ : ibegin + discard_size;
RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) { RandomReplace::MakeIf([&](size_t i, RandomReplace::EngineT& eng) {
return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng)); return !(gpair_ref[i].GetHess() >= 0.0f && coin_flip(eng));
}, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref); }, GradientPair(0), initial_seed, ibegin, iend, &gpair_ref);
@ -436,16 +434,9 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
uint32_t nbins = gmat.cut.Ptrs().back(); uint32_t nbins = gmat.cut.Ptrs().back();
// initialize histogram builder // initialize histogram builder
dmlc::OMPException exc; dmlc::OMPException exc;
#pragma omp parallel
{
exc.Run([&]() {
this->nthread_ = omp_get_num_threads();
});
}
exc.Rethrow(); exc.Rethrow();
this->histogram_builder_->Reset( this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, this->ctx_->Threads(), 1, rabit::IsDistributed());
this->nthread_, 1, rabit::IsDistributed());
std::vector<size_t>& row_indices = *row_set_collection_.Data(); std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(info.num_row_); row_indices.resize(info.num_row_);
@ -463,13 +454,14 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
// We should check that the partitioning was done correctly // We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories // and each row of the dataset fell into exactly one of the categories
} }
common::MemStackAllocator<bool, 128> buff(this->nthread_); auto n_threads = this->ctx_->Threads();
common::MemStackAllocator<bool, 128> buff(n_threads);
bool* p_buff = buff.Get(); bool* p_buff = buff.Get();
std::fill(p_buff, p_buff + this->nthread_, false); std::fill(p_buff, p_buff + this->ctx_->Threads(), false);
const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_); const size_t block_size = info.num_row_ / n_threads + !!(info.num_row_ % n_threads);
#pragma omp parallel num_threads(this->nthread_) #pragma omp parallel num_threads(n_threads)
{ {
exc.Run([&]() { exc.Run([&]() {
const size_t tid = omp_get_thread_num(); const size_t tid = omp_get_thread_num();
@ -488,7 +480,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
exc.Rethrow(); exc.Rethrow();
bool has_neg_hess = false; bool has_neg_hess = false;
for (int32_t tid = 0; tid < this->nthread_; ++tid) { for (int32_t tid = 0; tid < n_threads; ++tid) {
if (p_buff[tid]) { if (p_buff[tid]) {
has_neg_hess = true; has_neg_hess = true;
} }
@ -503,7 +495,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
} }
row_indices.resize(j); row_indices.resize(j);
} else { } else {
#pragma omp parallel num_threads(this->nthread_) #pragma omp parallel num_threads(n_threads)
{ {
exc.Run([&]() { exc.Run([&]() {
const size_t tid = omp_get_thread_num(); const size_t tid = omp_get_thread_num();
@ -543,10 +535,10 @@ void QuantileHistMaker::Builder<GradientSumT>::InitData(
p_last_tree_ = &tree; p_last_tree_ = &tree;
if (data_layout_ == DataLayout::kDenseDataOneBased) { if (data_layout_ == DataLayout::kDenseDataOneBased) {
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{ evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->nthread_, column_sampler_, task_, true}); param_, info, this->ctx_->Threads(), column_sampler_, task_, true});
} else { } else {
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{ evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->nthread_, column_sampler_, task_, false}); param_, info, this->ctx_->Threads(), column_sampler_, task_, false});
} }
if (data_layout_ == DataLayout::kDenseDataZeroBased if (data_layout_ == DataLayout::kDenseDataZeroBased
@ -642,7 +634,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<CPUE
}); });
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
// Store results in intermediate buffers from partition_builder_ // Store results in intermediate buffers from partition_builder_
common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) { common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin(); size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid; const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin); const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
@ -673,7 +665,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<CPUE
// 4. Copy elements from partition_builder_ to row_set_collection_ back // 4. Copy elements from partition_builder_ to row_set_collection_ back
// with updated row-indexes for each tree-node // with updated row-indexes for each tree-node
common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) { common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) {
const int32_t nid = nodes[node_in_set].nid; const int32_t nid = nodes[node_in_set].nid;
partition_builder_.MergeToArray(node_in_set, r.begin(), partition_builder_.MergeToArray(node_in_set, r.begin(),
const_cast<size_t*>(row_set_collection_[nid].begin)); const_cast<size_t*>(row_set_collection_[nid].begin));

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2021 by Contributors * Copyright 2017-2022 by XGBoost Contributors
* \file updater_quantile_hist.h * \file updater_quantile_hist.h
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Chen, Egor Smirnov * \author Philip Cho, Tianqi Chen, Egor Smirnov
@ -155,14 +155,16 @@ class QuantileHistMaker: public TreeUpdater {
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>; using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
// constructor // constructor
explicit Builder(const size_t n_trees, const TrainParam& param, explicit Builder(const size_t n_trees, const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner, DMatrix const* fmat, ObjInfo task) std::unique_ptr<TreeUpdater> pruner, DMatrix const* fmat, ObjInfo task,
GenericParameter const* ctx)
: n_trees_(n_trees), : n_trees_(n_trees),
param_(param), param_(param),
pruner_(std::move(pruner)), pruner_(std::move(pruner)),
p_last_tree_(nullptr), p_last_tree_(nullptr),
p_last_fmat_(fmat), p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>}, histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
task_{task} { task_{task},
ctx_{ctx} {
builder_monitor_.Init("Quantile::Builder"); builder_monitor_.Init("Quantile::Builder");
} }
// update one tree, growing // update one tree, growing
@ -225,8 +227,6 @@ class QuantileHistMaker: public TreeUpdater {
// --data fields-- // --data fields--
const size_t n_trees_; const size_t n_trees_;
const TrainParam& param_; const TrainParam& param_;
// number of omp thread used during training
int nthread_;
std::shared_ptr<common::ColumnSampler> column_sampler_{ std::shared_ptr<common::ColumnSampler> column_sampler_{
std::make_shared<common::ColumnSampler>()}; std::make_shared<common::ColumnSampler>()};
@ -258,9 +258,10 @@ class QuantileHistMaker: public TreeUpdater {
enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
DataLayout data_layout_; DataLayout data_layout_;
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>> std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>> histogram_builder_;
histogram_builder_;
ObjInfo task_; ObjInfo task_;
// Context for number of threads
GenericParameter const* ctx_;
common::Monitor builder_monitor_; common::Monitor builder_monitor_;
}; };

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014-2022 by XGBoost Contributors
* \file updater_refresh.cc * \file updater_refresh.cc
* \brief refresh the statistics and leaf value on the tree on the dataset * \brief refresh the statistics and leaf value on the tree on the dataset
* \author Tianqi Chen * \author Tianqi Chen
@ -51,11 +51,11 @@ class TreeRefresher: public TreeUpdater {
std::vector<std::vector<GradStats> > stemp; std::vector<std::vector<GradStats> > stemp;
std::vector<RegTree::FVec> fvec_temp; std::vector<RegTree::FVec> fvec_temp;
// setup temp space for each thread // setup temp space for each thread
const int nthread = omp_get_max_threads(); const int nthread = ctx_->Threads();
fvec_temp.resize(nthread, RegTree::FVec()); fvec_temp.resize(nthread, RegTree::FVec());
stemp.resize(nthread, std::vector<GradStats>()); stemp.resize(nthread, std::vector<GradStats>());
dmlc::OMPException exc; dmlc::OMPException exc;
#pragma omp parallel #pragma omp parallel num_threads(nthread)
{ {
exc.Run([&]() { exc.Run([&]() {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
@ -78,7 +78,7 @@ class TreeRefresher: public TreeUpdater {
auto page = batch.GetView(); auto page = batch.GetView();
CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max()); CHECK_LT(batch.Size(), std::numeric_limits<unsigned>::max());
const auto nbatch = static_cast<bst_omp_uint>(batch.Size()); const auto nbatch = static_cast<bst_omp_uint>(batch.Size());
common::ParallelFor(nbatch, [&](bst_omp_uint i) { common::ParallelFor(nbatch, ctx_->Threads(), [&](bst_omp_uint i) {
SparsePage::Inst inst = page[i]; SparsePage::Inst inst = page[i];
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
const auto ridx = static_cast<bst_uint>(batch.base_rowid + i); const auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
@ -95,7 +95,7 @@ class TreeRefresher: public TreeUpdater {
} }
// aggregate the statistics // aggregate the statistics
auto num_nodes = static_cast<int>(stemp[0].size()); auto num_nodes = static_cast<int>(stemp[0].size());
common::ParallelFor(num_nodes, [&](int nid) { common::ParallelFor(num_nodes, ctx_->Threads(), [&](int nid) {
for (int tid = 1; tid < nthread; ++tid) { for (int tid = 1; tid < nthread; ++tid) {
stemp[0][nid].Add(stemp[tid][nid]); stemp[0][nid].Add(stemp[tid][nid]);
} }

View File

@ -19,7 +19,7 @@ TEST(DenseColumn, Test) {
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2); column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0));
for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
@ -66,7 +66,7 @@ TEST(SparseColumn, Test) {
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5); column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0));
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: { case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0); auto col = column_matrix.GetColumn<uint8_t, true>(0);
@ -106,7 +106,7 @@ TEST(DenseColumnWithMissing, Test) {
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0));
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2); column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0));
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: { case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0); auto col = column_matrix.GetColumn<uint8_t, true>(0);

View File

@ -76,7 +76,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
// Make sure that the final sparse page has the right number of entries // Make sure that the final sparse page has the right number of entries
ASSERT_EQ(kEntries, page.data.Size()); ASSERT_EQ(kEntries, page.data.Size());
page.SortRows(); page.SortRows(common::OmpGetNumThreads(0));
auto v = page.GetView(); auto v = page.GetView();
for (size_t i = 0; i < v.Size(); ++i) { for (size_t i = 0; i < v.Size(); ++i) {
auto column = v[i]; auto column = v[i];

View File

@ -27,8 +27,8 @@ class QuantileHistMock : public QuantileHistMaker {
using GHistRowT = typename RealImpl::GHistRowT; using GHistRowT = typename RealImpl::GHistRowT;
BuilderMock(const TrainParam &param, std::unique_ptr<TreeUpdater> pruner, BuilderMock(const TrainParam &param, std::unique_ptr<TreeUpdater> pruner,
DMatrix const *fmat) DMatrix const *fmat, GenericParameter const* ctx)
: RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}) {} : RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}, ctx) {}
public: public:
void TestInitData(const GHistIndexMatrix& gmat, void TestInitData(const GHistIndexMatrix& gmat,
@ -166,7 +166,7 @@ class QuantileHistMock : public QuantileHistMaker {
ColumnMatrix cm; ColumnMatrix cm;
// treat everything as dense, as this is what we intend to test here // treat everything as dense, as this is what we intend to test here
cm.Init(gmat, 0.0); cm.Init(gmat, 0.0, common::OmpGetNumThreads(0));
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs); RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
const size_t num_row = dmat->Info().num_row_; const size_t num_row = dmat->Info().num_row_;
// split by feature 0 // split by feature 0
@ -222,6 +222,7 @@ class QuantileHistMock : public QuantileHistMaker {
int static constexpr kNRows = 8, kNCols = 16; int static constexpr kNRows = 8, kNCols = 16;
std::shared_ptr<xgboost::DMatrix> dmat_; std::shared_ptr<xgboost::DMatrix> dmat_;
GenericParameter ctx_;
const std::vector<std::pair<std::string, std::string> > cfg_; const std::vector<std::pair<std::string, std::string> > cfg_;
std::shared_ptr<BuilderMock<float> > float_builder_; std::shared_ptr<BuilderMock<float> > float_builder_;
std::shared_ptr<BuilderMock<double> > double_builder_; std::shared_ptr<BuilderMock<double> > double_builder_;
@ -233,18 +234,12 @@ class QuantileHistMock : public QuantileHistMaker {
QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} { QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} {
QuantileHistMaker::Configure(args); QuantileHistMaker::Configure(args);
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
ctx_.UpdateAllowUnknown(Args{});
if (single_precision_histogram) { if (single_precision_histogram) {
float_builder_.reset( float_builder_.reset(new BuilderMock<float>(param_, std::move(pruner_), dmat_.get(), &ctx_));
new BuilderMock<float>(
param_,
std::move(pruner_),
dmat_.get()));
} else { } else {
double_builder_.reset( double_builder_.reset(
new BuilderMock<double>( new BuilderMock<double>(param_, std::move(pruner_), dmat_.get(), &ctx_));
param_,
std::move(pruner_),
dmat_.get()));
} }
} }
~QuantileHistMock() override = default; ~QuantileHistMock() override = default;