From 017c97b8ce62935429c797b085bf46c63b2be617 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 7 Feb 2019 14:22:13 +0800 Subject: [PATCH] Clean up training code. (#3825) * Remove GHistRow, GHistEntry, GHistIndexRow. * Remove kSimpleStats. * Remove CheckInfo, SetLeafVec in GradStats and in SKStats. * Clean up the GradStats. * Cleanup calcgain. * Move LossChangeMissing out of common. * Remove [] operator from GHistIndexBlock. --- src/common/device_helpers.cuh | 4 +- src/common/hist_util.cc | 39 ++++----- src/common/hist_util.h | 87 ++++---------------- src/common/host_device_vector.h | 4 +- src/common/span.h | 4 +- src/common/transform.h | 10 +-- src/tree/param.h | 117 ++++++++++++--------------- src/tree/split_evaluator.cc | 1 + src/tree/updater_basemaker-inl.h | 12 +-- src/tree/updater_colmaker.cc | 44 +++++----- src/tree/updater_gpu.cu | 25 +++++- src/tree/updater_gpu_common.cuh | 47 +---------- src/tree/updater_gpu_hist.cu | 30 ++++++- src/tree/updater_histmaker.cc | 114 ++++++++++++-------------- src/tree/updater_quantile_hist.cc | 102 +++++++++++------------ src/tree/updater_quantile_hist.h | 10 +-- src/tree/updater_refresh.cc | 31 ++++--- src/tree/updater_skmaker.cc | 25 ++---- tests/cpp/tree/test_quantile_hist.cc | 6 +- 19 files changed, 306 insertions(+), 406 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 72efcbdce..1fd8e2407 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1005,7 +1005,7 @@ class AllReducer { */ void Synchronize() { #ifdef XGBOOST_USE_NCCL - for (int i = 0; i < device_ordinals.size(); i++) { + for (size_t i = 0; i < device_ordinals.size(); i++) { dh::safe_cuda(cudaSetDevice(device_ordinals[i])); dh::safe_cuda(cudaStreamSynchronize(streams[i])); } @@ -1051,7 +1051,7 @@ template void ExecuteIndexShards(std::vector *shards, FunctionT f) { SaveCudaContext{[&]() { #pragma omp parallel for schedule(static, 1) if (shards->size() > 1) - for (int shard = 0; shard < shards->size(); ++shard) { + for (size_t shard = 0; shard < shards->size(); ++shard) { f(shard, shards->at(shard)); } }}; diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 1cca8179b..0bde67f7d 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017 by Contributors + * Copyright 2017-2018 by Contributors * \file hist_util.h * \brief Utilities to store histograms * \author Philip Cho, Tianqi Chen @@ -417,7 +417,7 @@ void GHistBuilder::BuildHist(const std::vector& gpair, const size_t* row_ptr = gmat.row_ptr.data(); const float* pgh = reinterpret_cast(gpair.data()); - double* hist_data = reinterpret_cast(hist.begin); + double* hist_data = reinterpret_cast(hist.data()); double* data = reinterpret_cast(data_.data()); const size_t block_size = 512; @@ -432,11 +432,11 @@ void GHistBuilder::BuildHist(const std::vector& gpair, size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; - #pragma omp parallel for num_threads(nthread_to_process) schedule(guided) +#pragma omp parallel for num_threads(nthread_to_process) schedule(guided) for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { dmlc::omp_uint tid = omp_get_thread_num(); double* data_local_hist = ((nthread_to_process == 1) ? hist_data : - reinterpret_cast(data_.data() + tid * nbins_)); + reinterpret_cast(data_.data() + tid * nbins_)); if (!thread_init_[tid]) { memset(data_local_hist, '\0', 2*nbins_*sizeof(double)); @@ -477,7 +477,7 @@ void GHistBuilder::BuildHist(const std::vector& gpair, } } - #pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided) +#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided) for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { const size_t istart = iblock * block_size; const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size); @@ -507,8 +507,9 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, #if defined(_OPENMP) const auto nthread = static_cast(this->nthread_); #endif + tree::GradStats* p_hist = hist.data(); - #pragma omp parallel for num_threads(nthread) schedule(guided) +#pragma omp parallel for num_threads(nthread) schedule(guided) for (bst_omp_uint bid = 0; bid < nblock; ++bid) { auto gmat = gmatb[bid]; @@ -517,20 +518,17 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, size_t ibegin[kUnroll]; size_t iend[kUnroll]; GradientPair stat[kUnroll]; + for (int k = 0; k < kUnroll; ++k) { rid[k] = row_indices.begin[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { ibegin[k] = gmat.row_ptr[rid[k]]; iend[k] = gmat.row_ptr[rid[k] + 1]; - } - for (int k = 0; k < kUnroll; ++k) { stat[k] = gpair[rid[k]]; } for (int k = 0; k < kUnroll; ++k) { for (size_t j = ibegin[k]; j < iend[k]; ++j) { const uint32_t bin = gmat.index[j]; - hist.begin[bin].Add(stat[k]); + p_hist[bin].Add(stat[k]); } } } @@ -541,7 +539,7 @@ void GHistBuilder::BuildBlockHist(const std::vector& gpair, const GradientPair stat = gpair[rid]; for (size_t j = ibegin; j < iend; ++j) { const uint32_t bin = gmat.index[j]; - hist.begin[bin].Add(stat); + p_hist[bin].Add(stat); } } } @@ -555,24 +553,27 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa #if defined(_OPENMP) const auto nthread = static_cast(this->nthread_); #endif + tree::GradStats* p_self = self.data(); + tree::GradStats* p_sibling = sibling.data(); + tree::GradStats* p_parent = parent.data(); - #pragma omp parallel for num_threads(nthread) schedule(static) +#pragma omp parallel for num_threads(nthread) schedule(static) for (bst_omp_uint bin_id = 0; bin_id < static_cast(nbins - rest); bin_id += kUnroll) { - GHistEntry pb[kUnroll]; - GHistEntry sb[kUnroll]; + tree::GradStats pb[kUnroll]; + tree::GradStats sb[kUnroll]; for (int k = 0; k < kUnroll; ++k) { - pb[k] = parent.begin[bin_id + k]; + pb[k] = p_parent[bin_id + k]; } for (int k = 0; k < kUnroll; ++k) { - sb[k] = sibling.begin[bin_id + k]; + sb[k] = p_sibling[bin_id + k]; } for (int k = 0; k < kUnroll; ++k) { - self.begin[bin_id + k].SetSubtract(pb[k], sb[k]); + p_self[bin_id + k].SetSubstract(pb[k], sb[k]); } } for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) { - self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]); + p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]); } } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index ff5542768..64bf3ec0d 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -16,45 +16,8 @@ #include "../include/rabit/rabit.h" namespace xgboost { - namespace common { -/*! \brief sums of gradient statistics corresponding to a histogram bin */ -struct GHistEntry { - /*! \brief sum of first-order gradient statistics */ - double sum_grad{0}; - /*! \brief sum of second-order gradient statistics */ - double sum_hess{0}; - - GHistEntry() = default; - - inline void Clear() { - sum_grad = sum_hess = 0; - } - - /*! \brief add a GradientPair to the sum */ - inline void Add(const GradientPair& e) { - sum_grad += e.GetGrad(); - sum_hess += e.GetHess(); - } - - /*! \brief add a GHistEntry to the sum */ - inline void Add(const GHistEntry& e) { - sum_grad += e.sum_grad; - sum_hess += e.sum_hess; - } - - inline static void Reduce(GHistEntry& a, const GHistEntry& b) { // NOLINT(*) - a.Add(b); - } - - /*! \brief set sum to be difference of two GHistEntry's */ - inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) { - sum_grad = a.sum_grad - b.sum_grad; - sum_hess = a.sum_hess - b.sum_hess; - } -}; - /*! \brief Cut configuration for all the features. */ struct HistCutMatrix { /*! \brief Unit pointer to rows by element position */ @@ -83,15 +46,7 @@ void DeviceSketch * \brief A single row in global histogram index. * Directly represent the global index in the histogram entry. */ -struct GHistIndexRow { - /*! \brief The index of the histogram */ - const uint32_t* index; - /*! \brief The size of the histogram */ - size_t size; - GHistIndexRow() = default; - GHistIndexRow(const uint32_t* index, size_t size) - : index(index), size(size) {} -}; +using GHistIndexRow = Span; /*! * \brief preprocessed global index matrix, in CSR format @@ -111,7 +66,9 @@ struct GHistIndexMatrix { void Init(DMatrix* p_fmat, int max_num_bins); // get i-th row inline GHistIndexRow operator[](size_t i) const { - return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]}; + return {&index[0] + row_ptr[i], + static_cast( + row_ptr[i + 1] - row_ptr[i])}; } inline void GetFeatureCounts(size_t* counts) const { auto nfeature = cut.row_ptr.size() - 1; @@ -134,11 +91,6 @@ struct GHistIndexBlock { inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index) : row_ptr(row_ptr), index(index) {} - - // get i-th row - inline GHistIndexRow operator[](size_t i) const { - return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]}; - } }; class ColumnMatrix; @@ -171,21 +123,12 @@ class GHistIndexBlockMatrix { }; /*! - * \brief histogram of gradient statistics for a single node. - * Consists of multiple GHistEntry's, each entry showing total graident statistics + * \brief histogram of graident statistics for a single node. + * Consists of multiple GradStats, each entry showing total graident statistics * for that particular bin * Uses global bin id so as to represent all features simultaneously */ -struct GHistRow { - /*! \brief base pointer to first entry */ - GHistEntry* begin; - /*! \brief number of entries */ - uint32_t size; - - GHistRow() = default; - GHistRow(GHistEntry* begin, uint32_t size) - : begin(begin), size(size) {} -}; +using GHistRow = Span; /*! * \brief histogram of gradient statistics for multiple nodes @@ -193,27 +136,29 @@ struct GHistRow { class HistCollection { public: // access histogram for i-th node - inline GHistRow operator[](bst_uint nid) const { + GHistRow operator[](bst_uint nid) const { constexpr uint32_t kMax = std::numeric_limits::max(); CHECK_NE(row_ptr_[nid], kMax); - return {const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_}; + tree::GradStats* ptr = + const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]); + return {ptr, nbins_}; } // have we computed a histogram for i-th node? - inline bool RowExists(bst_uint nid) const { + bool RowExists(bst_uint nid) const { const uint32_t k_max = std::numeric_limits::max(); return (nid < row_ptr_.size() && row_ptr_[nid] != k_max); } // initialize histogram collection - inline void Init(uint32_t nbins) { + void Init(uint32_t nbins) { nbins_ = nbins; row_ptr_.clear(); data_.clear(); } // create an empty histogram for i-th node - inline void AddHistRow(bst_uint nid) { + void AddHistRow(bst_uint nid) { constexpr uint32_t kMax = std::numeric_limits::max(); if (nid >= row_ptr_.size()) { row_ptr_.resize(nid + 1, kMax); @@ -228,7 +173,7 @@ class HistCollection { /*! \brief number of all bins over all features */ uint32_t nbins_; - std::vector data_; + std::vector data_; /*! \brief row_ptr_[nid] locates bin for historgram of node nid */ std::vector row_ptr_; @@ -268,8 +213,8 @@ class GHistBuilder { size_t nthread_; /*! \brief number of all bins over all features */ uint32_t nbins_; - std::vector data_; std::vector thread_init_; + std::vector data_; }; diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index d1abb604b..c2ae0110a 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -140,7 +140,7 @@ class GPUDistribution { return begin; } - size_t ShardSize(size_t size, int index) const { + size_t ShardSize(size_t size, size_t index) const { if (size == 0) { return 0; } if (offsets_.size() > 0) { // explicit offsets are provided @@ -154,7 +154,7 @@ class GPUDistribution { return end - begin; } - size_t ShardProperSize(size_t size, int index) const { + size_t ShardProperSize(size_t size, size_t index) const { if (size == 0) { return 0; } return ShardSize(size, index) - (devices_.Size() - 1 > index ? overlap_ : 0); } diff --git a/src/common/span.h b/src/common/span.h index cb49b84cd..b62162552 100644 --- a/src/common/span.h +++ b/src/common/span.h @@ -554,8 +554,8 @@ class Span { detail::ptrdiff_t _offset, detail::ptrdiff_t _count = dynamic_extent) const { SPAN_CHECK(_offset >= 0 && _offset < size()); - SPAN_CHECK(_count == dynamic_extent || - _count >= 0 && _offset + _count <= size()); + SPAN_CHECK((_count == dynamic_extent) || + (_count >= 0 && _offset + _count <= size())); return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; diff --git a/src/common/transform.h b/src/common/transform.h index ce452814e..37a9236f8 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -58,12 +58,12 @@ class Transform { public: Evaluator(Functor func, Range range, GPUSet devices, bool reshard) : func_(func), range_{std::move(range)}, - distribution_{std::move(GPUDistribution::Block(devices))}, - reshard_{reshard} {} + reshard_{reshard}, + distribution_{std::move(GPUDistribution::Block(devices))} {} Evaluator(Functor func, Range range, GPUDistribution dist, bool reshard) : - func_(func), range_{std::move(range)}, distribution_{std::move(dist)}, - reshard_{reshard} {} + func_(func), range_{std::move(range)}, reshard_{reshard}, + distribution_{std::move(dist)} {} /*! * \brief Evaluate the functor with input pointers to HostDeviceVector. @@ -159,7 +159,7 @@ class Transform { template void LaunchCPU(Functor func, HDV*... vectors) const { - auto end = *(range_.end()); + omp_ulong end = static_cast(*(range_.end())); #pragma omp parallel for schedule(static) for (omp_ulong idx = 0; idx < end; ++idx) { func(idx, UnpackHDV(vectors)...); diff --git a/src/tree/param.h b/src/tree/param.h index c55543a79..0e0ce4911 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -256,12 +256,12 @@ struct TrainParam : public dmlc::Parameter { // functions for L1 cost template -XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) { - if (w > +lambda) { - return w - lambda; +XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 alpha) { + if (w > + alpha) { + return w - alpha; } - if (w < -lambda) { - return w + lambda; + if (w < - alpha) { + return w + alpha; } return 0.0; } @@ -271,9 +271,9 @@ XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; } // calculate the cost of loss function template -XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad, - T sum_hess, T w) { - return -(2.0 * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w)); +XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, + T sum_grad, T sum_hess, T w) { + return -(T(2.0) * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w)); } // calculate the cost of loss function @@ -281,44 +281,51 @@ template XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) { if (sum_hess < p.min_child_weight) { return T(0.0); -} + } if (p.max_delta_step == 0.0f) { if (p.reg_alpha == 0.0f) { return Sqr(sum_grad) / (sum_hess + p.reg_lambda); } else { return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) / - (sum_hess + p.reg_lambda); + (sum_hess + p.reg_lambda); } } else { T w = CalcWeight(p, sum_grad, sum_hess); - T ret = sum_grad * w + T(0.5) * (sum_hess + p.reg_lambda) * Sqr(w); + T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w); if (p.reg_alpha == 0.0f) { - return T(-2.0) * ret; + return ret; } else { - return T(-2.0) * (ret + p.reg_alpha * std::abs(w)); + return ret + p.reg_alpha * std::abs(w); } } } + +template +XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) { + return CalcGain(p, stat.GetGrad(), stat.GetHess()); +} + // calculate cost of loss function with four statistics template XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess, - T test_grad, T test_hess) { + T test_grad, T test_hess) { T w = CalcWeight(sum_grad, sum_hess); - T ret = test_grad * w + 0.5 * (test_hess + p.reg_lambda) * Sqr(w); + T ret = CalcGainGivenWeight(p, test_grad, test_hess); if (p.reg_alpha == 0.0f) { - return -2.0 * ret; + return ret; } else { - return -2.0 * (ret + p.reg_alpha * std::abs(w)); + return ret + p.reg_alpha * std::abs(w); } } // calculate weight given the statistics template XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, - T sum_hess) { + T sum_hess) { if (sum_hess < p.min_child_weight || sum_hess <= 0.0) { return 0.0; -} + } T dw; if (p.reg_alpha == 0.0f) { dw = -sum_grad / (sum_hess + p.reg_lambda); @@ -328,14 +335,15 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, if (p.max_delta_step != 0.0f) { if (dw > p.max_delta_step) { dw = p.max_delta_step; -} + } if (dw < -p.max_delta_step) { dw = -p.max_delta_step; -} + } } return dw; } +// Used in gpu code where GradientPair is used for gradient sum, not GradStats. template XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) { return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess()); @@ -347,49 +355,27 @@ struct XGBOOST_ALIGNAS(16) GradStats { double sum_grad; /*! \brief sum hessian statistics */ double sum_hess; - /*! - * \brief whether this is simply statistics and we only need to call - * Add(gpair), instead of Add(gpair, info, ridx) - */ - static const int kSimpleStats = 1; - /*! \brief constructor, the object must be cleared during construction */ - explicit GradStats(const TrainParam& param) { this->Clear(); } - explicit GradStats(double sum_grad, double sum_hess) - : sum_grad(sum_grad), sum_hess(sum_hess) {} + + public: + XGBOOST_DEVICE double GetGrad() const { return sum_grad; } + XGBOOST_DEVICE double GetHess() const { return sum_hess; } + + XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} { + static_assert(sizeof(GradStats) == 16, + "Size of GradStats is not 16 bytes."); + } template XGBOOST_DEVICE explicit GradStats(const GpairT &sum) : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} - /*! \brief clear the statistics */ - inline void Clear() { sum_grad = sum_hess = 0.0f; } - /*! \brief check if necessary information is ready */ - inline static void CheckInfo(const MetaInfo& info) {} + explicit GradStats(const double grad, const double hess) + : sum_grad(grad), sum_hess(hess) {} /*! * \brief accumulate statistics * \param p the gradient pair */ inline void Add(GradientPair p) { this->Add(p.GetGrad(), p.GetHess()); } - /*! - * \brief accumulate statistics, more complicated version - * \param gpair the vector storing the gradient statistics - * \param info the additional information - * \param ridx instance index of this instance - */ - inline void Add(const std::vector& gpair, const MetaInfo& info, - bst_uint ridx) { - const GradientPair& b = gpair[ridx]; - this->Add(b.GetGrad(), b.GetHess()); - } - /*! \brief calculate leaf weight */ - template - XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m) const { - return xgboost::tree::CalcWeight(param, sum_grad, sum_hess); - } - /*! \brief calculate gain of the solution */ -template - inline double CalcGain(const ParamT& param) const { - return xgboost::tree::CalcGain(param, sum_grad, sum_hess); - } + /*! \brief add statistics to the data */ inline void Add(const GradStats& b) { sum_grad += b.sum_grad; @@ -406,8 +392,6 @@ template } /*! \return whether the statistics is not used yet */ inline bool Empty() const { return sum_hess == 0.0; } - // constructor to allow inheritance - GradStats() = default; /*! \brief add statistics to the data */ inline void Add(double grad, double hess) { sum_grad += grad; @@ -415,6 +399,7 @@ template } }; +// TODO(trivialfis): Remove this class. struct ValueConstraint { double lower_bound; double upper_bound; @@ -424,9 +409,9 @@ struct ValueConstraint { inline static void Init(TrainParam *param, unsigned num_feature) { param->monotone_constraints.resize(num_feature, 0); } -template + template XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m, GradStats stats) const { - double w = stats.CalcWeight(param); + double w = xgboost::tree::CalcWeight(param, stats); if (w < lower_bound) { return lower_bound; } @@ -436,13 +421,13 @@ template return w; } -template + template XGBOOST_DEVICE inline double CalcGain(const ParamT ¶m, GradStats stats) const { return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess, CalcWeight(param, stats)); } -template + template XGBOOST_DEVICE inline double CalcSplitGain(const ParamT ¶m, int constraint, GradStats left, GradStats right) const { const double negative_infinity = -std::numeric_limits::infinity(); @@ -468,7 +453,7 @@ template *cright = *this; if (c == 0) { return; -} + } double wleft = CalcWeight(param, left); double wright = CalcWeight(param, right); double mid = (wleft + wright) / 2; @@ -578,13 +563,13 @@ inline std::ostream &operator<<(std::ostream &os, const std::vector &t) { for (auto it = t.begin(); it != t.end(); ++it) { if (it != t.begin()) { os << ','; -} + } os << *it; } // python style tuple if (t.size() == 1) { os << ','; -} + } os << ')'; return os; } @@ -603,7 +588,7 @@ inline std::istream &operator>>(std::istream &is, std::vector &t) { is.get(); if (ch == '(') { break; -} + } if (!isspace(ch)) { is.setstate(std::ios::failbit); return is; @@ -635,7 +620,7 @@ inline std::istream &operator>>(std::istream &is, std::vector &t) { } if (ch == ')') { break; -} + } } else if (ch == ')') { break; } else { diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index 8b67ab107..52a022e64 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -50,6 +50,7 @@ void SplitEvaluator::AddSplit(bst_uint nodeid, bst_uint featureid, bst_float leftweight, bst_float rightweight) {} + bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid, bst_uint featureid, const GradStats& left_stats, diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 2397f249d..1a8238e75 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -333,28 +333,28 @@ class BaseMaker: public TreeUpdater { const MetaInfo &info = fmat.Info(); thread_temp.resize(omp_get_max_threads()); p_node_stats->resize(tree.param.num_nodes); - #pragma omp parallel +#pragma omp parallel { const int tid = omp_get_thread_num(); - thread_temp[tid].resize(tree.param.num_nodes, TStats(param_)); + thread_temp[tid].resize(tree.param.num_nodes, TStats()); for (unsigned int nid : qexpand_) { - thread_temp[tid][nid].Clear(); + thread_temp[tid][nid] = TStats(); } } // setup position const auto ndata = static_cast(fmat.Info().num_row_); - #pragma omp parallel for schedule(static) +#pragma omp parallel for schedule(static) for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { const int nid = position_[ridx]; const int tid = omp_get_thread_num(); if (nid >= 0) { - thread_temp[tid][nid].Add(gpair, info, ridx); + thread_temp[tid][nid].Add(gpair[ridx]); } } // sum the per thread statistics together for (int nid : qexpand_) { TStats &s = (*p_node_stats)[nid]; - s.Clear(); + s = TStats(); for (size_t tid = 0; tid < thread_temp.size(); ++tid) { s.Add(thread_temp[tid][nid]); } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index d03fdcefb..0e1e43620 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -33,7 +33,6 @@ class ColMaker: public TreeUpdater { void Update(HostDeviceVector *gpair, DMatrix* dmat, const std::vector &trees) override { - GradStats::CheckInfo(dmat->Info()); // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); @@ -66,9 +65,7 @@ class ColMaker: public TreeUpdater { /*! \brief current best solution */ SplitEntry best; // constructor - explicit ThreadEntry(const TrainParam ¶m) - : stats(param), stats_extra(param) { - } + ThreadEntry() : last_fvalue{0}, first_fvalue{0} {} }; struct NodeEntry { /*! \brief statics for node entry */ @@ -80,9 +77,7 @@ class ColMaker: public TreeUpdater { /*! \brief current best solution */ SplitEntry best; // constructor - explicit NodeEntry(const TrainParam& param) - : stats(param), root_gain(0.0f), weight(0.0f){ - } + NodeEntry() : root_gain(0.0f), weight(0.0f) {} }; // actual builder that runs the algorithm class Builder { @@ -200,9 +195,9 @@ class ColMaker: public TreeUpdater { { // setup statistics space for each tree node for (auto& i : stemp_) { - i.resize(tree.param.num_nodes, ThreadEntry(param_)); + i.resize(tree.param.num_nodes, ThreadEntry()); } - snode_.resize(tree.param.num_nodes, NodeEntry(param_)); + snode_.resize(tree.param.num_nodes, NodeEntry()); } const MetaInfo& info = fmat.Info(); // setup position @@ -211,11 +206,11 @@ class ColMaker: public TreeUpdater { for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { const int tid = omp_get_thread_num(); if (position_[ridx] < 0) continue; - stemp_[tid][position_[ridx]].stats.Add(gpair, info, ridx); + stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]); } // sum the per thread statistics together for (int nid : qexpand) { - GradStats stats(param_); + GradStats stats; for (auto& s : stemp_) { stats.Add(s[nid].stats); } @@ -261,7 +256,7 @@ class ColMaker: public TreeUpdater { std::vector &temp = stemp_[tid]; // cleanup temp statistics for (int j : qexpand) { - temp[j].stats.Clear(); + temp[j].stats = GradStats(); } bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_; bst_uint end = std::min(static_cast(col.size()), step * (tid + 1)); @@ -273,7 +268,7 @@ class ColMaker: public TreeUpdater { if (temp[nid].stats.Empty()) { temp[nid].first_fvalue = fvalue; } - temp[nid].stats.Add(gpair, info, ridx); + temp[nid].stats.Add(gpair[ridx]); temp[nid].last_fvalue = fvalue; } } @@ -282,7 +277,7 @@ class ColMaker: public TreeUpdater { #pragma omp parallel for schedule(static) for (bst_omp_uint j = 0; j < nnode; ++j) { const int nid = qexpand[j]; - GradStats sum(param_), tmp(param_), c(param_); + GradStats sum, tmp, c; for (int tid = 0; tid < this->nthread_; ++tid) { tmp = stemp_[tid][nid].stats; stemp_[tid][nid].stats = sum; @@ -342,7 +337,7 @@ class ColMaker: public TreeUpdater { // rescan, generate candidate split #pragma omp parallel { - GradStats c(param_), cright(param_); + GradStats c, cright; const int tid = omp_get_thread_num(); std::vector &temp = stemp_[tid]; bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_; @@ -355,7 +350,7 @@ class ColMaker: public TreeUpdater { // get the statistics of nid ThreadEntry &e = temp[nid]; if (e.stats.Empty()) { - e.stats.Add(gpair, info, ridx); + e.stats.Add(gpair[ridx]); e.first_fvalue = fvalue; } else { // forward default right @@ -383,7 +378,7 @@ class ColMaker: public TreeUpdater { } } } - e.stats.Add(gpair, info, ridx); + e.stats.Add(gpair[ridx]); e.first_fvalue = fvalue; } } @@ -436,10 +431,10 @@ class ColMaker: public TreeUpdater { const std::vector &qexpand = qexpand_; // clear all the temp statistics for (auto nid : qexpand) { - temp[nid].stats.Clear(); + temp[nid].stats = GradStats(); } // left statistics - GradStats c(param_); + GradStats c; // local cache buffer for position and gradient pair constexpr int kBuffer = 32; int buf_position[kBuffer] = {}; @@ -516,17 +511,17 @@ class ColMaker: public TreeUpdater { const MetaInfo &info, std::vector &temp) { // NOLINT(*) // use cacheline aware optimization - if (GradStats::kSimpleStats != 0 && param_.cache_opt != 0) { + if (param_.cache_opt != 0) { EnumerateSplitCacheOpt(begin, end, d_step, fid, gpair, temp); return; } const std::vector &qexpand = qexpand_; // clear all the temp statistics for (auto nid : qexpand) { - temp[nid].stats.Clear(); + temp[nid].stats = GradStats(); } // left statistics - GradStats c(param_); + GradStats c; for (const Entry *it = begin; it != end; it += d_step) { const bst_uint ridx = it->index; const int nid = position_[ridx]; @@ -537,7 +532,7 @@ class ColMaker: public TreeUpdater { ThreadEntry &e = temp[nid]; // test if first hit, this is fine, because we set 0 during init if (e.stats.Empty()) { - e.stats.Add(gpair, info, ridx); + e.stats.Add(gpair[ridx]); e.last_fvalue = fvalue; } else { // try to find a split @@ -562,7 +557,7 @@ class ColMaker: public TreeUpdater { } } // update the statistics - e.stats.Add(gpair, info, ridx); + e.stats.Add(gpair[ridx]); e.last_fvalue = fvalue; } } @@ -783,7 +778,6 @@ class DistColMaker : public ColMaker { void Update(HostDeviceVector *gpair, DMatrix* dmat, const std::vector &trees) override { - GradStats::CheckInfo(dmat->Info()); CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time"; Builder builder( param_, diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index 7625afb77..cfbefa89e 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -16,6 +16,28 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_gpu); +template +XGBOOST_DEVICE float inline LossChangeMissing(const GradientPairT& scan, + const GradientPairT& missing, + const GradientPairT& parent_sum, + const float& parent_gain, + const GPUTrainingParam& param, + bool& missing_left_out) { // NOLINT + // Put gradients of missing values to left + float missing_left_loss = + DeviceCalcLossChange(param, scan + missing, parent_sum, parent_gain); + float missing_right_loss = + DeviceCalcLossChange(param, scan, parent_sum, parent_gain); + + if (missing_left_loss >= missing_right_loss) { + missing_left_out = true; + return missing_left_loss; + } else { + missing_left_out = false; + return missing_right_loss; + } +} + /** * @brief Absolute BFS order IDs to col-wise unique IDs based on user input * @param tid the index of the element that this thread should access @@ -565,7 +587,6 @@ class GPUMaker : public TreeUpdater { void Update(HostDeviceVector* gpair, DMatrix* dmat, const std::vector& trees) { - GradStats::CheckInfo(dmat->Info()); // rescale learning rate according to size of trees float lr = param.learning_rate; param.learning_rate = lr / trees.size(); @@ -633,7 +654,7 @@ class GPUMaker : public TreeUpdater { // get the default direction for the current node GradientPair missing = n.sum_gradients - gradSum; LossChangeMissing(gradScan, missing, n.sum_gradients, n.root_gain, - gpu_param, missingLeft); + gpu_param, missingLeft); // get the score/weight/id/gradSum for left and right child nodes GradientPair lGradSum = missingLeft ? gradScan + missing : gradScan; GradientPair rGradSum = n.sum_gradients - lGradSum; diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 63c886b5e..11b9b01d4 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -98,7 +98,7 @@ struct DeviceSplitCandidate { template XGBOOST_DEVICE void Update(const DeviceSplitCandidate& other, - const ParamT& param) { + const ParamT& param) { if (other.loss_chg > loss_chg && other.left_sum.GetHess() >= param.min_child_weight && other.right_sum.GetHess() >= param.min_child_weight) { @@ -213,51 +213,6 @@ XGBOOST_DEVICE inline float DeviceCalcLossChange(const GPUTrainingParam& param, return left_gain + right_gain - parent_gain; } -// Without constraints -template -XGBOOST_DEVICE float inline LossChangeMissing(const GradientPairT& scan, - const GradientPairT& missing, - const GradientPairT& parent_sum, - const float& parent_gain, - const GPUTrainingParam& param, - bool& missing_left_out) { // NOLINT - // Put gradients of missing values to left - float missing_left_loss = - DeviceCalcLossChange(param, scan + missing, parent_sum, parent_gain); - float missing_right_loss = - DeviceCalcLossChange(param, scan, parent_sum, parent_gain); - - if (missing_left_loss >= missing_right_loss) { - missing_left_out = true; - return missing_left_loss; - } else { - missing_left_out = false; - return missing_right_loss; - } -} - -// With constraints -template -XGBOOST_DEVICE float inline LossChangeMissing( - const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum, - const float& parent_gain, const GPUTrainingParam& param, int constraint, - const ValueConstraint& value_constraint, - bool& missing_left_out) { // NOLINT - float missing_left_gain = value_constraint.CalcSplitGain( - param, constraint, GradStats(scan + missing), - GradStats(parent_sum - (scan + missing))); - float missing_right_gain = value_constraint.CalcSplitGain( - param, constraint, GradStats(scan), GradStats(parent_sum - scan)); - - if (missing_left_gain >= missing_right_gain) { - missing_left_out = true; - return missing_left_gain - parent_gain; - } else { - missing_left_out = false; - return missing_right_gain - parent_gain; - } -} - // Total number of nodes in tree, given depth XGBOOST_DEVICE inline int MaxNodesDepth(int depth) { return (1 << (depth + 1)) - 1; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 831aa11b8..95a5008ba 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -50,6 +50,28 @@ struct GPUHistMakerTrainParam DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); +// With constraints +template +XGBOOST_DEVICE float inline LossChangeMissing( + const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum, + const float& parent_gain, const GPUTrainingParam& param, int constraint, + const ValueConstraint& value_constraint, + bool& missing_left_out) { // NOLINT + float missing_left_gain = value_constraint.CalcSplitGain( + param, constraint, GradStats(scan + missing), + GradStats(parent_sum - (scan + missing))); + float missing_right_gain = value_constraint.CalcSplitGain( + param, constraint, GradStats(scan), GradStats(parent_sum - scan)); + + if (missing_left_gain >= missing_right_gain) { + missing_left_out = true; + return missing_left_gain - parent_gain; + } else { + missing_left_out = false; + return missing_right_gain - parent_gain; + } +} + /*! * \brief * @@ -942,7 +964,6 @@ class GPUHistMakerSpecialised{ void Update(HostDeviceVector* gpair, DMatrix* dmat, const std::vector& trees) { monitor_.Start("Update", dist_.Devices()); - GradStats::CheckInfo(dmat->Info()); // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); @@ -1183,11 +1204,12 @@ class GPUHistMakerSpecialised{ void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { RegTree& tree = *p_tree; - GradStats left_stats(param_); + + GradStats left_stats; left_stats.Add(candidate.split.left_sum); - GradStats right_stats(param_); + GradStats right_stats; right_stats.Add(candidate.split.right_sum); - GradStats parent_sum(param_); + GradStats parent_sum; parent_sum.Add(left_stats); parent_sum.Add(right_stats); node_value_constraints_.resize(tree.GetNodes().size()); diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 706034dcc..f863e8552 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -19,13 +19,11 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_histmaker); -template class HistMaker: public BaseMaker { public: void Update(HostDeviceVector *gpair, DMatrix *p_fmat, const std::vector &trees) override { - TStats::CheckInfo(p_fmat->Info()); // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); @@ -42,13 +40,13 @@ class HistMaker: public BaseMaker { /*! \brief cutting point of histogram, contains maximum point */ const bst_float *cut; /*! \brief content of statistics data */ - TStats *data; + GradStats *data; /*! \brief size of histogram */ unsigned size; // default constructor HistUnit() = default; // constructor - HistUnit(const bst_float *cut, TStats *data, unsigned size) + HistUnit(const bst_float *cut, GradStats *data, unsigned size) : cut(cut), data(data), size(size) {} /*! \brief add a histogram to data */ inline void Add(bst_float fv, @@ -58,7 +56,7 @@ class HistMaker: public BaseMaker { unsigned i = std::upper_bound(cut, cut + size, fv) - cut; CHECK_NE(size, 0U) << "try insert into size=0"; CHECK_LT(i, size); - data[i].Add(gpair, info, ridx); + data[i].Add(gpair[ridx]); } }; /*! \brief a set of histograms from different index */ @@ -68,7 +66,7 @@ class HistMaker: public BaseMaker { /*! \brief cutting points in each histunit */ const bst_float *cut; /*! \brief data in different hist unit */ - std::vector data; + std::vector data; /*! \brief */ inline HistUnit operator[](size_t fid) { return HistUnit(cut + rptr[fid], @@ -89,12 +87,10 @@ class HistMaker: public BaseMaker { hset.resize(nthread); // cleanup statistics for (int tid = 0; tid < nthread; ++tid) { - for (size_t i = 0; i < hset[tid].data.size(); ++i) { - hset[tid].data[i].Clear(); - } + for (auto& d : hset[tid].data) { d = GradStats(); } hset[tid].rptr = dmlc::BeginPtr(rptr); hset[tid].cut = dmlc::BeginPtr(cut); - hset[tid].data.resize(cut.size(), TStats(param)); + hset[tid].data.resize(cut.size(), GradStats()); } } // aggregate all statistics to hset[0] @@ -119,7 +115,7 @@ class HistMaker: public BaseMaker { // workspace of thread ThreadWSpace wspace_; // reducer for histogram - rabit::Reducer histred_; + rabit::Reducer histred_; // set of working features std::vector fwork_set_; // update function implementation @@ -147,8 +143,7 @@ class HistMaker: public BaseMaker { // if nothing left to be expand, break if (qexpand_.size() == 0) break; } - for (size_t i = 0; i < qexpand_.size(); ++i) { - const int nid = qexpand_[i]; + for (int const nid : qexpand_) { (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate); } } @@ -179,34 +174,35 @@ class HistMaker: public BaseMaker { private: inline void EnumerateSplit(const HistUnit &hist, - const TStats &node_sum, + const GradStats &node_sum, bst_uint fid, SplitEntry *best, - TStats *left_sum) { + GradStats *left_sum) { if (hist.size == 0) return; - double root_gain = node_sum.CalcGain(param_); - TStats s(param_), c(param_); + double root_gain = CalcGain(param_, node_sum.GetGrad(), node_sum.GetHess()); + GradStats s, c; for (bst_uint i = 0; i < hist.size; ++i) { s.Add(hist.data[i]); if (s.sum_hess >= param_.min_child_weight) { c.SetSubstract(node_sum, s); if (c.sum_hess >= param_.min_child_weight) { - double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; - if (best->Update(static_cast(loss_chg), fid, hist.cut[i], - false, s, c)) { + double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) + + CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain; + if (best->Update(static_cast(loss_chg), fid, hist.cut[i], false, s, c)) { *left_sum = s; } } } } - s.Clear(); + s = GradStats(); for (bst_uint i = hist.size - 1; i != 0; --i) { s.Add(hist.data[i]); if (s.sum_hess >= param_.min_child_weight) { c.SetSubstract(node_sum, s); if (c.sum_hess >= param_.min_child_weight) { - double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; + double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) + + CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain; if (best->Update(static_cast(loss_chg), fid, hist.cut[i-1], true, c, s)) { *left_sum = c; } @@ -222,14 +218,14 @@ class HistMaker: public BaseMaker { const size_t num_feature = fset.size(); // get the best split condition for each node std::vector sol(qexpand_.size()); - std::vector left_sum(qexpand_.size()); + std::vector left_sum(qexpand_.size()); auto nexpand = static_cast(qexpand_.size()); #pragma omp parallel for schedule(dynamic, 1) for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { const int nid = qexpand_[wid]; CHECK_EQ(node2workindex_[nid], static_cast(wid)); SplitEntry &best = sol[wid]; - TStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; + GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; for (size_t i = 0; i < fset.size(); ++i) { EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)], node_sum, fset[i], &best, &left_sum[wid]); @@ -239,13 +235,13 @@ class HistMaker: public BaseMaker { for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { const int nid = qexpand_[wid]; const SplitEntry &best = sol[wid]; - const TStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; + const GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; this->SetStats(p_tree, nid, node_sum); // set up the values p_tree->Stat(nid).loss_chg = best.loss_chg; // now we know the solution in snode[nid], set split if (best.loss_chg > kRtEps) { - bst_float base_weight = node_sum.CalcWeight(param_); + bst_float base_weight = CalcWeight(param_, node_sum); bst_float left_leaf_weight = CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) * param_.learning_rate; @@ -258,7 +254,7 @@ class HistMaker: public BaseMaker { right_leaf_weight, best.loss_chg, node_sum.sum_hess); // right side sum - TStats right_sum; + GradStats right_sum; right_sum.SetSubstract(node_sum, left_sum[wid]); this->SetStats(p_tree, (*p_tree)[nid].LeftChild(), left_sum[wid]); this->SetStats(p_tree, (*p_tree)[nid].RightChild(), right_sum); @@ -268,20 +264,20 @@ class HistMaker: public BaseMaker { } } - inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) { - p_tree->Stat(nid).base_weight = static_cast(node_sum.CalcWeight(param_)); + inline void SetStats(RegTree *p_tree, int nid, const GradStats &node_sum) { + p_tree->Stat(nid).base_weight = + static_cast(CalcWeight(param_, node_sum)); p_tree->Stat(nid).sum_hess = static_cast(node_sum.sum_hess); } }; -template -class CQHistMaker: public HistMaker { +class CQHistMaker: public HistMaker { public: CQHistMaker() = default; protected: struct HistEntry { - typename HistMaker::HistUnit hist; + HistMaker::HistUnit hist; unsigned istart; /*! * \brief add a histogram to data, @@ -293,7 +289,7 @@ class CQHistMaker: public HistMaker { const bst_uint ridx) { while (istart < hist.size && !(fv < hist.cut[istart])) ++istart; CHECK_NE(istart, hist.size); - hist.data[istart].Add(gpair, info, ridx); + hist.data[istart].Add(gpair[ridx]); } /*! * \brief add a histogram to data, @@ -352,7 +348,7 @@ class CQHistMaker: public HistMaker { for (const auto &batch : p_fmat->GetSortedColumnBatches()) { // start enumeration const auto nsize = static_cast(fset.size()); - #pragma omp parallel for schedule(dynamic, 1) +#pragma omp parallel for schedule(dynamic, 1) for (bst_omp_uint i = 0; i < nsize; ++i) { int fid = fset[i]; int offset = feat2workindex_[fid]; @@ -366,8 +362,7 @@ class CQHistMaker: public HistMaker { // update node statistics. this->GetNodeStats(gpair, *p_fmat, tree, &thread_stats_, &node_stats_); - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - const int nid = this->qexpand_[i]; + for (int const nid : this->qexpand_) { const int wid = this->node2workindex_[nid]; this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)] .data[0] = node_stats_[nid]; @@ -403,8 +398,8 @@ class CQHistMaker: public HistMaker { const size_t work_set_size = work_set_.size(); sketchs_.resize(this->qexpand_.size() * work_set_size); - for (size_t i = 0; i < sketchs_.size(); ++i) { - sketchs_[i].Init(info.num_row_, this->param_.sketch_eps); + for (auto& sketch : sketchs_) { + sketch.Init(info.num_row_, this->param_.sketch_eps); } // intitialize the summary array summary_array_.resize(sketchs_.size()); @@ -501,13 +496,12 @@ class CQHistMaker: public HistMaker { // initialize sbuilder for use std::vector &hbuilder = *p_temp; hbuilder.resize(tree.param.num_nodes); - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - const unsigned nid = this->qexpand_[i]; + for (int const nid : this->qexpand_) { const unsigned wid = this->node2workindex_[nid]; hbuilder[nid].istart = 0; hbuilder[nid].hist = this->wspace_.hset[0][fid_offset + wid * (fset.size()+1)]; } - if (TStats::kSimpleStats != 0 && this->param_.cache_opt != 0) { + if (this->param_.cache_opt != 0) { constexpr bst_uint kBuffer = 32; bst_uint align_length = col.size() / kBuffer * kBuffer; int buf_position[kBuffer]; @@ -552,13 +546,11 @@ class CQHistMaker: public HistMaker { // initialize sbuilder for use std::vector &sbuilder = *p_temp; sbuilder.resize(tree.param.num_nodes); - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - const unsigned nid = this->qexpand_[i]; + for (int const nid : this->qexpand_) { const unsigned wid = this->node2workindex_[nid]; sbuilder[nid].sum_total = 0.0f; sbuilder[nid].sketch = &sketchs_[wid * work_set_size + offset]; } - // first pass, get sum of weight, TODO, optimization to skip first pass for (const auto& c : col) { const bst_uint ridx = c.index; @@ -569,20 +561,19 @@ class CQHistMaker: public HistMaker { } // if only one value, no need to do second pass if (col[0].fvalue == col[col.size()-1].fvalue) { - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - const int nid = this->qexpand_[i]; - sbuilder[nid].sketch->Push(col[0].fvalue, static_cast(sbuilder[nid].sum_total)); + for (int const nid : this->qexpand_) { + sbuilder[nid].sketch->Push( + col[0].fvalue, static_cast(sbuilder[nid].sum_total)); } return; } // two pass scan unsigned max_size = this->param_.MaxSketchSize(); - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - const int nid = this->qexpand_[i]; + for (int const nid : this->qexpand_) { sbuilder[nid].Init(max_size); } // second pass, build the sketch - if (TStats::kSimpleStats != 0 && this->param_.cache_opt != 0) { + if (this->param_.cache_opt != 0) { constexpr bst_uint kBuffer = 32; bst_uint align_length = col.size() / kBuffer * kBuffer; int buf_position[kBuffer]; @@ -616,10 +607,7 @@ class CQHistMaker: public HistMaker { } } } - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - const int nid = this->qexpand_[i]; - sbuilder[nid].Finalize(max_size); - } + for (int const nid : this->qexpand_) { sbuilder[nid].Finalize(max_size); } } // cached dmatrix where we initialized the feature on. const DMatrix* cache_dmatrix_{nullptr}; @@ -634,11 +622,11 @@ class CQHistMaker: public HistMaker { // thread temp data std::vector > thread_sketch_; // used to hold statistics - std::vector > thread_stats_; + std::vector > thread_stats_; // used to hold start pointer std::vector > thread_hist_; // node statistics - std::vector node_stats_; + std::vector node_stats_; // summary array std::vector summary_array_; // reducer for summary @@ -648,8 +636,7 @@ class CQHistMaker: public HistMaker { }; // global proposal -template -class GlobalProposalHistMaker: public CQHistMaker { +class GlobalProposalHistMaker: public CQHistMaker { protected: void ResetPosAndPropose(const std::vector &gpair, DMatrix *p_fmat, @@ -661,7 +648,7 @@ class GlobalProposalHistMaker: public CQHistMaker { } if (cached_rptr_.size() == 0) { CHECK_EQ(this->qexpand_.size(), 1U); - CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree); + CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree); cached_rptr_ = this->wspace_.rptr; cached_cut_ = this->wspace_.cut; } else { @@ -730,8 +717,7 @@ class GlobalProposalHistMaker: public CQHistMaker { // update node statistics. this->GetNodeStats(gpair, *p_fmat, tree, &(this->thread_stats_), &(this->node_stats_)); - for (size_t i = 0; i < this->qexpand_.size(); ++i) { - const int nid = this->qexpand_[i]; + for (const int nid : this->qexpand_) { const int wid = this->node2workindex_[nid]; this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)] .data[0] = this->node_stats_[nid]; @@ -750,19 +736,19 @@ class GlobalProposalHistMaker: public CQHistMaker { XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") .describe("Tree constructor that uses approximate histogram construction.") .set_body([]() { - return new CQHistMaker(); + return new CQHistMaker(); }); XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker") .describe("Tree constructor that uses approximate global proposal of histogram construction.") .set_body([]() { - return new GlobalProposalHistMaker(); + return new GlobalProposalHistMaker(); }); XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") .describe("Tree constructor that uses approximate global of histogram construction.") .set_body([]() { - return new GlobalProposalHistMaker(); + return new GlobalProposalHistMaker(); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 1d205c364..d96ca0a08 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -50,9 +50,8 @@ void QuantileHistMaker::Init(const std::vector *gpair, - DMatrix *dmat, - const std::vector &trees) { - GradStats::CheckInfo(dmat->Info()); + DMatrix *dmat, + const std::vector &trees) { if (is_gmat_initialized_ == false) { double tstart = dmlc::GetTime(); gmat_.Init(dmat, static_cast(param_.max_bin)); @@ -91,11 +90,11 @@ bool QuantileHistMaker::UpdatePredictionCache( } void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, - const GHistIndexBlockMatrix& gmatb, - const ColumnMatrix& column_matrix, - HostDeviceVector* gpair, - DMatrix* p_fmat, - RegTree* p_tree) { + const GHistIndexBlockMatrix& gmatb, + const ColumnMatrix& column_matrix, + HostDeviceVector* gpair, + DMatrix* p_fmat, + RegTree* p_tree) { double gstart = dmlc::GetTime(); int num_leaves = 0; @@ -280,9 +279,9 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, - const std::vector& gpair, - const DMatrix& fmat, - const RegTree& tree) { + const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree) { CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) << "ColMakerHist: can only grow new tree"; CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) @@ -395,11 +394,11 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, } } -void QuantileHistMaker::Builder::EvaluateSplit(int nid, - const GHistIndexMatrix& gmat, - const HistCollection& hist, - const DMatrix& fmat, - const RegTree& tree) { +void QuantileHistMaker::Builder::EvaluateSplit(const int nid, + const GHistIndexMatrix& gmat, + const HistCollection& hist, + const DMatrix& fmat, + const RegTree& tree) { // start enumeration const MetaInfo& info = fmat.Info(); auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); @@ -411,13 +410,14 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid, for (bst_omp_uint tid = 0; tid < nthread; ++tid) { best_split_tloc_[tid] = snode_[nid].best; } + GHistRow node_hist = hist[nid]; #pragma omp parallel for schedule(dynamic) num_threads(nthread) for (bst_omp_uint i = 0; i < nfeature; ++i) { const bst_uint fid = feature_set[i]; const unsigned tid = omp_get_thread_num(); - this->EnumerateSplit(-1, gmat, hist[nid], snode_[nid], info, + this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, &best_split_tloc_[tid], fid, nid); - this->EnumerateSplit(+1, gmat, hist[nid], snode_[nid], info, + this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, &best_split_tloc_[tid], fid, nid); } for (unsigned tid = 0; tid < nthread; ++tid) { @@ -426,11 +426,11 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid, } void QuantileHistMaker::Builder::ApplySplit(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree) { + const GHistIndexMatrix& gmat, + const ColumnMatrix& column_matrix, + const HistCollection& hist, + const DMatrix& fmat, + RegTree* p_tree) { // TODO(hcho3): support feature sampling by levels /* 1. Create child nodes */ @@ -613,10 +613,10 @@ void QuantileHistMaker::Builder::ApplySplitSparseData( } void QuantileHistMaker::Builder::InitNewNode(int nid, - const GHistIndexMatrix& gmat, - const std::vector& gpair, - const DMatrix& fmat, - const RegTree& tree) { + const GHistIndexMatrix& gmat, + const std::vector& gpair, + const DMatrix& fmat, + const RegTree& tree) { { snode_.resize(tree.param.num_nodes, NodeEntry(param_)); } @@ -628,22 +628,24 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, // in distributed mode, the node's stats should be calculated from histogram, otherwise, // we will have wrong results in EnumerateSplit() // here we take the last feature in cut + auto begin = hist.data(); for (size_t i = gmat.cut.row_ptr[0]; i < gmat.cut.row_ptr[1]; i++) { - stats.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess); + stats.Add(begin[i].sum_grad, begin[i].sum_hess); } } else { if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased || rabit::IsDistributed()) { /* specialized code for dense data - For dense data (with no missing value), - the sum of gradient histogram is equal to snode[nid] - GHistRow hist = hist_[nid];*/ + For dense data (with no missing value), + the sum of gradient histogram is equal to snode[nid] + GHistRow hist = hist_[nid];*/ const std::vector& row_ptr = gmat.cut.row_ptr; const uint32_t ibegin = row_ptr[fid_least_bins_]; const uint32_t iend = row_ptr[fid_least_bins_ + 1]; + auto begin = hist.data(); for (uint32_t i = ibegin; i < iend; ++i) { - const GHistEntry et = hist.begin[i]; + const GradStats et = begin[i]; stats.Add(et.sum_grad, et.sum_hess); } } else { @@ -653,27 +655,27 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, } } } - } - // calculating the weights - { - bst_uint parentid = tree[nid].Parent(); - snode_[nid].weight = static_cast( - spliteval_->ComputeWeight(parentid, snode_[nid].stats)); - snode_[nid].root_gain = static_cast( - spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight)); + // calculating the weights + { + bst_uint parentid = tree[nid].Parent(); + snode_[nid].weight = static_cast( + spliteval_->ComputeWeight(parentid, snode_[nid].stats)); + snode_[nid].root_gain = static_cast( + spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight)); + } } } // enumerate the split values of specific feature void QuantileHistMaker::Builder::EnumerateSplit(int d_step, - const GHistIndexMatrix& gmat, - const GHistRow& hist, - const NodeEntry& snode, - const MetaInfo& info, - SplitEntry* p_best, - bst_uint fid, - bst_uint nodeID) { + const GHistIndexMatrix& gmat, + const GHistRow& hist, + const NodeEntry& snode, + const MetaInfo& info, + SplitEntry* p_best, + bst_uint fid, + bst_uint nodeID) { CHECK(d_step == +1 || d_step == -1); // aliases @@ -681,8 +683,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step, const std::vector& cut_val = gmat.cut.cut; // statistics on both sides of split - GradStats c(param_); - GradStats e(param_); + GradStats c; + GradStats e; // best split so far SplitEntry best; @@ -708,7 +710,7 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step, for (int32_t i = ibegin; i != iend; i += d_step) { // start working // try to find a split - e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess); + e.Add(hist[i].GetGrad(), hist[i].GetHess()); if (e.sum_hess >= param_.min_child_weight) { c.SetSubstract(snode.stats, e); if (c.sum_hess >= param_.min_child_weight) { diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 4e8a1f276..8e68aea44 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -30,7 +30,6 @@ using xgboost::common::HistCutMatrix; using xgboost::common::GHistIndexMatrix; using xgboost::common::GHistIndexBlockMatrix; using xgboost::common::GHistIndexRow; -using xgboost::common::GHistEntry; using xgboost::common::HistCollection; using xgboost::common::RowSetCollection; using xgboost::common::GHistRow; @@ -73,8 +72,7 @@ class QuantileHistMaker: public TreeUpdater { SplitEntry best; // constructor explicit NodeEntry(const TrainParam& param) - : stats(param), root_gain(0.0f), weight(0.0f) { - } + : root_gain(0.0f), weight(0.0f) {} }; // actual builder that runs the algorithm @@ -105,7 +103,7 @@ class QuantileHistMaker: public TreeUpdater { } else { hist_builder_.BuildHist(gpair, row_indices, gmat, hist); } - this->histred_.Allreduce(hist.begin, hist_builder_.GetNumBins()); + this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins()); } inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { @@ -122,7 +120,7 @@ class QuantileHistMaker: public TreeUpdater { const DMatrix& fmat, const RegTree& tree); - void EvaluateSplit(int nid, + void EvaluateSplit(const int nid, const GHistIndexMatrix& gmat, const HistCollection& hist, const DMatrix& fmat, @@ -227,7 +225,7 @@ class QuantileHistMaker: public TreeUpdater { enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; - rabit::Reducer histred_; + rabit::Reducer histred_; }; std::unique_ptr builder_; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 98926cb22..7b0ab5dcc 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -19,7 +19,6 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_refresh); /*! \brief pruner that prunes a tree after growing finishs */ -template class TreeRefresher: public TreeUpdater { public: void Init(const std::vector >& args) override { @@ -31,14 +30,13 @@ class TreeRefresher: public TreeUpdater { const std::vector &trees) override { if (trees.size() == 0) return; const std::vector &gpair_h = gpair->ConstHostVector(); - // number of threads // thread temporal space - std::vector > stemp; + std::vector > stemp; std::vector fvec_temp; // setup temp space for each thread const int nthread = omp_get_max_threads(); fvec_temp.resize(nthread, RegTree::FVec()); - stemp.resize(nthread, std::vector()); + stemp.resize(nthread, std::vector()); #pragma omp parallel { int tid = omp_get_thread_num(); @@ -46,8 +44,8 @@ class TreeRefresher: public TreeUpdater { for (auto tree : trees) { num_nodes += tree->param.num_nodes; } - stemp[tid].resize(num_nodes, TStats(param_)); - std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param_)); + stemp[tid].resize(num_nodes, GradStats()); + std::fill(stemp[tid].begin(), stemp[tid].end(), GradStats()); fvec_temp[tid].Init(trees[0]->param.num_feature); } // if it is C++11, use lazy evaluation for Allreduce, @@ -104,21 +102,22 @@ class TreeRefresher: public TreeUpdater { const std::vector &gpair, const MetaInfo &info, const bst_uint ridx, - TStats *gstats) { + GradStats *gstats) { // start from groups that belongs to current data auto pid = static_cast(info.GetRoot(ridx)); - gstats[pid].Add(gpair, info, ridx); + gstats[pid].Add(gpair[ridx]); // tranverse tree while (!tree[pid].IsLeaf()) { unsigned split_index = tree[pid].SplitIndex(); pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); - gstats[pid].Add(gpair, info, ridx); + gstats[pid].Add(gpair[ridx]); } } - inline void Refresh(const TStats *gstats, + inline void Refresh(const GradStats *gstats, int nid, RegTree *p_tree) { RegTree &tree = *p_tree; - tree.Stat(nid).base_weight = static_cast(gstats[nid].CalcWeight(param_)); + tree.Stat(nid).base_weight = + static_cast(CalcWeight(param_, gstats[nid])); tree.Stat(nid).sum_hess = static_cast(gstats[nid].sum_hess); if (tree[nid].IsLeaf()) { if (param_.refresh_leaf) { @@ -126,9 +125,9 @@ class TreeRefresher: public TreeUpdater { } } else { tree.Stat(nid).loss_chg = static_cast( - gstats[tree[nid].LeftChild()].CalcGain(param_) + - gstats[tree[nid].RightChild()].CalcGain(param_) - - gstats[nid].CalcGain(param_)); + xgboost::tree::CalcGain(param_, gstats[tree[nid].LeftChild()]) + + xgboost::tree::CalcGain(param_, gstats[tree[nid].RightChild()]) - + xgboost::tree::CalcGain(param_, gstats[nid])); this->Refresh(gstats, tree[nid].LeftChild(), p_tree); this->Refresh(gstats, tree[nid].RightChild(), p_tree); } @@ -136,13 +135,13 @@ class TreeRefresher: public TreeUpdater { // training parameter TrainParam param_; // reducer - rabit::Reducer reducer_; + rabit::Reducer reducer_; }; XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") .describe("Refresher that refreshes the weight and statistics according to data.") .set_body([]() { - return new TreeRefresher(); + return new TreeRefresher(); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 9e94d5dae..dba1a1d5c 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -83,26 +83,17 @@ class SketchMaker: public BaseMaker { double neg_grad; /*! \brief sum of hessian statistics */ double sum_hess; - SKStats() = default; - // constructor - explicit SKStats(const TrainParam ¶m) { - this->Clear(); - } - /*! \brief clear the statistics */ - inline void Clear() { - neg_grad = pos_grad = sum_hess = 0.0f; - } + + SKStats() : pos_grad{0}, neg_grad{0}, sum_hess{0} {} + // accumulate statistics - inline void Add(const std::vector &gpair, - const MetaInfo &info, - bst_uint ridx) { - const GradientPair &b = gpair[ridx]; - if (b.GetGrad() >= 0.0f) { - pos_grad += b.GetGrad(); + void Add(const GradientPair& gpair) { + if (gpair.GetGrad() >= 0.0f) { + pos_grad += gpair.GetGrad(); } else { - neg_grad -= b.GetGrad(); + neg_grad -= gpair.GetGrad(); } - sum_hess += b.GetHess(); + sum_hess += gpair.GetHess(); } /*! \brief calculate gain of the solution */ inline double CalcGain(const TrainParam ¶m) const { diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index d91b69c48..de2cb6253 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -57,10 +57,10 @@ class QuantileHistMock : public QuantileHistMaker { {0.26f, 0.27f}, {0.23f, 0.24f}, {0.27f, 0.28f}, {0.57f, 0.59f}, {0.23f, 0.24f}, {0.47f, 0.49f}}; - for (size_t i = 0; i < hist_[nid].size; ++i) { + for (size_t i = 0; i < hist_[nid].size(); ++i) { GradientPairPrecise sol = solution[i]; - ASSERT_NEAR(sol.GetGrad(), hist_[nid].begin[i].sum_grad, kEps); - ASSERT_NEAR(sol.GetHess(), hist_[nid].begin[i].sum_hess, kEps); + ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps); + ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps); } }