diff --git a/doc/parameter.rst b/doc/parameter.rst index 8c83c302a..fcd20cfc3 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -225,12 +225,15 @@ Parameters for Tree Booster list is a group of indices of features that are allowed to interact with each other. See tutorial for more information -Additional parameters for `gpu_hist` tree method +Additional parameters for `hist` and 'gpu_hist' tree method ================================================ * ``single_precision_histogram``, [default=``false``] - - Use single precision to build histograms. See document for GPU support for more details. + - Use single precision to build histograms instead of double precision. + +Additional parameters for `gpu_hist` tree method +================================================ * ``deterministic_histogram``, [default=``true``] diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 32af1077b..480242611 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -141,6 +141,15 @@ class GradientPairInternal { public: using ValueT = T; + inline void Add(const ValueT& grad, const ValueT& hess) { + grad_ += grad; + hess_ += hess; + } + + inline static void Reduce(GradientPairInternal& a, const GradientPairInternal& b) { // NOLINT(*) + a += b; + } + XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {} XGBOOST_DEVICE GradientPairInternal(T grad, T hess) { @@ -148,9 +157,8 @@ class GradientPairInternal { SetHess(hess); } - // Copy constructor if of same value type - XGBOOST_DEVICE GradientPairInternal(const GradientPairInternal &g) - : grad_(g.grad_), hess_(g.hess_) {} // NOLINT + // Copy constructor if of same value type, marked as default to be trivially_copyable + GradientPairInternal(const GradientPairInternal &g) = default; // Copy constructor if different value type - use getters and setters to // perform conversion diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index b144a8077..e3ca953d2 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -830,54 +830,78 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, /*! * \brief fill a histogram by zeros in range [begin, end) */ -void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { +template +void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { #if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 - std::fill(hist.begin() + begin, hist.begin() + end, tree::GradStats()); + std::fill(hist.begin() + begin, hist.begin() + end, + xgboost::detail::GradientPairInternal()); #else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 - memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats)); + memset(hist.data() + begin, '\0', (end-begin)* + sizeof(xgboost::detail::GradientPairInternal)); #endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 } +template void InitilizeHistByZeroes(GHistRow hist, size_t begin, + size_t end); +template void InitilizeHistByZeroes(GHistRow hist, size_t begin, + size_t end); /*! * \brief Increment hist as dst += add in range [begin, end) */ -void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) { - using FPType = decltype(tree::GradStats::sum_grad); - FPType* pdst = reinterpret_cast(dst.data()); - const FPType* padd = reinterpret_cast(add.data()); +template +void IncrementHist(GHistRow dst, const GHistRow add, + size_t begin, size_t end) { + GradientSumT* pdst = reinterpret_cast(dst.data()); + const GradientSumT* padd = reinterpret_cast(add.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { pdst[i] += padd[i]; } } +template void IncrementHist(GHistRow dst, const GHistRow add, + size_t begin, size_t end); +template void IncrementHist(GHistRow dst, const GHistRow add, + size_t begin, size_t end); /*! * \brief Copy hist from src to dst in range [begin, end) */ -void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) { - using FPType = decltype(tree::GradStats::sum_grad); - FPType* pdst = reinterpret_cast(dst.data()); - const FPType* psrc = reinterpret_cast(src.data()); +template +void CopyHist(GHistRow dst, const GHistRow src, + size_t begin, size_t end) { + GradientSumT* pdst = reinterpret_cast(dst.data()); + const GradientSumT* psrc = reinterpret_cast(src.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { pdst[i] = psrc[i]; } } +template void CopyHist(GHistRow dst, const GHistRow src, + size_t begin, size_t end); +template void CopyHist(GHistRow dst, const GHistRow src, + size_t begin, size_t end); /*! * \brief Compute Subtraction: dst = src1 - src2 in range [begin, end) */ -void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2, +template +void SubtractionHist(GHistRow dst, const GHistRow src1, + const GHistRow src2, size_t begin, size_t end) { - using FPType = decltype(tree::GradStats::sum_grad); - FPType* pdst = reinterpret_cast(dst.data()); - const FPType* psrc1 = reinterpret_cast(src1.data()); - const FPType* psrc2 = reinterpret_cast(src2.data()); + GradientSumT* pdst = reinterpret_cast(dst.data()); + const GradientSumT* psrc1 = reinterpret_cast(src1.data()); + const GradientSumT* psrc2 = reinterpret_cast(src2.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { pdst[i] = psrc1[i] - psrc2[i]; } } +template void SubtractionHist(GHistRow dst, const GHistRow src1, + const GHistRow src2, + size_t begin, size_t end); +template void SubtractionHist(GHistRow dst, const GHistRow src1, + const GHistRow src2, + size_t begin, size_t end); struct Prefetch { public: @@ -908,7 +932,7 @@ void BuildHistDenseKernel(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, const size_t n_features, - GHistRow hist) { + GHistRow hist) { const size_t size = row_indices.Size(); const size_t* rid = row_indices.begin; const float* pgh = reinterpret_cast(gpair.data()); @@ -948,7 +972,7 @@ template void BuildHistSparseKernel(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - GHistRow hist) { + GHistRow hist) { const size_t size = row_indices.Size(); const size_t* rid = row_indices.begin; const float* pgh = reinterpret_cast(gpair.data()); @@ -987,7 +1011,7 @@ void BuildHistSparseKernel(const std::vector& gpair, template void BuildHistDispatchKernel(const std::vector& gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, GHistRow hist, bool isDense) { + const GHistIndexMatrix& gmat, GHistRow hist, bool isDense) { if (isDense) { const size_t* row_ptr = gmat.row_ptr.data(); const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]]; @@ -1002,7 +1026,7 @@ void BuildHistDispatchKernel(const std::vector& gpair, template void BuildHistKernel(const std::vector& gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, const bool isDense, GHistRow hist) { + const GHistIndexMatrix& gmat, const bool isDense, GHistRow hist) { const bool is_dense = row_indices.Size() && isDense; switch (gmat.index.GetBinTypeSize()) { case kUint8BinsTypeSize: @@ -1022,12 +1046,12 @@ void BuildHistKernel(const std::vector& gpair, } } -void GHistBuilder::BuildHist(const std::vector& gpair, +template +void GHistBuilder::BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - GHistRow hist, + GHistRowT hist, bool isDense) { - using FPType = decltype(tree::GradStats::sum_grad); const size_t nrows = row_indices.Size(); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); @@ -1036,21 +1060,34 @@ void GHistBuilder::BuildHist(const std::vector& gpair, if (contiguousBlock) { // contiguous memory access, built-in HW prefetching is enough - BuildHistKernel(gpair, row_indices, gmat, isDense, hist); + BuildHistKernel(gpair, row_indices, gmat, isDense, hist); } else { const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size); const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end); - BuildHistKernel(gpair, span1, gmat, isDense, hist); + BuildHistKernel(gpair, span1, gmat, isDense, hist); // no prefetching to avoid loading extra memory - BuildHistKernel(gpair, span2, gmat, isDense, hist); + BuildHistKernel(gpair, span2, gmat, isDense, hist); } } +template +void GHistBuilder::BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist, + bool isDense); +template +void GHistBuilder::BuildHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist, + bool isDense); -void GHistBuilder::BuildBlockHist(const std::vector& gpair, +template +void GHistBuilder::BuildBlockHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexBlockMatrix& gmatb, - GHistRow hist) { + GHistRowT hist) { constexpr int kUnroll = 8; // loop unrolling factor const size_t nblock = gmatb.GetNumBlock(); const size_t nrows = row_indices.end - row_indices.begin; @@ -1058,7 +1095,7 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, #if defined(_OPENMP) const auto nthread = static_cast(this->nthread_); // NOLINT #endif // defined(_OPENMP) - tree::GradStats* p_hist = hist.data(); + xgboost::detail::GradientPairInternal* p_hist = hist.data(); #pragma omp parallel for num_threads(nthread) schedule(guided) for (bst_omp_uint bid = 0; bid < nblock; ++bid) { @@ -1079,7 +1116,7 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, for (int k = 0; k < kUnroll; ++k) { for (size_t j = ibegin[k]; j < iend[k]; ++j) { const uint32_t bin = gmat.index[j]; - p_hist[bin].Add(stat[k]); + p_hist[bin].Add(stat[k].GetGrad(), stat[k].GetHess()); } } } @@ -1090,13 +1127,27 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, const GradientPair stat = gpair[rid]; for (size_t j = ibegin; j < iend; ++j) { const uint32_t bin = gmat.index[j]; - p_hist[bin].Add(stat); + p_hist[bin].Add(stat.GetGrad(), stat.GetHess()); } } } } +template +void GHistBuilder::BuildBlockHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexBlockMatrix& gmatb, + GHistRow hist); +template +void GHistBuilder::BuildBlockHist(const std::vector& gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexBlockMatrix& gmatb, + GHistRow hist); -void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { + +template +void GHistBuilder::SubtractionTrick(GHistRowT self, + GHistRowT sibling, + GHistRowT parent) { const size_t size = self.size(); CHECK_EQ(sibling.size(), size); CHECK_EQ(parent.size(), size); @@ -1111,6 +1162,14 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa SubtractionHist(self, parent, sibling, ibegin, iend); } } +template +void GHistBuilder::SubtractionTrick(GHistRow self, + GHistRow sibling, + GHistRow parent); +template +void GHistBuilder::SubtractionTrick(GHistRow self, + GHistRow sibling, + GHistRow parent); } // namespace common } // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 92c520966..d60960301 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -391,46 +391,52 @@ class GHistIndexBlockMatrix { std::vector blocks_; }; -/*! - * \brief histogram of gradient statistics for a single node. - * Consists of multiple GradStats, each entry showing total gradient statistics - * for that particular bin - * Uses global bin id so as to represent all features simultaneously - */ -using GHistRow = Span; +template +using GHistRow = Span >; /*! * \brief fill a histogram by zeros */ -void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end); +template +void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end); /*! * \brief Increment hist as dst += add in range [begin, end) */ -void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end); +template +void IncrementHist(GHistRow dst, const GHistRow add, + size_t begin, size_t end); /*! * \brief Copy hist from src to dst in range [begin, end) */ -void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end); +template +void CopyHist(GHistRow dst, const GHistRow src, + size_t begin, size_t end); /*! * \brief Compute Subtraction: dst = src1 - src2 in range [begin, end) */ -void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2, +template +void SubtractionHist(GHistRow dst, const GHistRow src1, + const GHistRow src2, size_t begin, size_t end); /*! * \brief histogram of gradient statistics for multiple nodes */ +template class HistCollection { public: + using GHistRowT = GHistRow; + using GradientPairT = xgboost::detail::GradientPairInternal; + // access histogram for i-th node - GHistRow operator[](bst_uint nid) const { + GHistRowT operator[](bst_uint nid) const { constexpr uint32_t kMax = std::numeric_limits::max(); CHECK_NE(row_ptr_[nid], kMax); - tree::GradStats* ptr = - const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]); + GradientPairT* ptr = + const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]); return {ptr, nbins_}; } @@ -473,7 +479,7 @@ class HistCollection { /*! \brief amount of active nodes in hist collection */ uint32_t n_nodes_added_ = 0; - std::vector data_; + std::vector data_; /*! \brief row_ptr_[nid] locates bin for histogram of node nid */ std::vector row_ptr_; @@ -484,8 +490,11 @@ class HistCollection { * Supports processing multiple tree-nodes for nested parallelism * Able to reduce histograms across threads in efficient way */ +template class ParallelGHistBuilder { public: + using GHistRowT = GHistRow; + void Init(size_t nbins) { if (nbins != nbins_) { hist_buffer_.Init(nbins); @@ -496,7 +505,7 @@ class ParallelGHistBuilder { // Add new elements if needed, mark all hists as unused // targeted_hists - already allocated hists which should contain final results after Reduce() call void Reset(size_t nthreads, size_t nodes, const BlockedSpace2d& space, - const std::vector& targeted_hists) { + const std::vector& targeted_hists) { hist_buffer_.Init(nbins_); tid_nid_to_hist_.clear(); hist_memory_.clear(); @@ -518,12 +527,12 @@ class ParallelGHistBuilder { } // Get specified hist, initialize hist by zeros if it wasn't used before - GHistRow GetInitializedHist(size_t tid, size_t nid) { + GHistRowT GetInitializedHist(size_t tid, size_t nid) { CHECK_LT(nid, nodes_); CHECK_LT(tid, nthreads_); size_t idx = tid_nid_to_hist_.at({tid, nid}); - GHistRow hist = hist_memory_[idx]; + GHistRowT hist = hist_memory_[idx]; if (!hist_was_used_[tid * nodes_ + nid]) { InitilizeHistByZeroes(hist, 0, hist.size()); @@ -538,14 +547,14 @@ class ParallelGHistBuilder { CHECK_GT(end, begin); CHECK_LT(nid, nodes_); - GHistRow dst = targeted_hists_[nid]; + GHistRowT dst = targeted_hists_[nid]; bool is_updated = false; for (size_t tid = 0; tid < nthreads_; ++tid) { if (hist_was_used_[tid * nodes_ + nid]) { is_updated = true; const size_t idx = tid_nid_to_hist_.at({tid, nid}); - GHistRow src = hist_memory_[idx]; + GHistRowT src = hist_memory_[idx]; if (dst.data() != src.data()) { IncrementHist(dst, src, begin, end); @@ -636,7 +645,7 @@ class ParallelGHistBuilder { /*! \brief number of nodes which will be processed in parallel */ size_t nodes_ = 0; /*! \brief Buffer for additional histograms for Parallel processing */ - HistCollection hist_buffer_; + HistCollection hist_buffer_; /*! * \brief Marks which hists were used, it means that they should be merged. * Contains only {true or false} values @@ -647,9 +656,9 @@ class ParallelGHistBuilder { /*! \brief Buffer for additional histograms for Parallel processing */ std::vector threads_to_nids_map_; /*! \brief Contains histograms for final results */ - std::vector targeted_hists_; + std::vector targeted_hists_; /*! \brief Allocated memory for histograms used for construction */ - std::vector hist_memory_; + std::vector hist_memory_; /*! \brief map pair {tid, nid} to index of allocated histogram from hist_memory_ */ std::map, size_t> tid_nid_to_hist_; }; @@ -657,8 +666,11 @@ class ParallelGHistBuilder { /*! * \brief builder for histograms of gradient statistics */ +template class GHistBuilder { public: + using GHistRowT = GHistRow; + GHistBuilder() = default; GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {} @@ -666,15 +678,17 @@ class GHistBuilder { void BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, - GHistRow hist, + GHistRowT hist, bool isDense); // same, with feature grouping void BuildBlockHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, const GHistIndexBlockMatrix& gmatb, - GHistRow hist); + GHistRowT hist); // construct a histogram via subtraction trick - void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); + void SubtractionTrick(GHistRowT self, + GHistRowT sibling, + GHistRowT parent); uint32_t GetNumBins() const { return nbins_; diff --git a/src/tree/param.h b/src/tree/param.h index 8a71cd1ef..280f06066 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -332,14 +332,15 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) /*! \brief core statistics used for tree construction */ struct XGBOOST_ALIGNAS(16) GradStats { + using GradType = double; /*! \brief sum gradient statistics */ - double sum_grad { 0 }; + GradType sum_grad { 0 }; /*! \brief sum hessian statistics */ - double sum_hess { 0 }; + GradType sum_hess { 0 }; public: - XGBOOST_DEVICE double GetGrad() const { return sum_grad; } - XGBOOST_DEVICE double GetHess() const { return sum_hess; } + XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; } + XGBOOST_DEVICE GradType GetHess() const { return sum_hess; } friend std::ostream& operator<<(std::ostream& os, GradStats s) { os << s.GetGrad() << "/" << s.GetHess(); @@ -354,7 +355,7 @@ struct XGBOOST_ALIGNAS(16) GradStats { template XGBOOST_DEVICE explicit GradStats(const GpairT &sum) : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} - explicit GradStats(const double grad, const double hess) + explicit GradStats(const GradType grad, const GradType hess) : sum_grad(grad), sum_hess(hess) {} /*! * \brief accumulate statistics @@ -379,7 +380,7 @@ struct XGBOOST_ALIGNAS(16) GradStats { /*! \return whether the statistics is not used yet */ inline bool Empty() const { return sum_hess == 0.0; } /*! \brief add statistics to the data */ - inline void Add(double grad, double hess) { + inline void Add(GradType grad, GradType hess) { sum_grad += grad; sum_hess += hess; } @@ -425,7 +426,11 @@ struct SplitEntryContainer { * \param split_index the feature index where the split is on */ bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { - if (this->SplitIndex() <= split_index) { + if (std::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, + // for example when lambda = 0 & min_child_weight = 0 + // skip value in this case + return false; + } else if (this->SplitIndex() <= split_index) { return new_loss_chg > this->loss_chg; } else { return !(this->loss_chg > new_loss_chg); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 051a1a44f..30eb01a72 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -35,6 +35,8 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_quantile_hist); +DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); + void QuantileHistMaker::Configure(const Args& args) { // initialize pruner if (!pruner_) { @@ -42,7 +44,7 @@ void QuantileHistMaker::Configure(const Args& args) { } pruner_->Configure(args); param_.UpdateAllowUnknown(args); - + hist_maker_param_.UpdateAllowUnknown(args); // initialize the split evaluator if (!spliteval_) { spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); @@ -51,6 +53,32 @@ void QuantileHistMaker::Configure(const Args& args) { spliteval_->Init(¶m_); } +template +void QuantileHistMaker::SetBuilder(std::unique_ptr>* builder, + DMatrix *dmat) { + builder->reset(new Builder( + param_, + std::move(pruner_), + std::unique_ptr(spliteval_->GetHostClone()), + int_constraint_, dmat)); + if (rabit::IsDistributed()) { + (*builder)->SetHistSynchronizer(new DistributedHistSynchronizer()); + (*builder)->SetHistRowsAdder(new DistributedHistRowsAdder()); + } else { + (*builder)->SetHistSynchronizer(new BatchHistSynchronizer()); + (*builder)->SetHistRowsAdder(new BatchHistRowsAdder()); + } +} + +template +void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr>& builder, + HostDeviceVector *gpair, + DMatrix *dmat, + const std::vector &trees) { + for (auto tree : trees) { + builder->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree); + } +} void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { @@ -71,22 +99,16 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, param_.learning_rate = lr / trees.size(); int_constraint_.Configure(param_, dmat->Info().num_col_); // build tree - if (!builder_) { - builder_.reset(new Builder( - param_, - std::move(pruner_), - std::unique_ptr(spliteval_->GetHostClone()), - int_constraint_, dmat)); - if (rabit::IsDistributed()) { - builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); - builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); - } else { - builder_->SetHistSynchronizer(new BatchHistSynchronizer()); - builder_->SetHistRowsAdder(new BatchHistRowsAdder()); + if (hist_maker_param_.single_precision_histogram) { + if (!float_builder_) { + SetBuilder(&float_builder_, dmat); } - } - for (auto tree : trees) { - builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree); + CallBuilderUpdate(float_builder_, gpair, dmat, trees); + } else { + if (!double_builder_) { + SetBuilder(&double_builder_, dmat); + } + CallBuilderUpdate(double_builder_, gpair, dmat, trees); } param_.learning_rate = lr; @@ -97,14 +119,21 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, bool QuantileHistMaker::UpdatePredictionCache( const DMatrix* data, HostDeviceVector* out_preds) { - if (!builder_ || param_.subsample < 1.0f) { + if (param_.subsample < 1.0f) { return false; } else { - return builder_->UpdatePredictionCache(data, out_preds); + if (hist_maker_param_.single_precision_histogram && float_builder_) { + return float_builder_->UpdatePredictionCache(data, out_preds); + } else if (double_builder_) { + return double_builder_->UpdatePredictionCache(data, out_preds); + } else { + return false; + } } } -void BatchHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder, +template +void BatchHistSynchronizer::SyncHistograms(BuilderT* builder, int starting_index, int sync_count, RegTree *p_tree) { @@ -130,7 +159,8 @@ void BatchHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder, builder->builder_monitor_.Stop("SyncHistograms"); } -void DistributedHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder, +template +void DistributedHistSynchronizer::SyncHistograms(BuilderT* builder, int starting_index, int sync_count, RegTree *p_tree) { @@ -172,9 +202,11 @@ void DistributedHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* bui builder->builder_monitor_.Stop("SyncHistograms"); } -void DistributedHistSynchronizer::ParallelSubtractionHist(QuantileHistMaker::Builder* builder, +template +void DistributedHistSynchronizer::ParallelSubtractionHist( + BuilderT* builder, const common::BlockedSpace2d& space, - const std::vector& nodes, + const std::vector& nodes, const RegTree * p_tree) { common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) { const auto entry = nodes[node]; @@ -190,7 +222,8 @@ void DistributedHistSynchronizer::ParallelSubtractionHist(QuantileHistMaker::Bui }); } -void BatchHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder, +template +void BatchHistRowsAdder::AddHistRows(BuilderT* builder, int *starting_index, int *sync_count, RegTree *p_tree) { builder->builder_monitor_.Start("AddHistRows"); @@ -209,7 +242,8 @@ void BatchHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder, builder->builder_monitor_.Stop("AddHistRows"); } -void DistributedHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder, +template +void DistributedHistRowsAdder::AddHistRows(BuilderT* builder, int *starting_index, int *sync_count, RegTree *p_tree) { builder->builder_monitor_.Start("AddHistRows"); @@ -243,15 +277,28 @@ void DistributedHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder, builder->builder_monitor_.Stop("AddHistRows"); } -void QuantileHistMaker::Builder::SetHistSynchronizer(HistSynchronizer* sync) { +template +void QuantileHistMaker::Builder::SetHistSynchronizer( + HistSynchronizer* sync) { hist_synchronizer_.reset(sync); } +template void QuantileHistMaker::Builder::SetHistSynchronizer( + HistSynchronizer* sync); +template void QuantileHistMaker::Builder::SetHistSynchronizer( + HistSynchronizer* sync); -void QuantileHistMaker::Builder::SetHistRowsAdder(HistRowsAdder* adder) { +template +void QuantileHistMaker::Builder::SetHistRowsAdder( + HistRowsAdder* adder) { hist_rows_adder_.reset(adder); } +template void QuantileHistMaker::Builder::SetHistRowsAdder( + HistRowsAdder* sync); +template void QuantileHistMaker::Builder::SetHistRowsAdder( + HistRowsAdder* sync); -void QuantileHistMaker::Builder::BuildHistogramsLossGuide( +template +void QuantileHistMaker::Builder::BuildHistogramsLossGuide( ExpandEntry entry, const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, @@ -274,7 +321,8 @@ void QuantileHistMaker::Builder::BuildHistogramsLossGuide( hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree); } -void QuantileHistMaker::Builder::BuildLocalHistograms( +template +void QuantileHistMaker::Builder::BuildLocalHistograms( const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, RegTree *p_tree, @@ -289,7 +337,7 @@ void QuantileHistMaker::Builder::BuildLocalHistograms( return row_set_collection_[nid].Size(); }, 256); - std::vector target_hists(n_nodes); + std::vector target_hists(n_nodes); for (size_t i = 0; i < n_nodes; ++i) { const int32_t nid = nodes_for_explicit_hist_build_[i].nid; target_hists[i] = hist_[nid]; @@ -312,8 +360,8 @@ void QuantileHistMaker::Builder::BuildLocalHistograms( builder_monitor_.Stop("BuildLocalHistograms"); } - -void QuantileHistMaker::Builder::BuildNodeStats( +template +void QuantileHistMaker::Builder::BuildNodeStats( const GHistIndexMatrix &gmat, DMatrix *p_fmat, RegTree *p_tree, @@ -336,8 +384,8 @@ void QuantileHistMaker::Builder::BuildNodeStats( } builder_monitor_.Stop("BuildNodeStats"); } - -void QuantileHistMaker::Builder::AddSplitsToTree( +template +void QuantileHistMaker::Builder::AddSplitsToTree( const GHistIndexMatrix &gmat, RegTree *p_tree, int *num_leaves, @@ -377,8 +425,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree( } } - -void QuantileHistMaker::Builder::EvaluateAndApplySplits( +template +void QuantileHistMaker::Builder::EvaluateAndApplySplits( const GHistIndexMatrix &gmat, const ColumnMatrix &column_matrix, RegTree *p_tree, @@ -400,7 +448,8 @@ void QuantileHistMaker::Builder::EvaluateAndApplySplits( // Exception: in distributed setting, we always build the histogram for the left child node // and use 'Subtraction Trick' to built the histogram for the right child node. // This ensures that the workers operate on the same set of tree nodes. -void QuantileHistMaker::Builder::SplitSiblings(const std::vector& nodes, +template +void QuantileHistMaker::Builder::SplitSiblings(const std::vector& nodes, std::vector* small_siblings, std::vector* big_siblings, RegTree *p_tree) { @@ -427,8 +476,8 @@ void QuantileHistMaker::Builder::SplitSiblings(const std::vector& n } builder_monitor_.Stop("SplitSiblings"); } - -void QuantileHistMaker::Builder::ExpandWithDepthWise( +template +void QuantileHistMaker::Builder::ExpandWithDepthWise( const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, const ColumnMatrix &column_matrix, @@ -468,8 +517,8 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise( } } } - -void QuantileHistMaker::Builder::ExpandWithLossGuide( +template +void QuantileHistMaker::Builder::ExpandWithLossGuide( const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, const ColumnMatrix& column_matrix, @@ -545,7 +594,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( builder_monitor_.Stop("ExpandWithLossGuide"); } -void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, +template +void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, const ColumnMatrix& column_matrix, HostDeviceVector* gpair, @@ -574,8 +624,8 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, builder_monitor_.Stop("Update"); } - -bool QuantileHistMaker::Builder::UpdatePredictionCache( +template +bool QuantileHistMaker::Builder::UpdatePredictionCache( const DMatrix* data, HostDeviceVector* p_out_preds) { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in @@ -624,8 +674,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( builder_monitor_.Stop("UpdatePredictionCache"); return true; } - -void QuantileHistMaker::Builder::InitSampling(const std::vector& gpair, +template +void QuantileHistMaker::Builder::InitSampling(const std::vector& gpair, const DMatrix& fmat, std::vector* row_indices) { const auto& info = fmat.Info(); @@ -682,7 +732,8 @@ void QuantileHistMaker::Builder::InitSampling(const std::vector& g row_indices_local.resize(prefix_sum); #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG } -void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, +template +void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, const RegTree& tree) { @@ -712,7 +763,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, { this->nthread_ = omp_get_num_threads(); } - hist_builder_ = GHistBuilder(this->nthread_, nbins); + hist_builder_ = GHistBuilder(this->nthread_, nbins); std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(info.num_row_); @@ -842,7 +893,8 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, // is equal to sum of statistics for all values: // then - there are no missing values // else - there are missing values -bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e, +template +bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e, const NodeEntry& snode) { if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { return false; @@ -852,9 +904,11 @@ bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e, } // nodes_set - set of nodes to be processed in parallel -void QuantileHistMaker::Builder::EvaluateSplits(const std::vector& nodes_set, +template +void QuantileHistMaker::Builder::EvaluateSplits( + const std::vector& nodes_set, const GHistIndexMatrix& gmat, - const HistCollection& hist, + const HistCollection& hist, const RegTree& tree) { builder_monitor_.Start("EvaluateSplits"); @@ -886,7 +940,7 @@ void QuantileHistMaker::Builder::EvaluateSplits(const std::vector& common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) { const int32_t nid = nodes_set[nid_in_set].nid; const auto tid = static_cast(omp_get_thread_num()); - GHistRow node_hist = hist[nid]; + GHistRowT node_hist = hist[nid]; for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) { const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set]; @@ -1014,8 +1068,9 @@ inline std::pair PartitionSparseKernel( return {nleft_elems, nright_elems}; } +template template -void QuantileHistMaker::Builder::PartitionKernel( +void QuantileHistMaker::Builder::PartitionKernel( const size_t node_in_set, const size_t nid, common::Range1d range, const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree) { const size_t* rid = row_set_collection_[nid].begin; @@ -1068,8 +1123,9 @@ void QuantileHistMaker::Builder::PartitionKernel( partition_builder_.SetNRightElems(node_in_set, range.begin(), range.end(), n_right); } - -void QuantileHistMaker::Builder::FindSplitConditions(const std::vector& nodes, +template +void QuantileHistMaker::Builder::FindSplitConditions( + const std::vector& nodes, const RegTree& tree, const GHistIndexMatrix& gmat, std::vector* split_conditions) { @@ -1095,9 +1151,10 @@ void QuantileHistMaker::Builder::FindSplitConditions(const std::vector& nodes, - RegTree* p_tree) { +template +void QuantileHistMaker::Builder::AddSplitsToRowSet( + const std::vector& nodes, + RegTree* p_tree) { const size_t n_nodes = nodes.size(); for (size_t i = 0; i < n_nodes; ++i) { const int32_t nid = nodes[i].nid; @@ -1109,11 +1166,11 @@ void QuantileHistMaker::Builder::AddSplitsToRowSet(const std::vector nodes, +template +void QuantileHistMaker::Builder::ApplySplit(const std::vector nodes, const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix, - const HistCollection& hist, + const HistCollection& hist, RegTree* p_tree) { builder_monitor_.Start("ApplySplit"); // 1. Find split condition for each split @@ -1169,8 +1226,8 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vector nodes AddSplitsToRowSet(nodes, p_tree); builder_monitor_.Stop("ApplySplit"); } - -void QuantileHistMaker::Builder::InitNewNode(int nid, +template +void QuantileHistMaker::Builder::InitNewNode(int nid, const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, @@ -1181,8 +1238,8 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, } { - auto& stats = snode_[nid].stats; - GHistRow hist = hist_[nid]; + GHistRowT hist = hist_[nid]; + GradientPairT grad_stat; if (tree[nid].IsRoot()) { if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { const std::vector& row_ptr = gmat.cut.Ptrs(); @@ -1190,16 +1247,17 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, const uint32_t iend = row_ptr[fid_least_bins_ + 1]; auto begin = hist.data(); for (uint32_t i = ibegin; i < iend; ++i) { - const GradStats et = begin[i]; - stats.Add(et.sum_grad, et.sum_hess); + const GradientPairT et = begin[i]; + grad_stat.Add(et.GetGrad(), et.GetHess()); } } else { const RowSetCollection::Elem e = row_set_collection_[nid]; for (const size_t* it = e.begin; it < e.end; ++it) { - stats.Add(gpair[*it]); + grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess()); } } - histred_.Allreduce(&snode_[nid].stats, 1); + histred_.Allreduce(&grad_stat, 1); + snode_[nid].stats = tree::GradStats(grad_stat.GetGrad(), grad_stat.GetHess()); } else { int parent_id = tree[nid].Parent(); if (tree[nid].IsLeftChild()) { @@ -1225,9 +1283,10 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, // Enumerate the split values of specific feature. // Returns the sum of gradients corresponding to the data points that contains a non-missing value // for the particular feature fid. +template template -GradStats QuantileHistMaker::Builder::EnumerateSplit( - const GHistIndexMatrix &gmat, const GHistRow &hist, const NodeEntry &snode, +GradStats QuantileHistMaker::Builder::EnumerateSplit( + const GHistIndexMatrix &gmat, const GHistRowT &hist, const NodeEntry &snode, SplitEntry *p_best, bst_uint fid, bst_uint nodeID) const { CHECK(d_step == +1 || d_step == -1); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 9402569a6..d74c16f72 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -78,8 +78,35 @@ using xgboost::common::GHistBuilder; using xgboost::common::ColumnMatrix; using xgboost::common::Column; +template class HistSynchronizer; + +template +class BatchHistSynchronizer; + +template +class DistributedHistSynchronizer; + +template class HistRowsAdder; + +template +class BatchHistRowsAdder; + +template +class DistributedHistRowsAdder; + +// training parameters specific to this algorithm +struct CPUHistMakerTrainParam + : public XGBoostParameter { + bool single_precision_histogram; + // declare parameters + DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + } +}; + /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: @@ -98,10 +125,12 @@ class QuantileHistMaker: public TreeUpdater { void LoadConfig(Json const& in) override { auto const& config = get(in); FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_); } void SaveConfig(Json* p_out) const override { auto& out = *p_out; out["train_param"] = ToJson(param_); + out["cpu_hist_train_param"] = ToJson(hist_maker_param_); } char const* Name() const override { @@ -109,12 +138,21 @@ class QuantileHistMaker: public TreeUpdater { } protected: + template friend class HistSynchronizer; + template friend class BatchHistSynchronizer; + template friend class DistributedHistSynchronizer; + + template friend class HistRowsAdder; + template friend class BatchHistRowsAdder; + template friend class DistributedHistRowsAdder; + + CPUHistMakerTrainParam hist_maker_param_; // training parameter TrainParam param_; // quantized data matrix @@ -142,8 +180,11 @@ class QuantileHistMaker: public TreeUpdater { }; // actual builder that runs the algorithm + template struct Builder { public: + using GHistRowT = GHistRow; + using GradientPairT = xgboost::detail::GradientPairInternal; // constructor explicit Builder(const TrainParam& param, std::unique_ptr pruner, @@ -168,7 +209,7 @@ class QuantileHistMaker: public TreeUpdater { const RowSetCollection::Elem row_indices, const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, - GHistRow hist) { + GHistRowT hist) { if (param_.enable_feature_grouping > 0) { hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist); } else { @@ -176,7 +217,9 @@ class QuantileHistMaker: public TreeUpdater { } } - inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { + inline void SubtractionTrick(GHistRowT self, + GHistRowT sibling, + GHistRowT parent) { builder_monitor_.Start("SubtractionTrick"); hist_builder_.SubtractionTrick(self, sibling, parent); builder_monitor_.Stop("SubtractionTrick"); @@ -184,16 +227,17 @@ class QuantileHistMaker: public TreeUpdater { bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector* p_out_preds); - void SetHistSynchronizer(HistSynchronizer* sync); - void SetHistRowsAdder(HistRowsAdder* adder); + void SetHistSynchronizer(HistSynchronizer* sync); + void SetHistRowsAdder(HistRowsAdder* adder); protected: - friend class HistSynchronizer; - friend class BatchHistSynchronizer; - friend class DistributedHistSynchronizer; - friend class HistRowsAdder; - friend class BatchHistRowsAdder; - friend class DistributedHistRowsAdder; + friend class HistSynchronizer; + friend class BatchHistSynchronizer; + friend class DistributedHistSynchronizer; + friend class HistRowsAdder; + friend class BatchHistRowsAdder; + friend class DistributedHistRowsAdder; + /* tree growing policies */ struct ExpandEntry { static const int kRootNid = 0; @@ -225,13 +269,13 @@ class QuantileHistMaker: public TreeUpdater { void EvaluateSplits(const std::vector& nodes_set, const GHistIndexMatrix& gmat, - const HistCollection& hist, + const HistCollection& hist, const RegTree& tree); void ApplySplit(std::vector nodes, const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix, - const HistCollection& hist, + const HistCollection& hist, RegTree* p_tree); template @@ -255,7 +299,7 @@ class QuantileHistMaker: public TreeUpdater { // Returns the sum of gradients corresponding to the data points that contains a non-missing // value for the particular feature fid. template - GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRow &hist, + GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRowT &hist, const NodeEntry &snode, SplitEntry *p_best, bst_uint fid, bst_uint nodeID) const; @@ -345,16 +389,16 @@ class QuantileHistMaker: public TreeUpdater { /*! \brief TreeNode Data: statistics for each constructed node */ std::vector snode_; /*! \brief culmulative histogram of gradients. */ - HistCollection hist_; + HistCollection hist_; /*! \brief culmulative local parent histogram of gradients. */ - HistCollection hist_local_worker_; + HistCollection hist_local_worker_; /*! \brief feature with least # of bins. to be used for dense specialization of InitNewNode() */ uint32_t fid_least_bins_; /*! \brief local prediction cache; maps node id to leaf value */ std::vector leaf_value_cache_; - GHistBuilder hist_builder_; + GHistBuilder hist_builder_; std::unique_ptr pruner_; std::unique_ptr spliteval_; FeatureInteractionConstraintHost interaction_constraints_; @@ -382,61 +426,92 @@ class QuantileHistMaker: public TreeUpdater { DataLayout data_layout_; common::Monitor builder_monitor_; - common::ParallelGHistBuilder hist_buffer_; - rabit::Reducer histred_; - std::unique_ptr hist_synchronizer_; - std::unique_ptr hist_rows_adder_; + common::ParallelGHistBuilder hist_buffer_; + rabit::Reducer histred_; + std::unique_ptr> hist_synchronizer_; + std::unique_ptr> hist_rows_adder_; }; common::Monitor updater_monitor_; - std::unique_ptr builder_; + + template + void SetBuilder(std::unique_ptr>*, DMatrix *dmat); + + template + void CallBuilderUpdate(const std::unique_ptr>& builder, + HostDeviceVector *gpair, + DMatrix *dmat, + const std::vector &trees); + + protected: + std::unique_ptr> float_builder_; + std::unique_ptr> double_builder_; + std::unique_ptr pruner_; std::unique_ptr spliteval_; FeatureInteractionConstraintHost int_constraint_; }; +template class HistSynchronizer { public: - virtual void SyncHistograms(QuantileHistMaker::Builder* builder, + using BuilderT = QuantileHistMaker::Builder; + + virtual void SyncHistograms(BuilderT* builder, int starting_index, int sync_count, RegTree *p_tree) = 0; + virtual ~HistSynchronizer() = default; }; -class BatchHistSynchronizer: public HistSynchronizer { +template +class BatchHistSynchronizer: public HistSynchronizer { public: - void SyncHistograms(QuantileHistMaker::Builder* builder, + using BuilderT = QuantileHistMaker::Builder; + void SyncHistograms(BuilderT* builder, int starting_index, int sync_count, RegTree *p_tree) override; }; -class DistributedHistSynchronizer: public HistSynchronizer { +template +class DistributedHistSynchronizer: public HistSynchronizer { public: - void SyncHistograms(QuantileHistMaker::Builder* builder_, - int starting_index, int sync_count, RegTree *p_tree) override; + using BuilderT = QuantileHistMaker::Builder; + using ExpandEntryT = typename BuilderT::ExpandEntry; - void ParallelSubtractionHist(QuantileHistMaker::Builder* builder, + void SyncHistograms(BuilderT* builder, int starting_index, + int sync_count, RegTree *p_tree) override; + + void ParallelSubtractionHist(BuilderT* builder, const common::BlockedSpace2d& space, - const std::vector& nodes, + const std::vector& nodes, const RegTree * p_tree); }; +template class HistRowsAdder { public: - virtual void AddHistRows(QuantileHistMaker::Builder* builder, - int *starting_index, int *sync_count, RegTree *p_tree) = 0; + using BuilderT = QuantileHistMaker::Builder; + + virtual void AddHistRows(BuilderT* builder, int *starting_index, + int *sync_count, RegTree *p_tree) = 0; + virtual ~HistRowsAdder() = default; }; -class BatchHistRowsAdder: public HistRowsAdder { +template +class BatchHistRowsAdder: public HistRowsAdder { public: - void AddHistRows(QuantileHistMaker::Builder* builder, - int *starting_index, int *sync_count, RegTree *p_tree) override; + using BuilderT = QuantileHistMaker::Builder; + void AddHistRows(BuilderT*, int *starting_index, + int *sync_count, RegTree *p_tree) override; }; -class DistributedHistRowsAdder: public HistRowsAdder { +template +class DistributedHistRowsAdder: public HistRowsAdder { public: - void AddHistRows(QuantileHistMaker::Builder* builder, - int *starting_index, int *sync_count, RegTree *p_tree) override; + using BuilderT = QuantileHistMaker::Builder; + void AddHistRows(BuilderT*, int *starting_index, + int *sync_count, RegTree *p_tree) override; }; diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 35d4a357c..0924db8a6 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -20,8 +20,8 @@ size_t GetNThreads() { return nthreads; } - -TEST(ParallelGHistBuilder, Reset) { +template +void ParallelGHistBuilderReset() { constexpr size_t kBins = 10; constexpr size_t kNodes = 5; constexpr size_t kNodesExtended = 10; @@ -29,16 +29,16 @@ TEST(ParallelGHistBuilder, Reset) { constexpr double kValue = 1.0; const size_t nthreads = GetNThreads(); - HistCollection collection; + HistCollection collection; collection.Init(kBins); for(size_t inode = 0; inode < kNodesExtended; inode++) { collection.AddHistRow(inode); } - ParallelGHistBuilder hist_builder; + ParallelGHistBuilder hist_builder; hist_builder.Init(kBins); - std::vector target_hist(kNodes); + std::vector> target_hist(kNodes); for(size_t i = 0; i < target_hist.size(); ++i) { target_hist[i] = collection[i]; } @@ -49,7 +49,7 @@ TEST(ParallelGHistBuilder, Reset) { common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) { const size_t tid = omp_get_thread_num(); - GHistRow hist = hist_builder.GetInitializedHist(tid, inode); + GHistRow hist = hist_builder.GetInitializedHist(tid, inode); // fill hist by some non-null values for(size_t j = 0; j < kBins; ++j) { hist[j].Add(kValue, kValue); @@ -67,7 +67,7 @@ TEST(ParallelGHistBuilder, Reset) { common::ParallelFor2d(space2, nthreads, [&](size_t inode, common::Range1d r) { const size_t tid = omp_get_thread_num(); - GHistRow hist = hist_builder.GetInitializedHist(tid, inode); + GHistRow hist = hist_builder.GetInitializedHist(tid, inode); // fill hist by some non-null values for(size_t j = 0; j < kBins; ++j) { ASSERT_EQ(0.0, hist[j].GetGrad()); @@ -76,23 +76,25 @@ TEST(ParallelGHistBuilder, Reset) { }); } -TEST(ParallelGHistBuilder, ReduceHist) { + +template +void ParallelGHistBuilderReduceHist(){ constexpr size_t kBins = 10; constexpr size_t kNodes = 5; constexpr size_t kTasksPerNode = 10; constexpr double kValue = 1.0; const size_t nthreads = GetNThreads(); - HistCollection collection; + HistCollection collection; collection.Init(kBins); for(size_t inode = 0; inode < kNodes; inode++) { collection.AddHistRow(inode); } - ParallelGHistBuilder hist_builder; + ParallelGHistBuilder hist_builder; hist_builder.Init(kBins); - std::vector target_hist(kNodes); + std::vector> target_hist(kNodes); for(size_t i = 0; i < target_hist.size(); ++i) { target_hist[i] = collection[i]; } @@ -104,7 +106,7 @@ TEST(ParallelGHistBuilder, ReduceHist) { common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) { const size_t tid = omp_get_thread_num(); - GHistRow hist = hist_builder.GetInitializedHist(tid, inode); + GHistRow hist = hist_builder.GetInitializedHist(tid, inode); for(size_t i = 0; i < kBins; ++i) { hist[i].Add(kValue, kValue); } @@ -122,6 +124,21 @@ TEST(ParallelGHistBuilder, ReduceHist) { } } +TEST(ParallelGHistBuilder, ResetDouble) { + ParallelGHistBuilderReset(); +} + +TEST(ParallelGHistBuilder, ResetFloat) { + ParallelGHistBuilderReset(); +} + +TEST(ParallelGHistBuilder, ReduceHistDouble) { + ParallelGHistBuilderReduceHist(); +} + +TEST(ParallelGHistBuilder, ReduceHistFloat) { + ParallelGHistBuilderReduceHist(); +} TEST(CutsBuilder, SearchGroupInd) { size_t constexpr kNumGroups = 4; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 3e6131b49..1b6ab89e9 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -21,8 +21,11 @@ namespace tree { class QuantileHistMock : public QuantileHistMaker { static double constexpr kEps = 1e-6; - struct BuilderMock : public QuantileHistMaker::Builder { - using RealImpl = QuantileHistMaker::Builder; + template + struct BuilderMock : public QuantileHistMaker::Builder { + using RealImpl = QuantileHistMaker::Builder; + using ExpandEntryT = typename RealImpl::ExpandEntry; + using GHistRowT = typename RealImpl::GHistRowT; BuilderMock(const TrainParam& param, std::unique_ptr pruner, @@ -30,7 +33,7 @@ class QuantileHistMock : public QuantileHistMaker { FeatureInteractionConstraintHost int_constraint, DMatrix const* fmat) : RealImpl(param, std::move(pruner), std::move(spliteval), - std::move(int_constraint), fmat) {} + std::move(int_constraint), fmat) {} public: void TestInitData(const GHistIndexMatrix& gmat, @@ -38,7 +41,7 @@ class QuantileHistMock : public QuantileHistMaker { DMatrix* p_fmat, const RegTree& tree) { RealImpl::InitData(gmat, gpair, *p_fmat, tree); - ASSERT_EQ(data_layout_, kSparseData); + ASSERT_EQ(this->data_layout_, RealImpl::kSparseData); /* The creation of HistCutMatrix and GHistIndexMatrix are not technically * part of QuantileHist updater logic, but we include it here because @@ -105,14 +108,14 @@ class QuantileHistMock : public QuantileHistMaker { // save state of global rng engine auto initial_rnd = common::GlobalRandom(); RealImpl::InitData(gmat, gpair, *p_fmat, tree); - std::vector row_indices_initial = *row_set_collection_.Data(); + std::vector row_indices_initial = *(this->row_set_collection_.Data()); for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) { omp_set_num_threads(i_nthreads); // return initial state of global rng engine common::GlobalRandom() = initial_rnd; RealImpl::InitData(gmat, gpair, *p_fmat, tree); - std::vector& row_indices = *row_set_collection_.Data(); + std::vector& row_indices = *(this->row_set_collection_.Data()); ASSERT_EQ(row_indices_initial.size(), row_indices.size()); for (size_t i = 0; i < row_indices_initial.size(); ++i) { ASSERT_EQ(row_indices_initial[i], row_indices[i]); @@ -129,26 +132,26 @@ class QuantileHistMock : public QuantileHistMaker { int starting_index = std::numeric_limits::max(); int sync_count = 0; - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); + this->nodes_for_explicit_hist_build_.clear(); + this->nodes_for_subtraction_trick_.clear(); tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0); - nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0); - nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0); - nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0); + this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0); + this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0); + this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0); + this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0); - hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); + this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); ASSERT_EQ(sync_count, 2); ASSERT_EQ(starting_index, 3); - for (const ExpandEntry& node : nodes_for_explicit_hist_build_) { - ASSERT_EQ(hist_.RowExists(node.nid), true); + for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) { + ASSERT_EQ(this->hist_.RowExists(node.nid), true); } - for (const ExpandEntry& node : nodes_for_subtraction_trick_) { - ASSERT_EQ(hist_.RowExists(node.nid), true); + for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) { + ASSERT_EQ(this->hist_.RowExists(node.nid), true); } } @@ -162,60 +165,61 @@ class QuantileHistMock : public QuantileHistMaker { int starting_index = std::numeric_limits::max(); int sync_count = 0; - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); + this->nodes_for_explicit_hist_build_.clear(); + this->nodes_for_subtraction_trick_.clear(); // level 0 - nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0); - hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); + this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0); + this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); + this->nodes_for_explicit_hist_build_.clear(); + this->nodes_for_subtraction_trick_.clear(); // level 1 - nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(), (*tree)[0].RightChild(), + this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(), + (*tree)[0].RightChild(), tree->GetDepth(1), 0.0f, 0); - nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(), (*tree)[0].LeftChild(), + this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(), + (*tree)[0].LeftChild(), tree->GetDepth(2), 0.0f, 0); - hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); + this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); + this->nodes_for_explicit_hist_build_.clear(); + this->nodes_for_subtraction_trick_.clear(); // level 2 - nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0); - nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0); - nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0); - nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0); - hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); + this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0); + this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0); + this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0); + this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0); + this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree); - const size_t n_nodes = nodes_for_explicit_hist_build_.size(); + const size_t n_nodes = this->nodes_for_explicit_hist_build_.size(); ASSERT_EQ(n_nodes, 2); - row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(), + this->row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(), (*tree)[0].RightChild(), 4, 4); - row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(), + this->row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(), (*tree)[1].RightChild(), 2, 2); - row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(), + this->row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(), (*tree)[2].RightChild(), 2, 2); common::BlockedSpace2d space(n_nodes, [&](size_t node) { - const int32_t nid = nodes_for_explicit_hist_build_[node].nid; - return row_set_collection_[nid].Size(); + const int32_t nid = this->nodes_for_explicit_hist_build_[node].nid; + return this->row_set_collection_[nid].Size(); }, 256); - std::vector target_hists(n_nodes); - for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) { - const int32_t nid = nodes_for_explicit_hist_build_[i].nid; - target_hists[i] = hist_[nid]; + std::vector target_hists(n_nodes); + for (size_t i = 0; i < this->nodes_for_explicit_hist_build_.size(); ++i) { + const int32_t nid = this->nodes_for_explicit_hist_build_[i].nid; + target_hists[i] = this->hist_[nid]; } - const size_t nbins = hist_builder_.GetNumBins(); + const size_t nbins = this->hist_builder_.GetNumBins(); // set values to specific nodes hist std::vector n_ids = {1, 2}; for (size_t i : n_ids) { - auto this_hist = hist_[i]; - using FPType = decltype(tree::GradStats::sum_grad); - FPType* p_hist = reinterpret_cast(this_hist.data()); + auto this_hist = this->hist_[i]; + GradientSumT* p_hist = reinterpret_cast(this_hist.data()); for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) { p_hist[bin_id] = 2*bin_id; } @@ -223,41 +227,39 @@ class QuantileHistMock : public QuantileHistMaker { n_ids[0] = 3; n_ids[1] = 5; for (size_t i : n_ids) { - auto this_hist = hist_[i]; - using FPType = decltype(tree::GradStats::sum_grad); - FPType* p_hist = reinterpret_cast(this_hist.data()); + auto this_hist = this->hist_[i]; + GradientSumT* p_hist = reinterpret_cast(this_hist.data()); for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) { p_hist[bin_id] = bin_id; } } - hist_buffer_.Reset(1, n_nodes, space, target_hists); + this->hist_buffer_.Reset(1, n_nodes, space, target_hists); // sync hist - hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, tree); + this->hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, tree); - auto check_hist = [] (const GHistRow parent, const GHistRow left, - const GHistRow right, size_t begin, size_t end) { - using FPType = decltype(tree::GradStats::sum_grad); - const FPType* p_parent = reinterpret_cast(parent.data()); - const FPType* p_left = reinterpret_cast(left.data()); - const FPType* p_right = reinterpret_cast(right.data()); + auto check_hist = [] (const GHistRowT parent, const GHistRowT left, + const GHistRowT right, size_t begin, size_t end) { + const GradientSumT* p_parent = reinterpret_cast(parent.data()); + const GradientSumT* p_left = reinterpret_cast(left.data()); + const GradientSumT* p_right = reinterpret_cast(right.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]); } }; - for (const ExpandEntry& node : nodes_for_explicit_hist_build_) { - auto this_hist = hist_[node.nid]; + for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) { + auto this_hist = this->hist_[node.nid]; const size_t parent_id = (*tree)[node.nid].Parent(); - auto parent_hist = hist_[parent_id]; - auto sibling_hist = hist_[node.sibling_nid]; + auto parent_hist = this->hist_[parent_id]; + auto sibling_hist = this->hist_[node.sibling_nid]; check_hist(parent_hist, this_hist, sibling_hist, 0, nbins); } - for (const ExpandEntry& node : nodes_for_subtraction_trick_) { - auto this_hist = hist_[node.nid]; + for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) { + auto this_hist = this->hist_[node.nid]; const size_t parent_id = (*tree)[node.nid].Parent(); - auto parent_hist = hist_[parent_id]; - auto sibling_hist = hist_[node.sibling_nid]; + auto parent_hist = this->hist_[parent_id]; + auto sibling_hist = this->hist_[node.sibling_nid]; check_hist(parent_hist, this_hist, sibling_hist, 0, nbins); } @@ -272,13 +274,13 @@ class QuantileHistMock : public QuantileHistMaker { {0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f} }; RealImpl::InitData(gmat, gpair, fmat, tree); GHistIndexBlockMatrix dummy; - hist_.AddHistRow(nid); - BuildHist(gpair, row_set_collection_[nid], - gmat, dummy, hist_[nid]); + this->hist_.AddHistRow(nid); + this->BuildHist(gpair, this->row_set_collection_[nid], + gmat, dummy, this->hist_[nid]); // Check if number of histogram bins is correct - ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back()); - std::vector histogram_expected(hist_[nid].size()); + ASSERT_EQ(this->hist_[nid].size(), gmat.cut.Ptrs().back()); + std::vector histogram_expected(this->hist_[nid].size()); // Compute the correct histogram (histogram_expected) const size_t num_row = fmat.Info().num_row_; @@ -293,10 +295,10 @@ class QuantileHistMock : public QuantileHistMaker { } // Now validate the computed histogram returned by BuildHist - for (size_t i = 0; i < hist_[nid].size(); ++i) { + for (size_t i = 0; i < this->hist_[nid].size(); ++i) { GradientPairPrecise sol = histogram_expected[i]; - ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps); - ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps); + ASSERT_NEAR(sol.GetGrad(), this->hist_[nid][i].GetGrad(), kEps); + ASSERT_NEAR(sol.GetHess(), this->hist_[nid][i].GetHess(), kEps); } } @@ -313,10 +315,10 @@ class QuantileHistMock : public QuantileHistMaker { gmat.Init(dmat.get(), kMaxBins); RealImpl::InitData(gmat, row_gpairs, *dmat, tree); - hist_.AddHistRow(0); + this->hist_.AddHistRow(0); - BuildHist(row_gpairs, row_set_collection_[0], - gmat, quantile_index_block, hist_[0]); + this->BuildHist(row_gpairs, this->row_set_collection_[0], + gmat, quantile_index_block, this->hist_[0]); RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree); @@ -331,7 +333,7 @@ class QuantileHistMock : public QuantileHistMaker { } // Initialize split evaluator std::unique_ptr evaluator(SplitEvaluator::Create("elastic_net")); - evaluator->Init(¶m_); + evaluator->Init(&this->param_); // Now enumerate all feature*threshold combination to get best split // To simplify logic, we make some assumptions: @@ -378,11 +380,13 @@ class QuantileHistMock : public QuantileHistMaker { } /* Now compare against result given by EvaluateSplit() */ - ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid, - tree.GetDepth(0), snode_[0].best.loss_chg, 0); - RealImpl::EvaluateSplits({node}, gmat, hist_, tree); - ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature); - ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]); + typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid, + RealImpl::ExpandEntry::kEmptyNid, + tree.GetDepth(0), + this->snode_[0].best.loss_chg, 0); + RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree); + ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature); + ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]); } void TestEvaluateSplitParallel(const GHistIndexBlockMatrix &quantile_index_block, @@ -411,7 +415,7 @@ class QuantileHistMock : public QuantileHistMaker { // treat everything as dense, as this is what we intend to test here cm.Init(gmat, 0.0); RealImpl::InitData(gmat, row_gpairs, *dmat, tree); - hist_.AddHistRow(0); + this->hist_.AddHistRow(0); RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree); @@ -430,9 +434,9 @@ class QuantileHistMock : public QuantileHistMaker { const size_t bin_id = gmat.index[offset]; if (bin_id >= bin_id_min && bin_id < bin_id_max) { if (bin_id <= split) { - left_cnt ++; + left_cnt++; } else { - right_cnt ++; + right_cnt++; } } } @@ -450,7 +454,8 @@ class QuantileHistMock : public QuantileHistMaker { RealImpl::partition_builder_.Init(1, 1, [&](size_t node_in_set) { return 1; }); - RealImpl::PartitionKernel(0, 0, common::Range1d(0, kNRows), split, cm, tree); + this->template PartitionKernel(0, 0, common::Range1d(0, kNRows), + split, cm, tree); RealImpl::partition_builder_.CalculateRowOffsets(); ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt); ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt); @@ -462,28 +467,47 @@ class QuantileHistMock : public QuantileHistMaker { int static constexpr kNRows = 8, kNCols = 16; std::shared_ptr dmat_; const std::vector > cfg_; - std::shared_ptr builder_; + std::shared_ptr > float_builder_; + std::shared_ptr > double_builder_; public: explicit QuantileHistMock( - const std::vector >& args, bool batch = true) : + const std::vector >& args, + const bool single_precision_histogram = false, bool batch = true) : cfg_{args} { QuantileHistMaker::Configure(args); spliteval_->Init(¶m_); dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); - builder_.reset( - new BuilderMock( - param_, - std::move(pruner_), - std::unique_ptr(spliteval_->GetHostClone()), - int_constraint_, - dmat_.get())); - if (batch) { - builder_->SetHistSynchronizer(new BatchHistSynchronizer()); - builder_->SetHistRowsAdder(new BatchHistRowsAdder()); + if (single_precision_histogram) { + float_builder_.reset( + new BuilderMock( + param_, + std::move(pruner_), + std::unique_ptr(spliteval_->GetHostClone()), + int_constraint_, + dmat_.get())); + if (batch) { + float_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); + float_builder_->SetHistRowsAdder(new BatchHistRowsAdder()); + } else { + float_builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); + float_builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); + } } else { - builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); - builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); + double_builder_.reset( + new BuilderMock( + param_, + std::move(pruner_), + std::unique_ptr(spliteval_->GetHostClone()), + int_constraint_, + dmat_.get())); + if (batch) { + double_builder_->SetHistSynchronizer(new BatchHistSynchronizer()); + double_builder_->SetHistRowsAdder(new BatchHistRowsAdder()); + } else { + double_builder_->SetHistSynchronizer(new DistributedHistSynchronizer()); + double_builder_->SetHistRowsAdder(new DistributedHistRowsAdder()); + } } } ~QuantileHistMock() override = default; @@ -501,8 +525,11 @@ class QuantileHistMock : public QuantileHistMaker { std::vector gpair = { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; - - builder_->TestInitData(gmat, gpair, dmat_.get(), tree); + if (double_builder_) { + double_builder_->TestInitData(gmat, gpair, dmat_.get(), tree); + } else { + float_builder_->TestInitData(gmat, gpair, dmat_.get(), tree); + } } void TestInitDataSampling() { @@ -516,8 +543,11 @@ class QuantileHistMock : public QuantileHistMaker { std::vector gpair = { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; - - builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree); + if (double_builder_) { + double_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree); + } else { + float_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree); + } } void TestAddHistRows() { @@ -530,7 +560,11 @@ class QuantileHistMock : public QuantileHistMaker { std::vector gpair = { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; - builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree); + if (double_builder_) { + double_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree); + } else { + float_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree); + } } void TestSyncHistograms() { @@ -543,7 +577,11 @@ class QuantileHistMock : public QuantileHistMaker { std::vector gpair = { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; - builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree); + if (double_builder_) { + double_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree); + } else { + float_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree); + } } @@ -554,22 +592,31 @@ class QuantileHistMock : public QuantileHistMaker { size_t constexpr kMaxBins = 4; common::GHistIndexMatrix gmat; gmat.Init(dmat_.get(), kMaxBins); - - builder_->TestBuildHist(0, gmat, *dmat_, tree); + if (double_builder_) { + double_builder_->TestBuildHist(0, gmat, *dmat_, tree); + } else { + float_builder_->TestBuildHist(0, gmat, *dmat_, tree); + } } void TestEvaluateSplit() { RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); - - builder_->TestEvaluateSplit(gmatb_, tree); + if (double_builder_) { + double_builder_->TestEvaluateSplit(gmatb_, tree); + } else { + float_builder_->TestEvaluateSplit(gmatb_, tree); + } } void TestApplySplit() { RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); - - builder_->TestApplySplit(gmatb_, tree); + if (double_builder_) { + double_builder_->TestApplySplit(gmatb_, tree); + } else { + float_builder_->TestEvaluateSplit(gmatb_, tree); + } } }; @@ -578,6 +625,9 @@ TEST(QuantileHist, InitData) { {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}}; QuantileHistMock maker(cfg); maker.TestInitData(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestInitData(); } TEST(QuantileHist, InitDataSampling) { @@ -587,6 +637,9 @@ TEST(QuantileHist, InitDataSampling) { {"subsample", std::to_string(subsample)}}; QuantileHistMock maker(cfg); maker.TestInitDataSampling(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestInitDataSampling(); } TEST(QuantileHist, AddHistRows) { @@ -594,6 +647,9 @@ TEST(QuantileHist, AddHistRows) { {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}}; QuantileHistMock maker(cfg); maker.TestAddHistRows(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestAddHistRows(); } TEST(QuantileHist, SyncHistograms) { @@ -601,6 +657,9 @@ TEST(QuantileHist, SyncHistograms) { {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}}; QuantileHistMock maker(cfg); maker.TestSyncHistograms(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestSyncHistograms(); } TEST(QuantileHist, DistributedAddHistRows) { @@ -608,6 +667,9 @@ TEST(QuantileHist, DistributedAddHistRows) { {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}}; QuantileHistMock maker(cfg, false); maker.TestAddHistRows(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestAddHistRows(); } TEST(QuantileHist, DistributedSyncHistograms) { @@ -615,6 +677,9 @@ TEST(QuantileHist, DistributedSyncHistograms) { {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}}; QuantileHistMock maker(cfg, false); maker.TestSyncHistograms(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestSyncHistograms(); } TEST(QuantileHist, BuildHist) { @@ -624,6 +689,9 @@ TEST(QuantileHist, BuildHist) { {"enable_feature_grouping", std::to_string(0)}}; QuantileHistMock maker(cfg); maker.TestBuildHist(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestBuildHist(); } TEST(QuantileHist, EvalSplits) { @@ -634,6 +702,9 @@ TEST(QuantileHist, EvalSplits) { {"min_child_weight", "0"}}; QuantileHistMock maker(cfg); maker.TestEvaluateSplit(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestEvaluateSplit(); } TEST(QuantileHist, ApplySplit) { @@ -644,6 +715,9 @@ TEST(QuantileHist, ApplySplit) { {"min_child_weight", "0"}}; QuantileHistMock maker(cfg); maker.TestApplySplit(); + const bool single_precision_histogram = true; + QuantileHistMock maker_float(cfg, single_precision_histogram); + maker_float.TestApplySplit(); } } // namespace tree diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 6dc9c77b6..673cacc79 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -57,7 +57,8 @@ class TestUpdaters(unittest.TestCase): 'max_bin': [2, 256], 'grow_policy': ['depthwise', 'lossguide'], 'max_leaves': [64, 0], - 'verbosity': [0]} + 'verbosity': [0], + 'single_precision_histogram': [True, False]} for param in parameter_combinations(variable_param): result = run_suite(param) assert_results_non_increasing(result, 1e-2)