diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 510206f50..e55e1ef57 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -8,11 +8,11 @@ #ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_ +#include #include #include #include "hist_util.h" - namespace xgboost { namespace common { @@ -51,6 +51,10 @@ class Column { } const size_t* GetRowData() const { return row_ind_; } + const uint32_t* GetIndex() const { + return index_; + } + private: ColumnType type_; const uint32_t* index_; @@ -80,7 +84,7 @@ class ColumnMatrix { std::fill(feature_counts_.begin(), feature_counts_.end(), 0); uint32_t max_val = std::numeric_limits::max(); - for (bst_uint fid = 0; fid < nfeature; ++fid) { + for (int32_t fid = 0; fid < nfeature; ++fid) { CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val); } @@ -113,13 +117,12 @@ class ColumnMatrix { boundary_[fid].index_end = accum_index_; boundary_[fid].row_ind_end = accum_row_ind_; } - index_.resize(boundary_[nfeature - 1].index_end); row_ind_.resize(boundary_[nfeature - 1].row_ind_end); // store least bin id for each feature index_base_.resize(nfeature); - for (bst_uint fid = 0; fid < nfeature; ++fid) { + for (int32_t fid = 0; fid < nfeature; ++fid) { index_base_[fid] = gmat.cut.row_ptr[fid]; } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 78a5c950b..af420db4d 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -1,15 +1,15 @@ /*! * Copyright 2017-2019 by Contributors - * \file hist_util.h + * \file hist_util.cc */ +#include "./hist_util.h" +#include #include #include #include #include - #include "./random.h" #include "./column_matrix.h" -#include "./hist_util.h" #include "./quantile.h" #include "./../tree/updater_quantile_hist.h" @@ -178,7 +178,7 @@ uint32_t HistCutMatrix::GetBinIdx(const Entry& e) { void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { cut.Init(p_fmat, max_num_bins); - const size_t nthread = omp_get_max_threads(); + const int32_t nthread = omp_get_max_threads(); const uint32_t nbins = cut.row_ptr.back(); hit_count.resize(nbins, 0); hit_count_tloc_.resize(nthread * nbins, 0); @@ -260,8 +260,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { } #pragma omp parallel for num_threads(nthread) schedule(static) - for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) { - for (size_t tid = 0; tid < nthread; ++tid) { + for (int32_t idx = 0; idx < int32_t(nbins); ++idx) { + for (int32_t tid = 0; tid < nthread; ++tid) { hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; } } @@ -411,7 +411,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat, for (auto fid : group) { nnz += feature_nnz[fid]; } - double nnz_rate = static_cast(nnz) / nrow; + float nnz_rate = static_cast(nnz) / nrow; // take apart small sparse group, due it will not gain on speed if (nnz_rate <= param.sparse_threshold) { for (auto fid : group) { @@ -496,176 +496,144 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, } } -void GHistBuilder::BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist) { - const size_t nthread = static_cast(this->nthread_); - data_.resize(nbins_ * nthread_); - - const size_t* rid = row_indices.begin; - const size_t nrows = row_indices.Size(); - const uint32_t* index = gmat.index.data(); - const size_t* row_ptr = gmat.row_ptr.data(); - const float* pgh = reinterpret_cast(gpair.data()); - - double* hist_data = reinterpret_cast(hist.data()); - double* data = reinterpret_cast(data_.data()); - - const size_t block_size = 512; - size_t n_blocks = nrows/block_size; - n_blocks += !!(nrows - n_blocks*block_size); - - const size_t nthread_to_process = std::min(nthread, n_blocks); - memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t)); +// used when data layout is kDenseDataZeroBased or kDenseDataOneBased +// it means that "row_ptr" is not needed for hist computations +void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid, + const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, + GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) { + GradStatHist grad_stat; // make local var to prevent false sharing + const size_t n_features = row_ptr[rid[istart]+1] - row_ptr[rid[istart]]; const size_t cache_line_size = 64; + const size_t prefetch_step = cache_line_size / sizeof(*index); const size_t prefetch_offset = 10; + size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; -#pragma omp parallel for num_threads(nthread_to_process) schedule(guided) - for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { - dmlc::omp_uint tid = omp_get_thread_num(); - double* data_local_hist = ((nthread_to_process == 1) ? hist_data : - reinterpret_cast(data_.data() + tid * nbins_)); + // if read each row in some block of bin-matrix - it's dense block + // and we dont need SW prefetch in this case + const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart - 1); - if (!thread_init_[tid]) { - memset(data_local_hist, '\0', 2*nbins_*sizeof(double)); - thread_init_[tid] = true; - } - - const size_t istart = iblock*block_size; - const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size); + if (iend < nrows - no_prefetch_size && !denseBlock) { for (size_t i = istart; i < iend; ++i) { - const size_t icol_start = row_ptr[rid[i]]; - const size_t icol_end = row_ptr[rid[i]+1]; + const size_t icol_start = rid[i] * n_features; + const size_t icol_start_prefetch = rid[i+prefetch_offset] * n_features; + const size_t idx_gh = 2*rid[i]; - if (i < nrows - no_prefetch_size) { - PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]); - PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + + for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features; + j += prefetch_step) { + PREFETCH_READ_T0(index + j); } - for (size_t j = icol_start; j < icol_end; ++j) { - const uint32_t idx_bin = 2*index[j]; - const size_t idx_gh = 2*rid[i]; + grad_stat.sum_grad += pgh[idx_gh]; + grad_stat.sum_hess += pgh[idx_gh+1]; + for (size_t j = icol_start; j < icol_start + n_features; ++j) { + const uint32_t idx_bin = 2*index[j]; data_local_hist[idx_bin] += pgh[idx_gh]; data_local_hist[idx_bin+1] += pgh[idx_gh+1]; } } - } + } else { + for (size_t i = istart; i < iend; ++i) { + const size_t icol_start = rid[i] * n_features; + const size_t idx_gh = 2*rid[i]; + grad_stat.sum_grad += pgh[idx_gh]; + grad_stat.sum_hess += pgh[idx_gh+1]; - if (nthread_to_process > 1) { - const size_t size = (2*nbins_); - const size_t block_size = 1024; - size_t n_blocks = size/block_size; - n_blocks += !!(size - n_blocks*block_size); - - size_t n_worked_bins = 0; - for (size_t i = 0; i < nthread_to_process; ++i) { - if (thread_init_[i]) { - thread_init_[n_worked_bins++] = i; - } - } - -#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided) - for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { - const size_t istart = iblock * block_size; - const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size); - - const size_t bin = 2 * thread_init_[0] * nbins_; - memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart)); - - for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) { - const size_t bin = 2 * thread_init_[i_bin_part] * nbins_; - for (size_t i = istart; i < iend; i++) { - hist_data[i] += data[bin + i]; - } + for (size_t j = icol_start; j < icol_start + n_features; ++j) { + const uint32_t idx_bin = 2*index[j]; + data_local_hist[idx_bin] += pgh[idx_gh]; + data_local_hist[idx_bin+1] += pgh[idx_gh+1]; } } } + grad_stat_global->Add(grad_stat); } -void GHistBuilder::BuildBlockHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexBlockMatrix& gmatb, - GHistRow hist) { - constexpr int kUnroll = 8; // loop unrolling factor - const size_t nblock = gmatb.GetNumBlock(); - const size_t nrows = row_indices.end - row_indices.begin; - const size_t rest = nrows % kUnroll; +// used when data layout is kSparseData +// it means that "row_ptr" is needed for hist computations +void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid, + const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, + GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) { + GradStatHist grad_stat; // make local var to prevent false sharing -#if defined(_OPENMP) - const auto nthread = static_cast(this->nthread_); // NOLINT -#endif // defined(_OPENMP) - tree::GradStats* p_hist = hist.data(); + const size_t cache_line_size = 64; + const size_t prefetch_step = cache_line_size / sizeof(index[0]); + const size_t prefetch_offset = 10; -#pragma omp parallel for num_threads(nthread) schedule(guided) - for (bst_omp_uint bid = 0; bid < nblock; ++bid) { - auto gmat = gmatb[bid]; + size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); + no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; - for (size_t i = 0; i < nrows - rest; i += kUnroll) { - size_t rid[kUnroll]; - size_t ibegin[kUnroll]; - size_t iend[kUnroll]; - GradientPair stat[kUnroll]; + // if read each row in some block of bin-matrix - it's dense block + // and we dont need SW prefetch in this case + const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart); - for (int k = 0; k < kUnroll; ++k) { - rid[k] = row_indices.begin[i + k]; - ibegin[k] = gmat.row_ptr[rid[k]]; - iend[k] = gmat.row_ptr[rid[k] + 1]; - stat[k] = gpair[rid[k]]; + if (iend < nrows - no_prefetch_size && !denseBlock) { + for (size_t i = istart; i < iend; ++i) { + const size_t icol_start = row_ptr[rid[i]]; + const size_t icol_end = row_ptr[rid[i]+1]; + const size_t idx_gh = 2*rid[i]; + + const size_t icol_start10 = row_ptr[rid[i+prefetch_offset]]; + const size_t icol_end10 = row_ptr[rid[i+prefetch_offset]+1]; + + PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]); + + for (size_t j = icol_start10; j < icol_end10; j+=prefetch_step) { + PREFETCH_READ_T0(index + j); } - for (int k = 0; k < kUnroll; ++k) { - for (size_t j = ibegin[k]; j < iend[k]; ++j) { - const uint32_t bin = gmat.index[j]; - p_hist[bin].Add(stat[k]); - } + + grad_stat.sum_grad += pgh[idx_gh]; + grad_stat.sum_hess += pgh[idx_gh+1]; + + for (size_t j = icol_start; j < icol_end; ++j) { + const uint32_t idx_bin = 2*index[j]; + data_local_hist[idx_bin] += pgh[idx_gh]; + data_local_hist[idx_bin+1] += pgh[idx_gh+1]; } } - for (size_t i = nrows - rest; i < nrows; ++i) { - const size_t rid = row_indices.begin[i]; - const size_t ibegin = gmat.row_ptr[rid]; - const size_t iend = gmat.row_ptr[rid + 1]; - const GradientPair stat = gpair[rid]; - for (size_t j = ibegin; j < iend; ++j) { - const uint32_t bin = gmat.index[j]; - p_hist[bin].Add(stat); + } else { + for (size_t i = istart; i < iend; ++i) { + const size_t icol_start = row_ptr[rid[i]]; + const size_t icol_end = row_ptr[rid[i]+1]; + const size_t idx_gh = 2*rid[i]; + + grad_stat.sum_grad += pgh[idx_gh]; + grad_stat.sum_hess += pgh[idx_gh+1]; + + for (size_t j = icol_start; j < icol_end; ++j) { + const uint32_t idx_bin = 2*index[j]; + data_local_hist[idx_bin] += pgh[idx_gh]; + data_local_hist[idx_bin+1] += pgh[idx_gh+1]; } } } + grad_stat_global->Add(grad_stat); } -void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { - const uint32_t nbins = static_cast(nbins_); - constexpr int kUnroll = 8; // loop unrolling factor - const uint32_t rest = nbins % kUnroll; +void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { + GradStatHist* p_self = self.data(); + GradStatHist* p_sibling = sibling.data(); + GradStatHist* p_parent = parent.data(); -#if defined(_OPENMP) - const auto nthread = static_cast(this->nthread_); // NOLINT -#endif // defined(_OPENMP) - tree::GradStats* p_self = self.data(); - tree::GradStats* p_sibling = sibling.data(); - tree::GradStats* p_parent = parent.data(); + const size_t size = self.size(); + CHECK_EQ(sibling.size(), size); + CHECK_EQ(parent.size(), size); -#pragma omp parallel for num_threads(nthread) schedule(static) - for (bst_omp_uint bin_id = 0; - bin_id < static_cast(nbins - rest); bin_id += kUnroll) { - tree::GradStats pb[kUnroll]; - tree::GradStats sb[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - pb[k] = p_parent[bin_id + k]; + const size_t block_size = 1024; // aproximatly 1024 values per block + size_t n_blocks = size/block_size + !!(size%block_size); + + #pragma omp parallel for + for (int iblock = 0; iblock < n_blocks; ++iblock) { + const size_t ibegin = iblock*block_size; + const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size); + for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) { + p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]); } - for (int k = 0; k < kUnroll; ++k) { - sb[k] = p_sibling[bin_id + k]; - } - for (int k = 0; k < kUnroll; ++k) { - p_self[bin_id + k].SetSubstract(pb[k], sb[k]); - } - } - for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) { - p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]); } } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index dc2b80bb8..0cef1878e 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -11,13 +11,50 @@ #include #include #include +#include +#include #include "row_set.h" #include "../tree/param.h" #include "./quantile.h" #include "./timer.h" #include "../include/rabit/rabit.h" +#include "random.h" namespace xgboost { + +/*! + * \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated. + */ +template +class MemStackAllocator { + public: + explicit MemStackAllocator(size_t required_size): required_size_(required_size) { + } + + T* Get() { + if (!ptr_) { + if (MaxStackSize >= required_size_) { + ptr_ = stack_mem_; + } else { + ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); + do_free_ = true; + } + } + + return ptr_; + } + + ~MemStackAllocator() { + if (do_free_) free(ptr_); + } + + private: + T* ptr_ = nullptr; + bool do_free_ = false; + size_t required_size_; + T stack_mem_[MaxStackSize]; +}; + namespace common { /* @@ -114,7 +151,7 @@ struct HistCutMatrix { }; /*! \brief Builds the cut matrix on the GPU. - * + * * \return The row stride across the entire dataset. */ size_t DeviceSketch @@ -134,9 +171,10 @@ using GHistIndexRow = Span; */ struct GHistIndexMatrix { /*! \brief row pointer to rows by element position */ - std::vector row_ptr; + // std::vector row_ptr; + SimpleArray row_ptr; /*! \brief The index data */ - std::vector index; + SimpleArray index; /*! \brief hit count of each index */ std::vector hit_count; /*! \brief The corresponding cuts */ @@ -170,6 +208,11 @@ struct GHistIndexBlock { inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index) : row_ptr(row_ptr), index(index) {} + + // get i-th row + inline GHistIndexRow operator[](size_t i) const { + return {&index[0] + row_ptr[i], detail::ptrdiff_t(row_ptr[i + 1] - row_ptr[i])}; + } }; class ColumnMatrix; @@ -202,12 +245,63 @@ class GHistIndexBlockMatrix { }; /*! - * \brief histogram of graident statistics for a single node. - * Consists of multiple GradStats, each entry showing total graident statistics - * for that particular bin - * Uses global bin id so as to represent all features simultaneously + * \brief used instead of GradStats to have float instead of double to reduce histograms + * this improves performance by 10-30% and memory consumption for histograms by 2x + * accuracy in both cases is the same */ -using GHistRow = Span; +struct GradStatHist { + typedef float GradType; + /*! \brief sum gradient statistics */ + GradType sum_grad; + /*! \brief sum hessian statistics */ + GradType sum_hess; + + GradStatHist() : sum_grad{0}, sum_hess{0} { + static_assert(sizeof(GradStatHist) == 8, + "Size of GradStatHist is not 8 bytes."); + } + + inline void Add(const GradStatHist& b) { + sum_grad += b.sum_grad; + sum_hess += b.sum_hess; + } + + inline void Add(const tree::GradStats& b) { + sum_grad += b.sum_grad; + sum_hess += b.sum_hess; + } + + inline void Add(const GradientPair& p) { + this->Add(p.GetGrad(), p.GetHess()); + } + + inline void Add(const GradType& grad, const GradType& hess) { + sum_grad += grad; + sum_hess += hess; + } + + inline tree::GradStats ToGradStat() const { + return tree::GradStats(sum_grad, sum_hess); + } + + inline void SetSubstract(const GradStatHist& a, const GradStatHist& b) { + sum_grad = a.sum_grad - b.sum_grad; + sum_hess = a.sum_hess - b.sum_hess; + } + + inline void SetSubstract(const tree::GradStats& a, const GradStatHist& b) { + sum_grad = a.sum_grad - b.sum_grad; + sum_hess = a.sum_hess - b.sum_hess; + } + + inline GradType GetGrad() const { return sum_grad; } + inline GradType GetHess() const { return sum_hess; } + inline static void Reduce(GradStatHist& a, const GradStatHist& b) { // NOLINT(*) + a.Add(b); + } +}; + +using GHistRow = Span; /*! * \brief histogram of gradient statistics for multiple nodes @@ -215,49 +309,43 @@ using GHistRow = Span; class HistCollection { public: // access histogram for i-th node - GHistRow operator[](bst_uint nid) const { - constexpr uint32_t kMax = std::numeric_limits::max(); - CHECK_NE(row_ptr_[nid], kMax); - tree::GradStats* ptr = - const_cast(dmlc::BeginPtr(data_) + row_ptr_[nid]); - return {ptr, nbins_}; + inline GHistRow operator[](bst_uint nid) { + AddHistRow(nid); + return { const_cast(dmlc::BeginPtr(data_arr_[nid])), nbins_}; } // have we computed a histogram for i-th node? - bool RowExists(bst_uint nid) const { - const uint32_t k_max = std::numeric_limits::max(); - return (nid < row_ptr_.size() && row_ptr_[nid] != k_max); + inline bool RowExists(bst_uint nid) const { + return nid < data_arr_.size(); } // initialize histogram collection - void Init(uint32_t nbins) { - nbins_ = nbins; - row_ptr_.clear(); - data_.clear(); + inline void Init(uint32_t nbins) { + if (nbins_ != nbins) { + data_arr_.clear(); + nbins_ = nbins; + } } // create an empty histogram for i-th node - void AddHistRow(bst_uint nid) { - constexpr uint32_t kMax = std::numeric_limits::max(); - if (nid >= row_ptr_.size()) { - row_ptr_.resize(nid + 1, kMax); - } - CHECK_EQ(row_ptr_[nid], kMax); + inline void AddHistRow(bst_uint nid) { + if (data_arr_.size() <= nid) { + size_t prev = data_arr_.size(); + data_arr_.resize(nid + 1); - row_ptr_[nid] = data_.size(); - data_.resize(data_.size() + nbins_); + for (size_t i = prev; i < data_arr_.size(); ++i) { + data_arr_[i].resize(nbins_); + } + } } private: /*! \brief number of all bins over all features */ - uint32_t nbins_; - - std::vector data_; - - /*! \brief row_ptr_[nid] locates bin for historgram of node nid */ - std::vector row_ptr_; + uint32_t nbins_ = 0; + std::vector> data_arr_; }; + /*! * \brief builder for histograms of gradient statistics */ @@ -267,21 +355,55 @@ class GHistBuilder { inline void Init(size_t nthread, uint32_t nbins) { nthread_ = nthread; nbins_ = nbins; - thread_init_.resize(nthread_); } - // construct a histogram via histogram aggregation - void BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist); - // same, with feature grouping void BuildBlockHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexBlockMatrix& gmatb, - GHistRow hist); - // construct a histogram via subtraction trick - void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); + const RowSetCollection::Elem row_indices, + const GHistIndexBlockMatrix& gmatb, + GHistRow hist) { + constexpr int kUnroll = 8; // loop unrolling factor + const int32_t nblock = gmatb.GetNumBlock(); + const size_t nrows = row_indices.end - row_indices.begin; + const size_t rest = nrows % kUnroll; + + #pragma omp parallel for + for (int32_t bid = 0; bid < nblock; ++bid) { + auto gmat = gmatb[bid]; + + for (size_t i = 0; i < nrows - rest; i += kUnroll) { + size_t rid[kUnroll]; + size_t ibegin[kUnroll]; + size_t iend[kUnroll]; + GradientPair stat[kUnroll]; + for (int k = 0; k < kUnroll; ++k) { + rid[k] = row_indices.begin[i + k]; + } + for (int k = 0; k < kUnroll; ++k) { + ibegin[k] = gmat.row_ptr[rid[k]]; + iend[k] = gmat.row_ptr[rid[k] + 1]; + } + for (int k = 0; k < kUnroll; ++k) { + stat[k] = gpair[rid[k]]; + } + for (int k = 0; k < kUnroll; ++k) { + for (size_t j = ibegin[k]; j < iend[k]; ++j) { + const uint32_t bin = gmat.index[j]; + hist[bin].Add(stat[k]); + } + } + } + for (size_t i = nrows - rest; i < nrows; ++i) { + const size_t rid = row_indices.begin[i]; + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; + const GradientPair stat = gpair[rid]; + for (size_t j = ibegin; j < iend; ++j) { + const uint32_t bin = gmat.index[j]; + hist[bin].Add(stat); + } + } + } + } uint32_t GetNumBins() { return nbins_; @@ -292,11 +414,19 @@ class GHistBuilder { size_t nthread_; /*! \brief number of all bins over all features */ uint32_t nbins_; - std::vector thread_init_; - std::vector data_; }; +void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid, + const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, + GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat); + +void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid, + const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr, + GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat); + +void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_HIST_UTIL_H_ diff --git a/src/common/row_set.h b/src/common/row_set.h index 285988b15..39ae404f8 100644 --- a/src/common/row_set.h +++ b/src/common/row_set.h @@ -27,10 +27,10 @@ class RowSetCollection { // id of node associated with this instance set; -1 means uninitialized Elem() = default; - Elem(const size_t* begin, - const size_t* end, - int node_id) - : begin(begin), end(end), node_id(node_id) {} + Elem(const size_t* begin_, + const size_t* end_, + int node_id_) + : begin(begin_), end(end_), node_id(node_id_) {} inline size_t Size() const { return end - begin; @@ -42,6 +42,10 @@ class RowSetCollection { std::vector right; }; + size_t Size(unsigned node_id) { + return elem_of_each_node_[node_id].Size(); + } + inline std::vector::const_iterator begin() const { // NOLINT return elem_of_each_node_.begin(); } @@ -51,12 +55,12 @@ class RowSetCollection { } /*! \brief return corresponding element set given the node_id */ - inline const Elem& operator[](unsigned node_id) const { - const Elem& e = elem_of_each_node_[node_id]; - CHECK(e.begin != nullptr) - << "access element that is not in the set"; + inline Elem operator[](unsigned node_id) const { + const Elem e = elem_of_each_node_[node_id]; return e; } + + // clear up things inline void Clear() { elem_of_each_node_.clear(); @@ -81,38 +85,29 @@ class RowSetCollection { const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size(); elem_of_each_node_.emplace_back(Elem(begin, end, 0)); } + // split rowset into two inline void AddSplit(unsigned node_id, - const std::vector& row_split_tloc, + size_t iLeft, unsigned left_node_id, unsigned right_node_id) { - const Elem e = elem_of_each_node_[node_id]; - const auto nthread = static_cast(row_split_tloc.size()); - CHECK(e.begin != nullptr); - size_t* all_begin = dmlc::BeginPtr(row_indices_); - size_t* begin = all_begin + (e.begin - all_begin); + Elem e = elem_of_each_node_[node_id]; - size_t* it = begin; - for (bst_omp_uint tid = 0; tid < nthread; ++tid) { - std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it); - it += row_split_tloc[tid].left.size(); - } - size_t* split_pt = it; - for (bst_omp_uint tid = 0; tid < nthread; ++tid) { - std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it); - it += row_split_tloc[tid].right.size(); - } + CHECK(e.begin != nullptr); + + size_t* begin = const_cast(e.begin); + size_t* split_pt = begin + iLeft; if (left_node_id >= elem_of_each_node_.size()) { - elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1)); + elem_of_each_node_.resize((left_node_id + 1)*2, Elem(nullptr, nullptr, -1)); } if (right_node_id >= elem_of_each_node_.size()) { - elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1)); + elem_of_each_node_.resize((right_node_id + 1)*2, Elem(nullptr, nullptr, -1)); } elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id); elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id); - elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1); + elem_of_each_node_[node_id] = Elem(begin, e.end, -1); } // stores the row indices in the set diff --git a/src/tree/param.h b/src/tree/param.h index b7594cbec..93f0797c6 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -291,7 +291,7 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess } } else { T w = CalcWeight(p, sum_grad, sum_hess); - T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w); + T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w); if (p.reg_alpha == 0.0f) { return ret; } else { @@ -311,7 +311,7 @@ template XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess, T test_grad, T test_hess) { T w = CalcWeight(sum_grad, sum_hess); - T ret = CalcGainGivenWeight(p, test_grad, test_hess); + T ret = CalcGainGivenWeight(p, test_grad, test_hess); if (p.reg_alpha == 0.0f) { return ret; } else { @@ -350,15 +350,16 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) } /*! \brief core statistics used for tree construction */ -struct XGBOOST_ALIGNAS(16) GradStats { +struct GradStats { + typedef double GradType; /*! \brief sum gradient statistics */ - double sum_grad; + GradType sum_grad; /*! \brief sum hessian statistics */ - double sum_hess; + GradType sum_hess; public: - XGBOOST_DEVICE double GetGrad() const { return sum_grad; } - XGBOOST_DEVICE double GetHess() const { return sum_hess; } + XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; } + XGBOOST_DEVICE GradType GetHess() const { return sum_hess; } XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} { static_assert(sizeof(GradStats) == 16, @@ -368,7 +369,7 @@ struct XGBOOST_ALIGNAS(16) GradStats { template XGBOOST_DEVICE explicit GradStats(const GpairT &sum) : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} - explicit GradStats(const double grad, const double hess) + explicit GradStats(const GradType grad, const GradType hess) : sum_grad(grad), sum_hess(hess) {} /*! * \brief accumulate statistics @@ -393,7 +394,7 @@ struct XGBOOST_ALIGNAS(16) GradStats { /*! \return whether the statistics is not used yet */ inline bool Empty() const { return sum_hess == 0.0; } /*! \brief add statistics to the data */ - inline void Add(double grad, double hess) { + inline void Add(GradType grad, GradType hess) { sum_grad += grad; sum_hess += hess; } @@ -423,7 +424,7 @@ struct ValueConstraint { template XGBOOST_DEVICE inline double CalcGain(const ParamT ¶m, GradStats stats) const { - return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess, + return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess, CalcWeight(param, stats)); } @@ -434,8 +435,8 @@ struct ValueConstraint { double wleft = CalcWeight(param, left); double wright = CalcWeight(param, right); double gain = - CalcGainGivenWeight(param, left.sum_grad, left.sum_hess, wleft) + - CalcGainGivenWeight(param, right.sum_grad, right.sum_hess, wright); + CalcGainGivenWeight(param, left.sum_grad, left.sum_hess, wleft) + + CalcGainGivenWeight(param, right.sum_grad, right.sum_hess, wright); if (constraint == 0) { return gain; } else if (constraint > 0) { @@ -480,6 +481,7 @@ struct SplitEntry { bst_float split_value{0.0f}; GradStats left_sum; GradStats right_sum; + bool default_left{true}; /*! \brief constructor */ SplitEntry() = default; @@ -494,7 +496,11 @@ struct SplitEntry { * \param split_index the feature index where the split is on */ inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { - if (this->SplitIndex() <= split_index) { + if (!std::isfinite(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, + // for example when lambda = 0 & min_child_weight = 0 + // skip value in this case + return false; + } else if (this->SplitIndex() <= split_index) { return new_loss_chg > this->loss_chg; } else { return !(this->loss_chg > new_loss_chg); @@ -512,6 +518,7 @@ struct SplitEntry { this->split_value = e.split_value; this->left_sum = e.left_sum; this->right_sum = e.right_sum; + this->default_left = e.default_left; return true; } else { return false; @@ -526,13 +533,11 @@ struct SplitEntry { * \return whether the proposed split is better and can replace current split */ inline bool Update(bst_float new_loss_chg, unsigned split_index, - bst_float new_split_value, bool default_left, + bst_float new_split_value, bool new_default_left, const GradStats &left_sum, const GradStats &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; - if (default_left) { - split_index |= (1U << 31); - } + this->default_left = new_default_left; this->sindex = split_index; this->split_value = new_split_value; this->left_sum = left_sum; @@ -548,9 +553,9 @@ struct SplitEntry { dst.Update(src); } /*!\return feature index to split on */ - inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); } + inline unsigned SplitIndex() const { return sindex; } /*!\return whether missing value goes to left branch */ - inline bool DefaultLeft() const { return (sindex >> 31) != 0; } + inline bool DefaultLeft() const { return default_left; } }; } // namespace tree diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index 5c43567de..f27b57991 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -283,7 +283,9 @@ class MonotonicConstraint final : public SplitEvaluator { bst_float leftweight, bst_float rightweight) override { inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight); - bst_uint newsize = std::max(leftid, rightid) + 1; + + bst_uint newsize = std::max(bst_uint(lower_.size()), bst_uint(std::max(leftid, rightid) + 1u)); + lower_.resize(newsize); upper_.resize(newsize); bst_int constraint = GetConstraint(featureid); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 52633c099..9899ea61d 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -1,8 +1,8 @@ /*! - * Copyright 2017-2018 by Contributors + * Copyright 2017-2019 by Contributors * \file updater_quantile_hist.cc * \brief use quantized feature values to construct a tree - * \author Philip Cho, Tianqi Checn + * \author Philip Cho, Tianqi Checn, Egor Smirnov */ #include #include @@ -41,7 +41,7 @@ void QuantileHistMaker::Init(const std::vector *gpair, DMatrix *dmat, const std::vector &trees) { + // omp_set_nested(1); if (is_gmat_initialized_ == false) { double tstart = dmlc::GetTime(); gmat_.Init(dmat, static_cast(param_.max_bin)); @@ -88,94 +89,231 @@ bool QuantileHistMaker::UpdatePredictionCache( } } -void QuantileHistMaker::Builder::SyncHistograms( - int starting_index, - int sync_count, - RegTree *p_tree) { - builder_monitor_.Start("SyncHistograms"); - this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count); - // use Subtraction Trick - for (auto const& node_pair : nodes_for_subtraction_trick_) { - hist_.AddHistRow(node_pair.first); - SubtractionTrick(hist_[node_pair.first], hist_[node_pair.second], - hist_[(*p_tree)[node_pair.first].Parent()]); - } - builder_monitor_.Stop("SyncHistograms"); -} - -void QuantileHistMaker::Builder::BuildLocalHistograms( - int *starting_index, - int *sync_count, - const GHistIndexMatrix &gmat, - const GHistIndexBlockMatrix &gmatb, - RegTree *p_tree, - const std::vector &gpair_h) { - builder_monitor_.Start("BuildLocalHistograms"); - for (auto const& entry : qexpand_depth_wise_) { - int nid = entry.nid; - RegTree::Node &node = (*p_tree)[nid]; - if (rabit::IsDistributed()) { - if (node.IsRoot() || node.IsLeftChild()) { - hist_.AddHistRow(nid); - // in distributed setting, we always calculate from left child or root node - BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); - if (!node.IsRoot()) { - nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid; - } - (*sync_count)++; - (*starting_index) = std::min((*starting_index), nid); - } - } else { - if (!node.IsRoot() && node.IsLeftChild() && - (row_set_collection_[nid].Size() < - row_set_collection_[(*p_tree)[node.Parent()].RightChild()].Size())) { - hist_.AddHistRow(nid); - BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); - nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid; - (*sync_count)++; - (*starting_index) = std::min((*starting_index), nid); - } else if (!node.IsRoot() && !node.IsLeftChild() && - (row_set_collection_[nid].Size() <= - row_set_collection_[(*p_tree)[node.Parent()].LeftChild()].Size())) { - hist_.AddHistRow(nid); - BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); - nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].LeftChild()] = nid; - (*sync_count)++; - (*starting_index) = std::min((*starting_index), nid); - } else if (node.IsRoot()) { - hist_.AddHistRow(nid); - BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); - (*sync_count)++; - (*starting_index) = std::min((*starting_index), nid); - } - } - } - builder_monitor_.Stop("BuildLocalHistograms"); -} - -void QuantileHistMaker::Builder::BuildNodeStats( +void QuantileHistMaker::Builder::BuildNodeStat( const GHistIndexMatrix &gmat, DMatrix *p_fmat, RegTree *p_tree, - const std::vector &gpair_h) { - builder_monitor_.Start("BuildNodeStats"); - for (auto const& entry : qexpand_depth_wise_) { - int nid = entry.nid; - this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree); - // add constraints - if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) { - // it's a right child - auto parent_id = (*p_tree)[nid].Parent(); - auto left_sibling_id = (*p_tree)[parent_id].LeftChild(); - auto parent_split_feature_id = snode_[parent_id].best.SplitIndex(); - spliteval_->AddSplit(parent_id, left_sibling_id, nid, parent_split_feature_id, - snode_[left_sibling_id].weight, snode_[nid].weight); - } + const std::vector &gpair_h, + int32_t nid) { + + // add constraints + if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) { + auto parent_id = (*p_tree)[nid].Parent(); + // it's a right child + auto left_sibling_id = (*p_tree)[parent_id].LeftChild(); + auto parent_split_feature_id = snode_[parent_id].best.SplitIndex(); + + spliteval_->AddSplit(parent_id, left_sibling_id, nid, parent_split_feature_id, + snode_[left_sibling_id].weight, snode_[nid].weight); } - builder_monitor_.Stop("BuildNodeStats"); } -void QuantileHistMaker::Builder::EvaluateSplits( +void QuantileHistMaker::Builder::BuildNodeStatBatch( + const GHistIndexMatrix &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair_h, + const std::vector& nodes) { + perf_monitor.TickStart(); + for (const auto& node : nodes) { + const int32_t nid = node.nid; + const int32_t sibling_nid = node.sibling_nid; + this->InitNewNode(nid, gmat, gpair_h, *p_fmat, p_tree, &(snode_[nid]), (*p_tree)[nid].Parent()); + if (sibling_nid > -1) { + this->InitNewNode(nid, gmat, gpair_h, *p_fmat, p_tree, + &(snode_[sibling_nid]), (*p_tree)[sibling_nid].Parent()); + } + } + for (const auto& node : nodes) { + const int32_t nid = node.nid; + const int32_t sibling_nid = node.sibling_nid; + BuildNodeStat(gmat, p_fmat, p_tree, gpair_h, nid); + if (sibling_nid > -1) { + BuildNodeStat(gmat, p_fmat, p_tree, gpair_h, sibling_nid); + } + } + perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::INIT_NEW_NODE); +} + +template +inline std::pair PartitionDenseLeftDefaultKernel(const RowIdxType* rid, + const IdxType* idx, const IdxType offset, const int32_t split_cond, + const size_t istart, const size_t iend, RowIdxType* p_left, RowIdxType* p_right) { + size_t ileft = 0; + size_t iright = 0; + + const IdxType max_val = std::numeric_limits::max(); + + for (size_t i = istart; i < iend; i++) { + if (idx[rid[i]] == max_val || static_cast(idx[rid[i]] + offset) <= split_cond) { + p_left[ileft++] = rid[i]; + } else { + p_right[iright++] = rid[i]; + } + } + + return { ileft, iright }; +} + +template +inline std::pair PartitionDenseRightDefaultKernel(const RowIdxType* rid, + const IdxType* idx, const IdxType offset, const int32_t split_cond, + const size_t istart, const size_t iend, RowIdxType* p_left, RowIdxType* p_right) { + size_t ileft = 0; + size_t iright = 0; + + const IdxType max_val = std::numeric_limits::max(); + + for (size_t i = istart; i < iend; i++) { + if (idx[rid[i]] == max_val || static_cast(idx[rid[i]] + offset) > split_cond) { + p_right[iright++] = rid[i]; + } else { + p_left[ileft++] = rid[i]; + } + } + return { ileft, iright }; +} + +template +inline std::pair PartitionSparseKernel(const RowIdxType* rowid, + const IdxType* idx, const int32_t split_cond, const size_t ibegin, + const size_t iend, RowIdxType* p_left, RowIdxType* p_right, + Column column, bool default_left) { + size_t ileft = 0; + size_t iright = 0; + + if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range + // search first nonzero row with index >= rowid[ibegin] + const size_t* p = std::lower_bound(column.GetRowData(), + column.GetRowData() + column.Size(), + rowid[ibegin]); + if (p != column.GetRowData() + column.Size() && *p <= rowid[iend - 1]) { + size_t cursor = p - column.GetRowData(); + + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowid[i]; + while (cursor < column.Size() + && column.GetRowIdx(cursor) < rid + && column.GetRowIdx(cursor) <= rowid[iend - 1]) { + ++cursor; + } + if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { + const uint32_t rbin = column.GetFeatureBinIdx(cursor); + if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { + p_left[ileft++] = rid; + } else { + p_right[iright++] = rid; + } + ++cursor; + } else { + // missing value + if (default_left) { + p_left[ileft++] = rid; + } else { + p_right[iright++] = rid; + } + } + } + } else { // all rows in [ibegin, iend) have missing values + if (default_left) { + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowid[i]; + p_left[ileft++] = rid; + } + } else { + for (size_t i = ibegin; i < iend; ++i) { + const size_t rid = rowid[i]; + p_right[iright++] = rid; + } + } + } + } + return {ileft, iright}; +} + + +int32_t QuantileHistMaker::Builder::FindSplitCond(int32_t nid, + RegTree *p_tree, + const GHistIndexMatrix &gmat) { + bst_float left_leaf_weight = spliteval_->ComputeWeight(nid, + snode_[nid].best.left_sum) * param_.learning_rate; + bst_float right_leaf_weight = spliteval_->ComputeWeight(nid, + snode_[nid].best.right_sum) * param_.learning_rate; + p_tree->ExpandNode(nid, snode_[nid].best.SplitIndex(), snode_[nid].best.split_value, + snode_[nid].best.DefaultLeft(), snode_[nid].weight, left_leaf_weight, + right_leaf_weight, snode_[nid].best.loss_chg, snode_[nid].stats.sum_hess); + + RegTree::Node node = (*p_tree)[nid]; + // Categorize member rows + const bst_uint fid = node.SplitIndex(); + const bst_float split_pt = node.SplitCond(); + const uint32_t lower_bound = gmat.cut.row_ptr[fid]; + const uint32_t upper_bound = gmat.cut.row_ptr[fid + 1]; + int32_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, + static_cast(std::numeric_limits::max())); + for (uint32_t i = lower_bound; i < upper_bound; ++i) { + if (split_pt == gmat.cut.cut[i]) { + split_cond = static_cast(i); + } + } + return split_cond; +} + +// split rows in each node to blocks of rows +// for future parallel execution +template +void QuantileHistMaker::Builder::CreateTasksForApplySplit( + const std::vector& nodes, + const GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + const int depth, + const size_t block_size, + std::vector* tasks, + std::vector* nodes_bounds) { + size_t* buffer = buffer_for_partition_.data(); + size_t cur_buff_offset = 0; + + auto create_nodes = [&](int32_t this_nid) { + if (snode_[this_nid].best.loss_chg < kRtEps || + (param_.max_depth > 0 && depth == param_.max_depth) || + (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { + (*p_tree)[this_nid].SetLeaf(snode_[this_nid].weight * param_.learning_rate); + } else { + const size_t nrows = row_set_collection_[this_nid].Size(); + const size_t n_blocks = nrows / block_size + !!(nrows % block_size); + + nodes_bounds->emplace_back(this_nid, tasks->size(), tasks->size() + n_blocks); + + const int32_t split_cond = FindSplitCond(this_nid, p_tree, gmat); + + for (size_t i = 0; i < n_blocks; ++i) { + const size_t istart = i*block_size; + const size_t iend = (i == n_blocks-1) ? nrows : istart + block_size; + + TaskType task {this_nid, split_cond, n_blocks, i, istart, iend, nodes_bounds->size()-1, + buffer + cur_buff_offset, buffer + cur_buff_offset + (iend-istart), 0, 0, 0, 0}; + tasks->push_back(task); + cur_buff_offset += 2*(iend-istart); + } + } + }; + for (const auto& node : nodes) { + const int32_t nid = node.nid; + const int32_t sibling_nid = node.sibling_nid; + create_nodes(nid); + + if (sibling_nid > -1) { + create_nodes(sibling_nid); + } + } +} + +void QuantileHistMaker::Builder::CreateNewNodesBatch( + const std::vector& nodes, const GHistIndexMatrix &gmat, const ColumnMatrix &column_matrix, DMatrix *p_fmat, @@ -184,49 +322,403 @@ void QuantileHistMaker::Builder::EvaluateSplits( int depth, unsigned *timestamp, std::vector *temp_qexpand_depth) { - for (auto const& entry : qexpand_depth_wise_) { - int nid = entry.nid; - this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree); - if (snode_[nid].best.loss_chg < kRtEps || - (param_.max_depth > 0 && depth == param_.max_depth) || - (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { - (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); + perf_monitor.TickStart(); + const size_t block_size = 2048; + + struct ApplySplitTaskInfo { + // input + int32_t nid; + int32_t split_cond; + size_t n_blocks_this_node; + size_t i_block_this_node; + size_t istart; + size_t iend; + size_t inode; + // result + size_t* left; + size_t* right; + size_t n_left; + size_t n_right; + size_t ileft; + size_t iright; + }; + + struct NodeBoundsInfo { + NodeBoundsInfo(int32_t nid, size_t begin, size_t end): + nid(nid), begin(begin), end(end) { + } + + int32_t nid; + size_t begin; + size_t end; + }; + + // create tasks for partition of row_set_collection_ + std::vector tasks; + std::vector nodes_bounds; + + // 1. Split row-indexes in each nodes to blocks of rows + CreateTasksForApplySplit(nodes, gmat, p_tree, num_leaves, + depth, block_size, &tasks, &nodes_bounds); + + // buffer to store # of rows in left part for each row-block + std::vector left_sizes; + left_sizes.reserve(nodes_bounds.size()); + const int size = tasks.size(); + + // execute tasks in parallel + #pragma omp parallel + { + // 2. For each block of rows: + // a) Write row-indexes which should come to the left child - to 1th buffer + // b) Write row-indexes which should come to the right child - to 2th buffer + // values in each buffer - sorted in original order + #pragma omp for + for (int32_t i = 0; i < size; ++i) { + const int32_t nid = tasks[i].nid; + const int32_t split_cond = tasks[i].split_cond; + const size_t istart = tasks[i].istart; + const size_t iend = tasks[i].iend; + + const bst_uint fid = (*p_tree)[nid].SplitIndex(); + const bool default_left = (*p_tree)[nid].DefaultLeft(); + const Column column = column_matrix.GetColumn(fid); + + const uint32_t* idx = column.GetIndex(); + const size_t* rid = row_set_collection_[nid].begin; + + if (column.GetType() == xgboost::common::kDenseColumn) { + if (default_left) { + auto res = PartitionDenseLeftDefaultKernel( + rid, idx, column.GetBaseIdx(), split_cond, istart, iend, + tasks[i].left, tasks[i].right); + tasks[i].n_left = res.first; + tasks[i].n_right = res.second; + } else { + auto res = PartitionDenseRightDefaultKernel( + rid, idx, column.GetBaseIdx(), split_cond, istart, iend, + tasks[i].left, tasks[i].right); + tasks[i].n_left = res.first; + tasks[i].n_right = res.second; + } + } else { + auto res = PartitionSparseKernel( + rid, idx, split_cond, istart, iend, tasks[i].left, tasks[i].right, column, default_left); + tasks[i].n_left = res.first; + tasks[i].n_right = res.second; + } + } + + // 3. For each node - find number of elements in left the part + #pragma omp single + { + for (auto& node : nodes_bounds) { + size_t n_left = 0; + size_t n_right = 0; + + for (size_t i = node.begin; i < node.end; ++i) { + tasks[i].ileft = n_left; + tasks[i].iright = n_right; + + n_left += tasks[i].n_left; + n_right += tasks[i].n_right; + } + left_sizes.push_back(n_left); + } + } + + // 4. Copy data from buffers to original row_set_collection_ + #pragma omp for + for (int32_t i = 0; i < size; ++i) { + const size_t node_idx = tasks[i].inode; + const int32_t nid = tasks[i].nid; + const size_t n_left = left_sizes[node_idx]; + + CHECK_LE(tasks[i].ileft + tasks[i].n_left, row_set_collection_[nid].Size()); + CHECK_LE(n_left + tasks[i].iright + tasks[i].n_right, row_set_collection_[nid].Size()); + + auto* rid = const_cast(row_set_collection_[nid].begin); + std::memcpy(rid + tasks[i].ileft, tasks[i].left, + tasks[i].n_left * sizeof(rid[0])); + std::memcpy(rid + n_left + tasks[i].iright, tasks[i].right, + tasks[i].n_right * sizeof(rid[0])); + } + } + + // register new nodes + for (size_t i = 0; i < nodes_bounds.size(); ++i) { + const int32_t nid = nodes_bounds[i].nid; + const size_t n_left = left_sizes[i]; + RegTree::Node node = (*p_tree)[nid]; + + const int32_t left_id = node.LeftChild(); + const int32_t right_id = node.RightChild(); + row_set_collection_.AddSplit(nid, n_left, left_id, right_id); + + if (rabit::IsDistributed() || + row_set_collection_[left_id].Size() < row_set_collection_[right_id].Size()) { + temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id, nid, + depth + 1, 0.0, (*timestamp)++)); } else { - this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); - int left_id = (*p_tree)[nid].LeftChild(); - int right_id = (*p_tree)[nid].RightChild(); - temp_qexpand_depth->push_back(ExpandEntry(left_id, - p_tree->GetDepth(left_id), 0.0, (*timestamp)++)); - temp_qexpand_depth->push_back(ExpandEntry(right_id, - p_tree->GetDepth(right_id), 0.0, (*timestamp)++)); - // - 1 parent + 2 new children - (*num_leaves)++; + temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id, nid, + depth + 1, 0.0, (*timestamp)++)); + } + } + perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::APPLY_SPLIT); +} + +std::tuple + QuantileHistMaker::Builder::GetHistBuffer( + std::vector* hist_is_init, std::vector* grad_stats, + size_t block_id, size_t nthread, size_t tid, + std::vector* data_hist, size_t hist_size) { + + const size_t n_hist_for_current_node = hist_is_init->size(); + const size_t hist_id = ((n_hist_for_current_node == nthread) ? tid : block_id); + + common::GradStatHist::GradType* local_data_hist = (*data_hist)[hist_id]; + if (!(*hist_is_init)[hist_id]) { + std::fill(local_data_hist, local_data_hist + hist_size, 0.0f); + (*hist_is_init)[hist_id] = true; + } + + return std::make_tuple(local_data_hist, &(*grad_stats)[hist_id]); +} + +void QuantileHistMaker::Builder::CreateTasksForBuildHist( + size_t block_size_rows, + size_t nthread, + const std::vector& nodes, + std::vector>* hist_buffers, + std::vector>* hist_is_init, + std::vector>* grad_stats, + std::vector* task_nid, + std::vector* task_node_idx, + std::vector* task_block_idx) { + size_t i_hist = 0; + + // prepare tasks for parallel execution + for (size_t i = 0; i < nodes.size(); ++i) { + const int32_t nid = nodes[i].nid; + const int32_t sibling_nid = nodes[i].sibling_nid; + hist_.AddHistRow(nid); + if (sibling_nid > -1) { + hist_.AddHistRow(sibling_nid); + } + const size_t nrows = row_set_collection_[nid].Size(); + const size_t n_local_blocks = nrows / block_size_rows + !!(nrows % block_size_rows); + const size_t n_local_histograms = std::min(nthread, n_local_blocks); + + task_nid->resize(task_nid->size() + n_local_blocks, nid); + for (size_t j = 0; j < n_local_blocks; ++j) { + task_node_idx->push_back(i); + task_block_idx->push_back(j); + } + + (*hist_buffers)[i].clear(); + for (size_t j = 0; j < n_local_histograms; j++) { + (*hist_buffers)[i].push_back( + reinterpret_cast(hist_buff_[i_hist++].data())); + } + (*hist_is_init)[i].clear(); + (*hist_is_init)[i].resize(n_local_histograms, false); + (*grad_stats)[i].resize(n_local_histograms); + } +} + +void QuantileHistMaker::Builder::BuildHistsBatch(const std::vector& nodes, + RegTree* p_tree, const GHistIndexMatrix &gmat, const std::vector& gpair, + std::vector>* hist_buffers, + std::vector>* hist_is_init) { + perf_monitor.TickStart(); + const size_t block_size_rows = 256; + const size_t nthread = static_cast(this->nthread_); + const size_t nbins = gmat.cut.row_ptr.back(); + const size_t hist_size = 2 * nbins; + + hist_buffers->resize(nodes.size()); + hist_is_init->resize(nodes.size()); + + // input data for tasks + std::vector task_nid; + std::vector task_node_idx; + std::vector task_block_idx; + + // result vector + std::vector> grad_stats(nodes.size()); + + // 1. Create tasks for hist construction by block of rows for each node + CreateTasksForBuildHist(block_size_rows, nthread, nodes, hist_buffers, hist_is_init, &grad_stats, + &task_nid, &task_node_idx, &task_block_idx); + int32_t n_hist_buidling_tasks = task_node_idx.size(); + + const GradientPair::ValueT* const pgh = + reinterpret_cast(gpair.data()); + + // 2. Build partial histograms for each node + #pragma omp parallel for schedule(guided) + for (int32_t itask = 0; itask < n_hist_buidling_tasks; ++itask) { + const size_t tid = omp_get_thread_num(); + const int32_t nid = task_nid[itask]; + const int32_t block_id = task_block_idx[itask]; + // node_idx : location of node `nid` within the `nodes` list. In general, node_idx != nid + const int32_t node_idx = task_node_idx[itask]; + + common::GradStatHist::GradType* data_local_hist; + common::GradStatHist* grad_stat; // total gradient/hessian value for node `nid` + std::tie(data_local_hist, grad_stat) = GetHistBuffer(&(*hist_is_init)[node_idx], + &grad_stats[node_idx], block_id, nthread, tid, + &(*hist_buffers)[node_idx], hist_size); + + const size_t* row_ptr = gmat.row_ptr.data(); + const size_t* rid = row_set_collection_[nid].begin; + + const size_t nrows = row_set_collection_[nid].Size(); + const size_t istart = block_id * block_size_rows; + const size_t iend = (((block_id+1)*block_size_rows > nrows) ? nrows : istart + block_size_rows); + + // call hist building kernel depending on bin-matrix layout + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + common::BuildHistLocalDense(istart, iend, nrows, rid, gmat.index.data(), pgh, + row_ptr, data_local_hist, grad_stat); + } else { + common::BuildHistLocalSparse(istart, iend, nrows, rid, gmat.index.data(), pgh, + row_ptr, data_local_hist, grad_stat); + } + } + + // 3. Merge grad stats for each node + // Sync histograms in case of distributed computation + SyncHistograms(p_tree, nodes, hist_buffers, hist_is_init, grad_stats); + + perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::BUILD_HIST); +} + +void QuantileHistMaker::Builder::SyncHistograms( + RegTree* p_tree, + const std::vector& nodes, + std::vector>* hist_buffers, + std::vector>* hist_is_init, + const std::vector>& grad_stats) { + if (rabit::IsDistributed()) { + const int size = nodes.size(); + #pragma omp parallel for // TODO(egorsmir): replace to n_features * nodes.size() + for (int i = 0; i < size; ++i) { + const int32_t nid = nodes[i].nid; + common::GradStatHist::GradType* hist_data = + reinterpret_cast(hist_[nid].data()); + + ReduceHistograms(hist_data, nullptr, nullptr, 0, hist_builder_.GetNumBins() * 2, i, + *hist_is_init, *hist_buffers); + } + + for (auto elem : nodes) { + this->histred_.Allreduce(hist_[elem.nid].data(), hist_builder_.GetNumBins()); + } + + // TODO(egorsmir): add parallel for + for (auto elem : nodes) { + if (elem.sibling_nid > -1) { + SubtractionTrick(hist_[elem.sibling_nid], hist_[elem.nid], + hist_[(*p_tree)[elem.sibling_nid].Parent()]); + } + } + } + + // merge grad stats + { + for (size_t inode = 0; inode < nodes.size(); ++inode) { + const int32_t nid = nodes[inode].nid; + + if (snode_.size() <= size_t(nid)) { + snode_.resize(nid + 1, NodeEntry(param_)); + } + + common::GradStatHist grad_stat; + for (size_t ihist = 0; ihist < (*hist_is_init)[inode].size(); ++ihist) { + if ((*hist_is_init)[inode][ihist]) { + grad_stat.Add(grad_stats[inode][ihist]); + } + } + this->histred_.Allreduce(&grad_stat, 1); + snode_[nid].stats = grad_stat.ToGradStat(); + + const int32_t sibling_nid = nodes[inode].sibling_nid; + if (sibling_nid > -1) { + if (snode_.size() <= size_t(sibling_nid)) { + snode_.resize(sibling_nid + 1, NodeEntry(param_)); + } + const int parent_id = (*p_tree)[nid].Parent(); + snode_[sibling_nid].stats.SetSubstract(snode_[parent_id].stats, snode_[nid].stats); + } } } } -void QuantileHistMaker::Builder::ExpandWithDepthWidth( +// merge some block of partial histograms +void QuantileHistMaker::Builder::ReduceHistograms( + common::GradStatHist::GradType* hist_data, + common::GradStatHist::GradType* sibling_hist_data, + common::GradStatHist::GradType* parent_hist_data, + const size_t ibegin, + const size_t iend, + const size_t inode, + const std::vector>& hist_is_init, + const std::vector>& hist_buffers) { + bool is_init = false; + for (size_t ihist = 0; ihist < hist_is_init[inode].size(); ++ihist) { + common::GradStatHist::GradType* partial_data = hist_buffers[inode][ihist]; + if (hist_is_init[inode][ihist] && is_init) { + for (size_t i = ibegin; i < iend; ++i) { + hist_data[i] += partial_data[i]; + } + } else if (hist_is_init[inode][ihist]) { + for (size_t i = ibegin; i < iend; ++i) { + hist_data[i] = partial_data[i]; + } + is_init = true; + } + } + + if (sibling_hist_data) { + for (size_t i = ibegin; i < iend; ++i) { + sibling_hist_data[i] = parent_hist_data[i] - hist_data[i]; + } + } +} + +void QuantileHistMaker::Builder::ExpandWithDepthWise( const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, const ColumnMatrix &column_matrix, - DMatrix *p_fmat, - RegTree *p_tree, + DMatrix* p_fmat, + RegTree* p_tree, const std::vector &gpair_h) { unsigned timestamp = 0; int num_leaves = 0; // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway - qexpand_depth_wise_.emplace_back(ExpandEntry(0, p_tree->GetDepth(0), 0.0, timestamp++)); + qexpand_depth_wise_.emplace_back(0, -1, ROOT_PARENT_ID, p_tree->GetDepth(0), 0.0, timestamp++); ++num_leaves; + for (int depth = 0; depth < param_.max_depth + 1; depth++) { - int starting_index = std::numeric_limits::max(); - int sync_count = 0; std::vector temp_qexpand_depth; - BuildLocalHistograms(&starting_index, &sync_count, gmat, gmatb, p_tree, gpair_h); - SyncHistograms(starting_index, sync_count, p_tree); - BuildNodeStats(gmat, p_fmat, p_tree, gpair_h); - EvaluateSplits(gmat, column_matrix, p_fmat, p_tree, &num_leaves, depth, ×tamp, - &temp_qexpand_depth); + + // buffer to store partial histograms + std::vector> hist_buffers; + // uint8_t is used instead of bool due to read/write + // to std::vector - thread unsafe + std::vector> hist_is_init; + + BuildHistsBatch(qexpand_depth_wise_, p_tree, gmat, gpair_h, + &hist_buffers, &hist_is_init); + BuildNodeStatBatch(gmat, p_fmat, p_tree, gpair_h, qexpand_depth_wise_); + EvaluateSplitsBatch(qexpand_depth_wise_, gmat, *p_fmat, hist_is_init, hist_buffers); + CreateNewNodesBatch(qexpand_depth_wise_, gmat, column_matrix, p_fmat, p_tree, + &num_leaves, depth, ×tamp, &temp_qexpand_depth); + + num_leaves += temp_qexpand_depth.size(); + // clean up qexpand_depth_wise_.clear(); nodes_for_subtraction_trick_.clear(); @@ -246,18 +738,21 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h) { - unsigned timestamp = 0; int num_leaves = 0; + std::vector> hist_buffers; + std::vector> hist_is_init; + for (int nid = 0; nid < p_tree->param.num_roots; ++nid) { - hist_.AddHistRow(nid); - BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], true); + std::vector nodes_to_build{ExpandEntry( + 0, -1, ROOT_PARENT_ID, p_tree->GetDepth(0), 0.0, timestamp++)}; - this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree); + BuildHistsBatch(nodes_to_build, p_tree, gmat, gpair_h, &hist_buffers, &hist_is_init); + BuildNodeStatBatch(gmat, p_fmat, p_tree, gpair_h, nodes_to_build); + EvaluateSplitsBatch(nodes_to_build, gmat, *p_fmat, hist_is_init, hist_buffers); - this->EvaluateSplit(nid, gmat, hist_, *p_fmat, *p_tree); - qexpand_loss_guided_->push(ExpandEntry(nid, p_tree->GetDepth(nid), + qexpand_loss_guided_->push(ExpandEntry(nid, -1, -1, p_tree->GetDepth(nid), snode_[nid].best.loss_chg, timestamp++)); ++num_leaves; @@ -265,50 +760,29 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide( while (!qexpand_loss_guided_->empty()) { const ExpandEntry candidate = qexpand_loss_guided_->top(); - const int nid = candidate.nid; + const int32_t nid = candidate.nid; qexpand_loss_guided_->pop(); - if (candidate.loss_chg <= kRtEps - || (param_.max_depth > 0 && candidate.depth == param_.max_depth) - || (param_.max_leaves > 0 && num_leaves == param_.max_leaves) ) { - (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate); - } else { - this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); - const int cleft = (*p_tree)[nid].LeftChild(); - const int cright = (*p_tree)[nid].RightChild(); - hist_.AddHistRow(cleft); - hist_.AddHistRow(cright); + std::vector nodes_to_build{candidate}; + std::vector successors; - if (rabit::IsDistributed()) { - // in distributed mode, we need to keep consistent across workers - BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true); - SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); - } else { - if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { - BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true); - SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); - } else { - BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright], true); - SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]); - } - } + CreateNewNodesBatch(nodes_to_build, gmat, column_matrix, p_fmat, p_tree, + &num_leaves, candidate.depth, ×tamp, &successors); - this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree); - this->InitNewNode(cright, gmat, gpair_h, *p_fmat, *p_tree); - bst_uint featureid = snode_[nid].best.SplitIndex(); - spliteval_->AddSplit(nid, cleft, cright, featureid, - snode_[cleft].weight, snode_[cright].weight); + if (!successors.empty()) { + BuildHistsBatch(successors, p_tree, gmat, gpair_h, &hist_buffers, &hist_is_init); + BuildNodeStatBatch(gmat, p_fmat, p_tree, gpair_h, successors); + EvaluateSplitsBatch(successors, gmat, *p_fmat, hist_is_init, hist_buffers); - this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree); - this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree); + const int32_t cleft = (*p_tree)[nid].LeftChild(); + const int32_t cright = (*p_tree)[nid].RightChild(); - qexpand_loss_guided_->push(ExpandEntry(cleft, p_tree->GetDepth(cleft), + qexpand_loss_guided_->push(ExpandEntry(cleft, -1, nid, p_tree->GetDepth(cleft), snode_[cleft].best.loss_chg, timestamp++)); - qexpand_loss_guided_->push(ExpandEntry(cright, p_tree->GetDepth(cright), + qexpand_loss_guided_->push(ExpandEntry(cright, -1, nid, p_tree->GetDepth(cright), snode_[cright].best.loss_chg, timestamp++)); - ++num_leaves; // give two and take one, as parent is no longer a leaf } } @@ -320,34 +794,36 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree) { - builder_monitor_.Start("Update"); + perf_monitor.StartPerfMonitor(); const std::vector& gpair_h = gpair->ConstHostVector(); - spliteval_->Reset(); + perf_monitor.TickStart(); this->InitData(gmat, gpair_h, *p_fmat, *p_tree); + perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::INIT_DATA); if (param_.grow_policy == TrainParam::kLossGuide) { ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h); } else { - ExpandWithDepthWidth(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h); + ExpandWithDepthWise(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h); } for (int nid = 0; nid < p_tree->param.num_nodes; ++nid) { p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg; p_tree->Stat(nid).base_weight = snode_[nid].weight; - p_tree->Stat(nid).sum_hess = static_cast(snode_[nid].stats.sum_hess); + p_tree->Stat(nid).sum_hess = + static_cast(snode_[nid].stats.sum_hess); } pruner_->Update(gpair, p_fmat, std::vector{p_tree}); - builder_monitor_.Stop("Update"); + perf_monitor.EndPerfMonitor(); } bool QuantileHistMaker::Builder::UpdatePredictionCache( - const DMatrix* data, - HostDeviceVector* p_out_preds) { + const DMatrix* data, + HostDeviceVector* p_out_preds) { std::vector& out_preds = p_out_preds->HostVector(); // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in @@ -363,8 +839,31 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( CHECK_GT(out_preds.size(), 0U); - for (const RowSetCollection::Elem rowset : row_set_collection_) { - if (rowset.begin != nullptr && rowset.end != nullptr) { + const size_t block_size = 2048; + const size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); + std::vector tasks_elem; + std::vector tasks_iblock; + std::vector tasks_nblock; + + for (size_t k = 0; k < n_nodes; ++k) { + const size_t nrows = row_set_collection_[k].Size(); + const size_t nblocks = nrows / block_size + !!(nrows % block_size); + + for (size_t i = 0; i < nblocks; ++i) { + tasks_elem.push_back(row_set_collection_[k]); + tasks_iblock.push_back(i); + tasks_nblock.push_back(nblocks); + } + } + + #pragma omp parallel for schedule(guided) + for (int32_t k = 0; k < tasks_elem.size(); ++k) { + const RowSetCollection::Elem rowset = tasks_elem[k]; + if (rowset.begin != nullptr && rowset.end != nullptr && rowset.node_id != -1) { + const size_t nrows = rowset.Size(); + const size_t iblock = tasks_iblock[k]; + const size_t nblocks = tasks_nblock[k]; + int nid = rowset.node_id; bst_float leaf_value; // if a node is marked as deleted by the pruner, traverse upward to locate @@ -377,8 +876,11 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } leaf_value = (*p_last_tree_)[nid].LeafValue(); - for (const size_t* it = rowset.begin; it < rowset.end; ++it) { - out_preds[*it] += leaf_value; + const size_t istart = iblock*block_size; + const size_t iend = (iblock == nblocks-1) ? nrows : istart + block_size; + + for (size_t it = istart; it < iend; ++it) { + out_preds[rowset.begin[it]] += leaf_value; } } } @@ -399,7 +901,6 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) " << "when grow_policy is depthwise."; } - builder_monitor_.Start("InitData"); const auto& info = fmat.Info(); { @@ -410,12 +911,16 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, // initialize histogram collection uint32_t nbins = gmat.cut.row_ptr.back(); hist_.Init(nbins); + hist_buff_.Init(nbins); // initialize histogram builder -#pragma omp parallel + #pragma omp parallel { this->nthread_ = omp_get_num_threads(); } + + const auto nthread = static_cast(this->nthread_); + row_split_tloc_.resize(nthread); hist_builder_.Init(this->nthread_, nbins); CHECK_EQ(info.root_index_.size(), 0U); @@ -457,7 +962,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, } bool has_neg_hess = false; - for (size_t tid = 0; tid < this->nthread_; ++tid) { + for (int32_t tid = 0; tid < this->nthread_; ++tid) { if (p_buff[tid]) { has_neg_hess = true; } @@ -485,8 +990,8 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, } } } - row_set_collection_.Init(); + buffer_for_partition_.reserve(2 * info.num_row_); { /* determine layout of data */ @@ -549,290 +1054,121 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, qexpand_depth_wise_.clear(); } } - builder_monitor_.Stop("InitData"); } -void QuantileHistMaker::Builder::EvaluateSplit(const int nid, - const GHistIndexMatrix& gmat, - const HistCollection& hist, - const DMatrix& fmat, - const RegTree& tree) { - builder_monitor_.Start("EvaluateSplit"); - // start enumeration +void QuantileHistMaker::Builder::EvaluateSplitsBatch( + const std::vector& nodes, + const GHistIndexMatrix& gmat, + const DMatrix& fmat, + const std::vector>& hist_is_init, + const std::vector>& hist_buffers) { + perf_monitor.TickStart(); const MetaInfo& info = fmat.Info(); - auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); - const auto& feature_set = p_feature_set->HostVector(); - const auto nfeature = static_cast(feature_set.size()); - const auto nthread = static_cast(this->nthread_); - best_split_tloc_.resize(nthread); -#pragma omp parallel for schedule(static) num_threads(nthread) - for (bst_omp_uint tid = 0; tid < nthread; ++tid) { - best_split_tloc_[tid] = snode_[nid].best; - } - GHistRow node_hist = hist[nid]; + // prepare tasks + std::vector> tasks; + for (size_t i = 0; i < nodes.size(); ++i) { + auto p_feature_set = column_sampler_.GetFeatureSet(nodes[i].depth); -#pragma omp parallel for schedule(dynamic) num_threads(nthread) - for (bst_omp_uint i = 0; i < nfeature; ++i) { // NOLINT(*) - const auto feature_id = static_cast(feature_set[i]); - const auto tid = static_cast(omp_get_thread_num()); - const auto node_id = static_cast(nid); - // Narrow search space by dropping features that are not feasible under the - // given set of constraints (e.g. feature interaction constraints) - if (spliteval_->CheckFeatureConstraint(node_id, feature_id)) { - this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, - &best_split_tloc_[tid], feature_id, node_id); - this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, - &best_split_tloc_[tid], feature_id, node_id); - } - } - for (unsigned tid = 0; tid < nthread; ++tid) { - snode_[nid].best.Update(best_split_tloc_[tid]); - } - builder_monitor_.Stop("EvaluateSplit"); -} - -void QuantileHistMaker::Builder::ApplySplit(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree) { - builder_monitor_.Start("ApplySplit"); - // TODO(hcho3): support feature sampling by levels - - /* 1. Create child nodes */ - NodeEntry& e = snode_[nid]; - bst_float left_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate; - bst_float right_leaf_weight = - spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate; - p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, - e.best.DefaultLeft(), e.weight, left_leaf_weight, - right_leaf_weight, e.best.loss_chg, e.stats.sum_hess); - - /* 2. Categorize member rows */ - const auto nthread = static_cast(this->nthread_); - row_split_tloc_.resize(nthread); - for (bst_omp_uint i = 0; i < nthread; ++i) { - row_split_tloc_[i].left.clear(); - row_split_tloc_[i].right.clear(); - } - const bool default_left = (*p_tree)[nid].DefaultLeft(); - const bst_uint fid = (*p_tree)[nid].SplitIndex(); - const bst_float split_pt = (*p_tree)[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut.row_ptr[fid]; - const uint32_t upper_bound = gmat.cut.row_ptr[fid + 1]; - int32_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, - static_cast(std::numeric_limits::max())); - for (uint32_t i = lower_bound; i < upper_bound; ++i) { - if (split_pt == gmat.cut.cut[i]) { - split_cond = static_cast(i); + const auto& feature_set = p_feature_set->HostVector(); + const auto nfeature = static_cast(feature_set.size()); + for (size_t j = 0; j < nfeature; ++j) { + tasks.emplace_back(i, feature_set[j]); } } - const auto& rowset = row_set_collection_[nid]; + // partial results + std::vector> splits(tasks.size()); + // parallel enumeration + #pragma omp parallel for schedule(guided) + for (int32_t i = 0; i < tasks.size(); ++i) { + // node_idx : offset within `nodes` list + const int32_t node_idx = tasks[i].first; + const size_t fid = tasks[i].second; + const int32_t nid = nodes[node_idx].nid; // usually node_idx != nid + const int32_t sibling_nid = nodes[node_idx].sibling_nid; + const int32_t parent_nid = nodes[node_idx].parent_nid; - Column column = column_matrix.GetColumn(fid); - if (column.GetType() == xgboost::common::kDenseColumn) { - ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond, - default_left); - } else { - ApplySplitSparseData(rowset, gmat, &row_split_tloc_, column, lower_bound, - upper_bound, split_cond, default_left); - } + common::GradStatHist::GradType* hist_data = + reinterpret_cast(hist_[nid].data()); + common::GradStatHist::GradType* sibling_hist_data = sibling_nid > -1 ? + reinterpret_cast( + hist_[sibling_nid].data()) : nullptr; + common::GradStatHist::GradType* parent_hist_data = sibling_nid > -1 ? + reinterpret_cast(hist_[parent_nid].data()) : nullptr; - row_set_collection_.AddSplit( - nid, row_split_tloc_, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild()); - builder_monitor_.Stop("ApplySplit"); -} - -void QuantileHistMaker::Builder::ApplySplitDenseData( - const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_int split_cond, - bool default_left) { - std::vector& row_split_tloc = *p_row_split_tloc; - constexpr int kUnroll = 8; // loop unrolling factor - const size_t nrows = rowset.end - rowset.begin; - const size_t rest = nrows % kUnroll; - -#pragma omp parallel for num_threads(nthread_) schedule(static) - for (bst_omp_uint i = 0; i < nrows - rest; i += kUnroll) { - const bst_uint tid = omp_get_thread_num(); - auto& left = row_split_tloc[tid].left; - auto& right = row_split_tloc[tid].right; - size_t rid[kUnroll]; - uint32_t rbin[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - rid[k] = rowset.begin[i + k]; + // reduce needed part of a hist here to have it in cache before enumeration + if (!rabit::IsDistributed()) { + const std::vector& cut_ptr = gmat.cut.row_ptr; + const size_t ibegin = 2 * cut_ptr[fid]; + const size_t iend = 2 * cut_ptr[fid + 1]; + ReduceHistograms(hist_data, sibling_hist_data, parent_hist_data, ibegin, iend, node_idx, + hist_is_init, hist_buffers); } - for (int k = 0; k < kUnroll; ++k) { - rbin[k] = column.GetFeatureBinIdx(rid[k]); + + if (spliteval_->CheckFeatureConstraint(nid, fid)) { + auto& snode = snode_[nid]; + const bool compute_backward = this->EnumerateSplit(+1, gmat, hist_[nid], snode, + info, &splits[i].first, fid, nid); + + // Sometimes, we don't need to enumerate backward because forward and backward + // enumeration will give same loss values. This is the case if the particular feature + // column contains no missing values. So enumerate backward only if it's necessary. + if (compute_backward) { + this->EnumerateSplit(-1, gmat, hist_[nid], snode, info, + &splits[i].first, fid, nid); + } } - for (int k = 0; k < kUnroll; ++k) { // NOLINT - if (rbin[k] == std::numeric_limits::max()) { // missing value - if (default_left) { - left.push_back(rid[k]); - } else { - right.push_back(rid[k]); - } - } else { - if (static_cast(rbin[k] + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid[k]); - } else { - right.push_back(rid[k]); - } + + if (sibling_nid > -1 && spliteval_->CheckFeatureConstraint(sibling_nid, fid)) { + auto& snode = snode_[sibling_nid]; + + const bool compute_backward = this->EnumerateSplit(+1, gmat, hist_[sibling_nid], snode, + info, &splits[i].second, fid, sibling_nid); + + if (compute_backward) { + this->EnumerateSplit(-1, gmat, hist_[sibling_nid], snode, info, + &splits[i].second, fid, sibling_nid); } } } - for (size_t i = nrows - rest; i < nrows; ++i) { - auto& left = row_split_tloc[nthread_-1].left; - auto& right = row_split_tloc[nthread_-1].right; - const size_t rid = rowset.begin[i]; - const uint32_t rbin = column.GetFeatureBinIdx(rid); - if (rbin == std::numeric_limits::max()) { // missing value - if (default_left) { - left.push_back(rid); - } else { - right.push_back(rid); - } - } else { - if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid); - } else { - right.push_back(rid); - } + + // choice of the best splits + for (size_t i = 0; i < splits.size(); ++i) { + const int32_t node_idx = tasks[i].first; + const int32_t nid = nodes[node_idx].nid; + const int32_t sibling_nid = nodes[node_idx].sibling_nid; + snode_[nid].best.Update(splits[i].first); + if (sibling_nid > -1) { + snode_[sibling_nid].best.Update(splits[i].second); } } -} -void QuantileHistMaker::Builder::ApplySplitSparseData( - const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_uint lower_bound, - bst_uint upper_bound, - bst_int split_cond, - bool default_left) { - std::vector& row_split_tloc = *p_row_split_tloc; - const size_t nrows = rowset.end - rowset.begin; - -#pragma omp parallel num_threads(nthread_) - { - const auto tid = static_cast(omp_get_thread_num()); - const size_t ibegin = tid * nrows / nthread_; - const size_t iend = (tid + 1) * nrows / nthread_; - if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range - // search first nonzero row with index >= rowset[ibegin] - const size_t* p = std::lower_bound(column.GetRowData(), - column.GetRowData() + column.Size(), - rowset.begin[ibegin]); - - auto& left = row_split_tloc[tid].left; - auto& right = row_split_tloc[tid].right; - if (p != column.GetRowData() + column.Size() && *p <= rowset.begin[iend - 1]) { - size_t cursor = p - column.GetRowData(); - - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - while (cursor < column.Size() - && column.GetRowIdx(cursor) < rid - && column.GetRowIdx(cursor) <= rowset.begin[iend - 1]) { - ++cursor; - } - if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { - const uint32_t rbin = column.GetFeatureBinIdx(cursor); - if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { - left.push_back(rid); - } else { - right.push_back(rid); - } - ++cursor; - } else { - // missing value - if (default_left) { - left.push_back(rid); - } else { - right.push_back(rid); - } - } - } - } else { // all rows in [ibegin, iend) have missing values - if (default_left) { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - left.push_back(rid); - } - } else { - for (size_t i = ibegin; i < iend; ++i) { - const size_t rid = rowset.begin[i]; - right.push_back(rid); - } - } - } - } - } + perf_monitor.UpdatePerfTimer(TreeGrowingPerfMonitor::timer_name::EVALUATE_SPLIT); } void QuantileHistMaker::Builder::InitNewNode(int nid, const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, - const RegTree& tree) { - builder_monitor_.Start("InitNewNode"); - { - snode_.resize(tree.param.num_nodes, NodeEntry(param_)); - } - - { - auto& stats = snode_[nid].stats; - GHistRow hist = hist_[nid]; - if (tree[nid].IsRoot()) { - if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { - const std::vector& row_ptr = gmat.cut.row_ptr; - const uint32_t ibegin = row_ptr[fid_least_bins_]; - const uint32_t iend = row_ptr[fid_least_bins_ + 1]; - auto begin = hist.data(); - for (uint32_t i = ibegin; i < iend; ++i) { - const GradStats et = begin[i]; - stats.Add(et.sum_grad, et.sum_hess); - } - } else { - const RowSetCollection::Elem e = row_set_collection_[nid]; - for (const size_t* it = e.begin; it < e.end; ++it) { - stats.Add(gpair[*it]); - } - } - histred_.Allreduce(&snode_[nid].stats, 1); - } else { - int parent_id = tree[nid].Parent(); - if (tree[nid].IsLeftChild()) { - snode_[nid].stats = snode_[parent_id].best.left_sum; - } else { - snode_[nid].stats = snode_[parent_id].best.right_sum; - } - } - } - + RegTree* tree, + QuantileHistMaker::NodeEntry* snode, + int32_t parentid) { // calculating the weights { - bst_uint parentid = tree[nid].Parent(); - snode_[nid].weight = static_cast( - spliteval_->ComputeWeight(parentid, snode_[nid].stats)); - snode_[nid].root_gain = static_cast( - spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight)); + snode->weight = static_cast( + spliteval_->ComputeWeight(parentid, snode->stats)); + snode->root_gain = static_cast( + spliteval_->ComputeScore(parentid, snode->stats, + snode->weight)); } - builder_monitor_.Stop("InitNewNode"); } // enumerate the split values of specific feature -void QuantileHistMaker::Builder::EnumerateSplit(int d_step, +// d_step: +1 or -1, indicating direction at which we scan candidate thresholds in order +// fid: feature for which we seek to pick best threshold +// Returns false if we don't need to enumerate in opposite direction. +// This is the case if the particular feature (fid) column contains no missing values. +bool QuantileHistMaker::Builder::EnumerateSplit(int d_step, const GHistIndexMatrix& gmat, const GHistRow& hist, const NodeEntry& snode, @@ -871,39 +1207,54 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step, iend = static_cast(cut_ptr[fid]) - 1; } - for (int32_t i = ibegin; i != iend; i += d_step) { - // start working - // try to find a split - e.Add(hist[i].GetGrad(), hist[i].GetHess()); - if (e.sum_hess >= param_.min_child_weight) { - c.SetSubstract(snode.stats, e); - if (c.sum_hess >= param_.min_child_weight) { - bst_float loss_chg; - bst_float split_pt; - if (d_step > 0) { - // forward enumeration: split at right bound of each bin - loss_chg = static_cast( - spliteval_->ComputeSplitScore(nodeID, fid, e, c) - - snode.root_gain); - split_pt = cut_val[i]; - best.Update(loss_chg, fid, split_pt, d_step == -1, e, c); - } else { + if (d_step == 1) { + for (int32_t i = ibegin; i < iend; i++) { + e.Add(hist[i].GetGrad(), hist[i].GetHess()); + if (e.sum_hess >= param_.min_child_weight) { + c.SetSubstract(snode.stats, e); + if (c.sum_hess >= param_.min_child_weight) { + bst_float loss_chg = static_cast(spliteval_->ComputeSplitScore(nodeID, + fid, e, c) - snode.root_gain); + bst_float split_pt = cut_val[i]; + best.Update(loss_chg, fid, split_pt, false, e, c); + } + } + } + p_best->Update(best); + + if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { + return false; + } + } else { + for (int32_t i = ibegin; i != iend; i--) { + e.Add(hist[i].GetGrad(), hist[i].GetHess()); + if (e.sum_hess >= param_.min_child_weight) { + c.SetSubstract(snode.stats, e); + if (c.sum_hess >= param_.min_child_weight) { + bst_float split_pt; // backward enumeration: split at left bound of each bin - loss_chg = static_cast( + bst_float loss_chg = static_cast( spliteval_->ComputeSplitScore(nodeID, fid, c, e) - snode.root_gain); + if (i == imin) { // for leftmost bin, left bound is the smallest feature value split_pt = gmat.cut.min_val[fid]; } else { split_pt = cut_val[i - 1]; } - best.Update(loss_chg, fid, split_pt, d_step == -1, c, e); + best.Update(loss_chg, fid, split_pt, true, c, e); } } } + p_best->Update(best); + + if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { + return false; + } } - p_best->Update(best); + + return true; } XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 17688f86a..592257321 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -1,8 +1,8 @@ /*! - * Copyright 2017-2018 by Contributors + * Copyright 2017-2019 by Contributors * \file updater_quantile_hist.h * \brief use quantized feature values to construct a tree - * \author Philip Cho, Tianqi Chen + * \author Philip Cho, Tianqi Chen, Egor Smirnov */ #ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ @@ -18,51 +18,19 @@ #include #include #include +#include #include "./param.h" #include "./split_evaluator.h" #include "../common/random.h" -#include "../common/timer.h" #include "../common/hist_util.h" #include "../common/row_set.h" #include "../common/column_matrix.h" namespace xgboost { - -/*! - * \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated. - */ -template -class MemStackAllocator { - public: - explicit MemStackAllocator(size_t required_size): required_size_(required_size) { - } - - T* Get() { - if (!ptr_) { - if (MaxStackSize >= required_size_) { - ptr_ = stack_mem_; - } else { - ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); - do_free_ = true; - } - } - - return ptr_; - } - - ~MemStackAllocator() { - if (do_free_) free(ptr_); - } - - - private: - T* ptr_ = nullptr; - bool do_free_ = false; - size_t required_size_; - T stack_mem_[MaxStackSize]; -}; - +namespace common { + struct GradStatHist; +} namespace tree { using xgboost::common::HistCutMatrix; @@ -88,6 +56,7 @@ class QuantileHistMaker: public TreeUpdater { bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector* out_preds) override; + protected: // training parameter TrainParam param_; @@ -100,6 +69,7 @@ class QuantileHistMaker: public TreeUpdater { bool is_gmat_initialized_; // data structure + public: struct NodeEntry { /*! \brief statics for node entry */ GradStats stats; @@ -111,7 +81,8 @@ class QuantileHistMaker: public TreeUpdater { SplitEntry best; // constructor explicit NodeEntry(const TrainParam& param) - : root_gain(0.0f), weight(0.0f) {} + : root_gain(0.0f), weight(0.0f) { + } }; // actual builder that runs the algorithm @@ -121,11 +92,8 @@ class QuantileHistMaker: public TreeUpdater { explicit Builder(const TrainParam& param, std::unique_ptr pruner, std::unique_ptr spliteval) - : param_(param), pruner_(std::move(pruner)), - spliteval_(std::move(spliteval)), p_last_tree_(nullptr), - p_last_fmat_(nullptr) { - builder_monitor_.Init("Quantile::Builder"); - } + : param_(param), pruner_(std::move(pruner)), spliteval_(std::move(spliteval)), + p_last_tree_(nullptr), p_last_fmat_(nullptr) { } // update one tree, growing virtual void Update(const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, @@ -134,42 +102,104 @@ class QuantileHistMaker: public TreeUpdater { DMatrix* p_fmat, RegTree* p_tree); - inline void BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - const GHistIndexBlockMatrix& gmatb, - GHistRow hist, - bool sync_hist) { - builder_monitor_.Start("BuildHist"); - if (param_.enable_feature_grouping > 0) { - hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist); - } else { - hist_builder_.BuildHist(gpair, row_indices, gmat, hist); - } - if (sync_hist) { - this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins()); - } - builder_monitor_.Stop("BuildHist"); - } - - inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { - builder_monitor_.Start("SubtractionTrick"); - hist_builder_.SubtractionTrick(self, sibling, parent); - builder_monitor_.Stop("SubtractionTrick"); - } - bool UpdatePredictionCache(const DMatrix* data, HostDeviceVector* p_out_preds); + std::tuple + GetHistBuffer(std::vector* hist_is_init, + std::vector* grad_stats, size_t block_id, size_t nthread, + size_t tid, std::vector* data_hist, size_t hist_size); + protected: /* tree growing policies */ struct ExpandEntry { int nid; + int sibling_nid; + int parent_nid; int depth; bst_float loss_chg; unsigned timestamp; - ExpandEntry(int nid, int depth, bst_float loss_chg, unsigned tstmp) - : nid(nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} + ExpandEntry(int nid, int sibling_nid, int parent_nid, int depth, bst_float loss_chg, + unsigned tstmp) : nid(nid), sibling_nid(sibling_nid), parent_nid(parent_nid), + depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} + }; + + struct TreeGrowingPerfMonitor { + enum timer_name {INIT_DATA, INIT_NEW_NODE, BUILD_HIST, EVALUATE_SPLIT, APPLY_SPLIT}; + + double global_start; + + // performance counters + double tstart; + double time_init_data = 0; + double time_init_new_node = 0; + double time_build_hist = 0; + double time_evaluate_split = 0; + double time_apply_split = 0; + + inline void StartPerfMonitor() { + global_start = dmlc::GetTime(); + } + + inline void EndPerfMonitor() { + CHECK_GT(global_start, 0); + double total_time = dmlc::GetTime() - global_start; + LOG(INFO) << "\nInitData: " + << std::fixed << std::setw(6) << std::setprecision(4) << time_init_data + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_init_data / total_time * 100 << "%)\n" + << "InitNewNode: " + << std::fixed << std::setw(6) << std::setprecision(4) << time_init_new_node + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_init_new_node / total_time * 100 << "%)\n" + << "BuildHist: " + << std::fixed << std::setw(6) << std::setprecision(4) << time_build_hist + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_build_hist / total_time * 100 << "%)\n" + << "EvaluateSplit: " + << std::fixed << std::setw(6) << std::setprecision(4) << time_evaluate_split + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_evaluate_split / total_time * 100 << "%)\n" + << "ApplySplit: " + << std::fixed << std::setw(6) << std::setprecision(4) << time_apply_split + << " (" << std::fixed << std::setw(5) << std::setprecision(2) + << time_apply_split / total_time * 100 << "%)\n" + << "========================================\n" + << "Total: " + << std::fixed << std::setw(6) << std::setprecision(4) << total_time << std::endl; + // clear performance counters + time_init_data = 0; + time_init_new_node = 0; + time_build_hist = 0; + time_evaluate_split = 0; + time_apply_split = 0; + } + + inline void TickStart() { + tstart = dmlc::GetTime(); + } + + inline void UpdatePerfTimer(const timer_name &timer_name) { + // CHECK_GT(tstart, 0); // TODO Fix + switch (timer_name) { + case INIT_DATA: + time_init_data += dmlc::GetTime() - tstart; + break; + case INIT_NEW_NODE: + time_init_new_node += dmlc::GetTime() - tstart; + break; + case BUILD_HIST: + time_build_hist += dmlc::GetTime() - tstart; + break; + case EVALUATE_SPLIT: + time_evaluate_split += dmlc::GetTime() - tstart; + break; + case APPLY_SPLIT: + time_apply_split += dmlc::GetTime() - tstart; + break; + } + tstart = -1; + } }; // initialize temp data structure @@ -178,43 +208,16 @@ class QuantileHistMaker: public TreeUpdater { const DMatrix& fmat, const RegTree& tree); - void EvaluateSplit(const int nid, - const GHistIndexMatrix& gmat, - const HistCollection& hist, - const DMatrix& fmat, - const RegTree& tree); - - void ApplySplit(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree); - - void ApplySplitDenseData(const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_int split_cond, - bool default_left); - - void ApplySplitSparseData(const RowSetCollection::Elem rowset, - const GHistIndexMatrix& gmat, - std::vector* p_row_split_tloc, - const Column& column, - bst_uint lower_bound, - bst_uint upper_bound, - bst_int split_cond, - bool default_left); - void InitNewNode(int nid, const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, - const RegTree& tree); + RegTree* tree, + QuantileHistMaker::NodeEntry* snode, + int32_t parentid); // enumerate the split values of specific feature - void EnumerateSplit(int d_step, + bool EnumerateSplit(int d_step, const GHistIndexMatrix& gmat, const GHistRow& hist, const NodeEntry& snode, @@ -223,37 +226,36 @@ class QuantileHistMaker: public TreeUpdater { bst_uint fid, bst_uint nodeID); - void ExpandWithDepthWidth(const GHistIndexMatrix &gmat, + void EvaluateSplitsBatch(const std::vector& nodes, + const GHistIndexMatrix& gmat, + const DMatrix& fmat, + const std::vector>& hist_is_init, + const std::vector>& hist_buffers); + + void ReduceHistograms( + common::GradStatHist::GradType* hist_data, + common::GradStatHist::GradType* sibling_hist_data, + common::GradStatHist::GradType* parent_hist_data, + const size_t ibegin, + const size_t iend, + const size_t inode, + const std::vector>& hist_is_init, + const std::vector>& hist_buffers); + + void SyncHistograms( + RegTree* p_tree, + const std::vector& nodes, + std::vector>* hist_buffers, + std::vector>* hist_is_init, + const std::vector>& grad_stats); + + void ExpandWithDepthWise(const GHistIndexMatrix &gmat, const GHistIndexBlockMatrix &gmatb, const ColumnMatrix &column_matrix, DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h); - void BuildLocalHistograms(int *starting_index, - int *sync_count, - const GHistIndexMatrix &gmat, - const GHistIndexBlockMatrix &gmatb, - RegTree *p_tree, - const std::vector &gpair_h); - - void SyncHistograms(int starting_index, - int sync_count, - RegTree *p_tree); - - void BuildNodeStats(const GHistIndexMatrix &gmat, - DMatrix *p_fmat, - RegTree *p_tree, - const std::vector &gpair_h); - - void EvaluateSplits(const GHistIndexMatrix &gmat, - const ColumnMatrix &column_matrix, - DMatrix *p_fmat, - RegTree *p_tree, - int *num_leaves, - int depth, - unsigned *timestamp, - std::vector *temp_qexpand_depth); void ExpandWithLossGuide(const GHistIndexMatrix& gmat, const GHistIndexBlockMatrix& gmatb, @@ -262,6 +264,62 @@ class QuantileHistMaker: public TreeUpdater { RegTree* p_tree, const std::vector& gpair_h); + + void BuildHistsBatch(const std::vector& nodes, RegTree* tree, + const GHistIndexMatrix &gmat, const std::vector& gpair, + std::vector>* hist_buffers, + std::vector>* hist_is_init); + + void BuildNodeStat(const GHistIndexMatrix &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair_h, + int32_t nid); + + void BuildNodeStatBatch( + const GHistIndexMatrix &gmat, + DMatrix *p_fmat, + RegTree *p_tree, + const std::vector &gpair_h, + const std::vector& nodes); + + int32_t FindSplitCond(int32_t nid, + RegTree *p_tree, + const GHistIndexMatrix &gmat); + + void CreateNewNodesBatch( + const std::vector& nodes, + const GHistIndexMatrix &gmat, + const ColumnMatrix &column_matrix, + DMatrix *p_fmat, + RegTree *p_tree, + int *num_leaves, + int depth, + unsigned *timestamp, + std::vector *temp_qexpand_depth); + + template + void CreateTasksForApplySplit( + const std::vector& nodes, + const GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + const int depth, + const size_t block_size, + std::vector* tasks, + std::vector* nodes_bounds); + + void CreateTasksForBuildHist( + size_t block_size_rows, + size_t nthread, + const std::vector& nodes, + std::vector>* hist_buffers, + std::vector>* hist_is_init, + std::vector>* grad_stats, + std::vector* task_nid, + std::vector* task_node_idx, + std::vector* task_block_idx); + inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) { if (lhs.loss_chg == rhs.loss_chg) { return lhs.timestamp > rhs.timestamp; // favor small timestamp @@ -270,6 +328,8 @@ class QuantileHistMaker: public TreeUpdater { } } + HistCollection hist_buff_; + // --data fields-- const TrainParam& param_; // number of omp thread used during training @@ -280,6 +340,7 @@ class QuantileHistMaker: public TreeUpdater { // the temp space for split std::vector row_split_tloc_; std::vector best_split_tloc_; + std::vector buffer_for_partition_; /*! \brief TreeNode Data: statistics for each constructed node */ std::vector snode_; /*! \brief culmulative histogram of gradients. */ @@ -311,8 +372,8 @@ class QuantileHistMaker: public TreeUpdater { enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; - common::Monitor builder_monitor_; - rabit::Reducer histred_; + TreeGrowingPerfMonitor perf_monitor; + rabit::Reducer histred_; }; std::unique_ptr builder_; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index f1f567198..3d0c09e6a 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -101,8 +101,13 @@ class QuantileHistMock : public QuantileHistMaker { RealImpl::InitData(gmat, gpair, fmat, tree); GHistIndexBlockMatrix dummy; hist_.AddHistRow(nid); - BuildHist(gpair, row_set_collection_[nid], - gmat, dummy, hist_[nid], false); + + std::vector> hist_buffers; + std::vector> hist_is_init; + std::vector nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)}; + BuildHistsBatch(nodes, const_cast(&tree), gmat, gpair, &hist_buffers, &hist_is_init); + RealImpl::InitNewNode(nid, gmat, gpair, fmat, const_cast(&tree), &snode_[0], tree[0].Parent()); + EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers); // Check if number of histogram bins is correct ASSERT_EQ(hist_[nid].size(), gmat.cut.row_ptr.back()); @@ -143,10 +148,12 @@ class QuantileHistMock : public QuantileHistMaker { RealImpl::InitData(gmat, row_gpairs, *(*dmat), tree); hist_.AddHistRow(0); - BuildHist(row_gpairs, row_set_collection_[0], - gmat, quantile_index_block, hist_[0], false); - - RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), tree); + std::vector nodes = {ExpandEntry(0, -1, -1, tree.GetDepth(0), 0.0, 0)}; + std::vector> hist_buffers; + std::vector> hist_is_init; + BuildHistsBatch(nodes, const_cast(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init); + RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), const_cast(&tree), &snode_[0], tree[0].Parent()); + EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers); /* Compute correct split (best_split) using the computed histogram */ const size_t num_row = dmat->get()->Info().num_row_; @@ -197,6 +204,7 @@ class QuantileHistMock : public QuantileHistMaker { const auto split_gain = evaluator->ComputeSplitScore(0, fid, GradStats(left_sum), GradStats(right_sum)); + if (split_gain > best_split_gain) { best_split_gain = split_gain; best_split_feature = fid; @@ -206,7 +214,8 @@ class QuantileHistMock : public QuantileHistMaker { } /* Now compare against result given by EvaluateSplit() */ - RealImpl::EvaluateSplit(0, gmat, hist_, *(*dmat), tree); + EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers); + ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature); ASSERT_EQ(snode_[0].best.split_value, gmat.cut.cut[best_split_threshold]); @@ -289,7 +298,7 @@ TEST(Updater, QuantileHist_EvalSplits) { std::vector> cfg {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, {"split_evaluator", "elastic_net"}, - {"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"}, + {"reg_lambda", "1.0f"}, {"reg_alpha", "0"}, {"max_delta_step", "0"}, {"min_child_weight", "0"}}; QuantileHistMock maker(cfg); maker.TestEvaluateSplit();