Optimize ‘hist’ for multi-core CPU (#4529)

* Initial performance optimizations for xgboost

* remove includes

* revert float->double

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* fix for CI

* Check existence of _mm_prefetch and __builtin_prefetch

* Fix lint

* optimizations for CPU

* appling comments in review

* add some comments, code refactoring

* fixing issues in CI

* adding runtime checks

* remove 1 extra check

* remove extra checks in BuildHist

* remove checks

* add debug info

* added debug info

* revert changes

* added comments

* Apply suggestions from code review

Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>

* apply review comments

* Remove unused function CreateNewNodes()

* Add descriptive comment on node_idx variable in QuantileHistMaker::Builder::BuildHistsBatch()
This commit is contained in:
Egor Smirnov
2019-06-27 22:33:49 +04:00
committed by Philip Hyunsu Cho
parent abffbe014e
commit 4d6590be3c
9 changed files with 1342 additions and 818 deletions

View File

@@ -8,11 +8,11 @@
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
#include <dmlc/timer.h>
#include <limits>
#include <vector>
#include "hist_util.h"
namespace xgboost {
namespace common {
@@ -51,6 +51,10 @@ class Column {
}
const size_t* GetRowData() const { return row_ind_; }
const uint32_t* GetIndex() const {
return index_;
}
private:
ColumnType type_;
const uint32_t* index_;
@@ -80,7 +84,7 @@ class ColumnMatrix {
std::fill(feature_counts_.begin(), feature_counts_.end(), 0);
uint32_t max_val = std::numeric_limits<uint32_t>::max();
for (bst_uint fid = 0; fid < nfeature; ++fid) {
for (int32_t fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val);
}
@@ -113,13 +117,12 @@ class ColumnMatrix {
boundary_[fid].index_end = accum_index_;
boundary_[fid].row_ind_end = accum_row_ind_;
}
index_.resize(boundary_[nfeature - 1].index_end);
row_ind_.resize(boundary_[nfeature - 1].row_ind_end);
// store least bin id for each feature
index_base_.resize(nfeature);
for (bst_uint fid = 0; fid < nfeature; ++fid) {
for (int32_t fid = 0; fid < nfeature; ++fid) {
index_base_[fid] = gmat.cut.row_ptr[fid];
}

View File

@@ -1,15 +1,15 @@
/*!
* Copyright 2017-2019 by Contributors
* \file hist_util.h
* \file hist_util.cc
*/
#include "./hist_util.h"
#include <dmlc/timer.h>
#include <rabit/rabit.h>
#include <dmlc/omp.h>
#include <numeric>
#include <vector>
#include "./random.h"
#include "./column_matrix.h"
#include "./hist_util.h"
#include "./quantile.h"
#include "./../tree/updater_quantile_hist.h"
@@ -178,7 +178,7 @@ uint32_t HistCutMatrix::GetBinIdx(const Entry& e) {
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
cut.Init(p_fmat, max_num_bins);
const size_t nthread = omp_get_max_threads();
const int32_t nthread = omp_get_max_threads();
const uint32_t nbins = cut.row_ptr.back();
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0);
@@ -260,8 +260,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
}
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) {
for (size_t tid = 0; tid < nthread; ++tid) {
for (int32_t idx = 0; idx < int32_t(nbins); ++idx) {
for (int32_t tid = 0; tid < nthread; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
}
}
@@ -411,7 +411,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
for (auto fid : group) {
nnz += feature_nnz[fid];
}
double nnz_rate = static_cast<double>(nnz) / nrow;
float nnz_rate = static_cast<float>(nnz) / nrow;
// take apart small sparse group, due it will not gain on speed
if (nnz_rate <= param.sparse_threshold) {
for (auto fid : group) {
@@ -496,176 +496,144 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
}
}
void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow hist) {
const size_t nthread = static_cast<size_t>(this->nthread_);
data_.resize(nbins_ * nthread_);
const size_t* rid = row_indices.begin;
const size_t nrows = row_indices.Size();
const uint32_t* index = gmat.index.data();
const size_t* row_ptr = gmat.row_ptr.data();
const float* pgh = reinterpret_cast<const float*>(gpair.data());
double* hist_data = reinterpret_cast<double*>(hist.data());
double* data = reinterpret_cast<double*>(data_.data());
const size_t block_size = 512;
size_t n_blocks = nrows/block_size;
n_blocks += !!(nrows - n_blocks*block_size);
const size_t nthread_to_process = std::min(nthread, n_blocks);
memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t));
// used when data layout is kDenseDataZeroBased or kDenseDataOneBased
// it means that "row_ptr" is not needed for hist computations
void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid,
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) {
GradStatHist grad_stat; // make local var to prevent false sharing
const size_t n_features = row_ptr[rid[istart]+1] - row_ptr[rid[istart]];
const size_t cache_line_size = 64;
const size_t prefetch_step = cache_line_size / sizeof(*index);
const size_t prefetch_offset = 10;
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
dmlc::omp_uint tid = omp_get_thread_num();
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
reinterpret_cast<double*>(data_.data() + tid * nbins_));
// if read each row in some block of bin-matrix - it's dense block
// and we dont need SW prefetch in this case
const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart - 1);
if (!thread_init_[tid]) {
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
thread_init_[tid] = true;
}
const size_t istart = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size);
if (iend < nrows - no_prefetch_size && !denseBlock) {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];
const size_t icol_start = rid[i] * n_features;
const size_t icol_start_prefetch = rid[i+prefetch_offset] * n_features;
const size_t idx_gh = 2*rid[i];
if (i < nrows - no_prefetch_size) {
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features;
j += prefetch_step) {
PREFETCH_READ_T0(index + j);
}
for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
const size_t idx_gh = 2*rid[i];
grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];
for (size_t j = icol_start; j < icol_start + n_features; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
}
} else {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = rid[i] * n_features;
const size_t idx_gh = 2*rid[i];
grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];
if (nthread_to_process > 1) {
const size_t size = (2*nbins_);
const size_t block_size = 1024;
size_t n_blocks = size/block_size;
n_blocks += !!(size - n_blocks*block_size);
size_t n_worked_bins = 0;
for (size_t i = 0; i < nthread_to_process; ++i) {
if (thread_init_[i]) {
thread_init_[n_worked_bins++] = i;
}
}
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
const size_t istart = iblock * block_size;
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
const size_t bin = 2 * thread_init_[0] * nbins_;
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));
for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
for (size_t i = istart; i < iend; i++) {
hist_data[i] += data[bin + i];
}
for (size_t j = icol_start; j < icol_start + n_features; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
}
grad_stat_global->Add(grad_stat);
}
void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexBlockMatrix& gmatb,
GHistRow hist) {
constexpr int kUnroll = 8; // loop unrolling factor
const size_t nblock = gmatb.GetNumBlock();
const size_t nrows = row_indices.end - row_indices.begin;
const size_t rest = nrows % kUnroll;
// used when data layout is kSparseData
// it means that "row_ptr" is needed for hist computations
void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid,
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) {
GradStatHist grad_stat; // make local var to prevent false sharing
#if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
#endif // defined(_OPENMP)
tree::GradStats* p_hist = hist.data();
const size_t cache_line_size = 64;
const size_t prefetch_step = cache_line_size / sizeof(index[0]);
const size_t prefetch_offset = 10;
#pragma omp parallel for num_threads(nthread) schedule(guided)
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
auto gmat = gmatb[bid];
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
for (size_t i = 0; i < nrows - rest; i += kUnroll) {
size_t rid[kUnroll];
size_t ibegin[kUnroll];
size_t iend[kUnroll];
GradientPair stat[kUnroll];
// if read each row in some block of bin-matrix - it's dense block
// and we dont need SW prefetch in this case
const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart);
for (int k = 0; k < kUnroll; ++k) {
rid[k] = row_indices.begin[i + k];
ibegin[k] = gmat.row_ptr[rid[k]];
iend[k] = gmat.row_ptr[rid[k] + 1];
stat[k] = gpair[rid[k]];
if (iend < nrows - no_prefetch_size && !denseBlock) {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];
const size_t idx_gh = 2*rid[i];
const size_t icol_start10 = row_ptr[rid[i+prefetch_offset]];
const size_t icol_end10 = row_ptr[rid[i+prefetch_offset]+1];
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
for (size_t j = icol_start10; j < icol_end10; j+=prefetch_step) {
PREFETCH_READ_T0(index + j);
}
for (int k = 0; k < kUnroll; ++k) {
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
const uint32_t bin = gmat.index[j];
p_hist[bin].Add(stat[k]);
}
grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];
for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
for (size_t i = nrows - rest; i < nrows; ++i) {
const size_t rid = row_indices.begin[i];
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
const GradientPair stat = gpair[rid];
for (size_t j = ibegin; j < iend; ++j) {
const uint32_t bin = gmat.index[j];
p_hist[bin].Add(stat);
} else {
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1];
const size_t idx_gh = 2*rid[i];
grad_stat.sum_grad += pgh[idx_gh];
grad_stat.sum_hess += pgh[idx_gh+1];
for (size_t j = icol_start; j < icol_end; ++j) {
const uint32_t idx_bin = 2*index[j];
data_local_hist[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
}
}
}
grad_stat_global->Add(grad_stat);
}
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
const uint32_t nbins = static_cast<bst_omp_uint>(nbins_);
constexpr int kUnroll = 8; // loop unrolling factor
const uint32_t rest = nbins % kUnroll;
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
GradStatHist* p_self = self.data();
GradStatHist* p_sibling = sibling.data();
GradStatHist* p_parent = parent.data();
#if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
#endif // defined(_OPENMP)
tree::GradStats* p_self = self.data();
tree::GradStats* p_sibling = sibling.data();
tree::GradStats* p_parent = parent.data();
const size_t size = self.size();
CHECK_EQ(sibling.size(), size);
CHECK_EQ(parent.size(), size);
#pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint bin_id = 0;
bin_id < static_cast<bst_omp_uint>(nbins - rest); bin_id += kUnroll) {
tree::GradStats pb[kUnroll];
tree::GradStats sb[kUnroll];
for (int k = 0; k < kUnroll; ++k) {
pb[k] = p_parent[bin_id + k];
const size_t block_size = 1024; // aproximatly 1024 values per block
size_t n_blocks = size/block_size + !!(size%block_size);
#pragma omp parallel for
for (int iblock = 0; iblock < n_blocks; ++iblock) {
const size_t ibegin = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
}
for (int k = 0; k < kUnroll; ++k) {
sb[k] = p_sibling[bin_id + k];
}
for (int k = 0; k < kUnroll; ++k) {
p_self[bin_id + k].SetSubstract(pb[k], sb[k]);
}
}
for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
}
}

View File

@@ -11,13 +11,50 @@
#include <xgboost/generic_parameters.h>
#include <limits>
#include <vector>
#include <algorithm>
#include <utility>
#include "row_set.h"
#include "../tree/param.h"
#include "./quantile.h"
#include "./timer.h"
#include "../include/rabit/rabit.h"
#include "random.h"
namespace xgboost {
/*!
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
*/
template<typename T, size_t MaxStackSize>
class MemStackAllocator {
public:
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
}
T* Get() {
if (!ptr_) {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
do_free_ = true;
}
}
return ptr_;
}
~MemStackAllocator() {
if (do_free_) free(ptr_);
}
private:
T* ptr_ = nullptr;
bool do_free_ = false;
size_t required_size_;
T stack_mem_[MaxStackSize];
};
namespace common {
/*
@@ -114,7 +151,7 @@ struct HistCutMatrix {
};
/*! \brief Builds the cut matrix on the GPU.
*
*
* \return The row stride across the entire dataset.
*/
size_t DeviceSketch
@@ -134,9 +171,10 @@ using GHistIndexRow = Span<uint32_t const>;
*/
struct GHistIndexMatrix {
/*! \brief row pointer to rows by element position */
std::vector<size_t> row_ptr;
// std::vector<size_t> row_ptr;
SimpleArray<size_t> row_ptr;
/*! \brief The index data */
std::vector<uint32_t> index;
SimpleArray<uint32_t> index;
/*! \brief hit count of each index */
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
@@ -170,6 +208,11 @@ struct GHistIndexBlock {
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
: row_ptr(row_ptr), index(index) {}
// get i-th row
inline GHistIndexRow operator[](size_t i) const {
return {&index[0] + row_ptr[i], detail::ptrdiff_t(row_ptr[i + 1] - row_ptr[i])};
}
};
class ColumnMatrix;
@@ -202,12 +245,63 @@ class GHistIndexBlockMatrix {
};
/*!
* \brief histogram of graident statistics for a single node.
* Consists of multiple GradStats, each entry showing total graident statistics
* for that particular bin
* Uses global bin id so as to represent all features simultaneously
* \brief used instead of GradStats to have float instead of double to reduce histograms
* this improves performance by 10-30% and memory consumption for histograms by 2x
* accuracy in both cases is the same
*/
using GHistRow = Span<tree::GradStats>;
struct GradStatHist {
typedef float GradType;
/*! \brief sum gradient statistics */
GradType sum_grad;
/*! \brief sum hessian statistics */
GradType sum_hess;
GradStatHist() : sum_grad{0}, sum_hess{0} {
static_assert(sizeof(GradStatHist) == 8,
"Size of GradStatHist is not 8 bytes.");
}
inline void Add(const GradStatHist& b) {
sum_grad += b.sum_grad;
sum_hess += b.sum_hess;
}
inline void Add(const tree::GradStats& b) {
sum_grad += b.sum_grad;
sum_hess += b.sum_hess;
}
inline void Add(const GradientPair& p) {
this->Add(p.GetGrad(), p.GetHess());
}
inline void Add(const GradType& grad, const GradType& hess) {
sum_grad += grad;
sum_hess += hess;
}
inline tree::GradStats ToGradStat() const {
return tree::GradStats(sum_grad, sum_hess);
}
inline void SetSubstract(const GradStatHist& a, const GradStatHist& b) {
sum_grad = a.sum_grad - b.sum_grad;
sum_hess = a.sum_hess - b.sum_hess;
}
inline void SetSubstract(const tree::GradStats& a, const GradStatHist& b) {
sum_grad = a.sum_grad - b.sum_grad;
sum_hess = a.sum_hess - b.sum_hess;
}
inline GradType GetGrad() const { return sum_grad; }
inline GradType GetHess() const { return sum_hess; }
inline static void Reduce(GradStatHist& a, const GradStatHist& b) { // NOLINT(*)
a.Add(b);
}
};
using GHistRow = Span<GradStatHist>;
/*!
* \brief histogram of gradient statistics for multiple nodes
@@ -215,49 +309,43 @@ using GHistRow = Span<tree::GradStats>;
class HistCollection {
public:
// access histogram for i-th node
GHistRow operator[](bst_uint nid) const {
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
CHECK_NE(row_ptr_[nid], kMax);
tree::GradStats* ptr =
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
return {ptr, nbins_};
inline GHistRow operator[](bst_uint nid) {
AddHistRow(nid);
return { const_cast<GradStatHist*>(dmlc::BeginPtr(data_arr_[nid])), nbins_};
}
// have we computed a histogram for i-th node?
bool RowExists(bst_uint nid) const {
const uint32_t k_max = std::numeric_limits<uint32_t>::max();
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
inline bool RowExists(bst_uint nid) const {
return nid < data_arr_.size();
}
// initialize histogram collection
void Init(uint32_t nbins) {
nbins_ = nbins;
row_ptr_.clear();
data_.clear();
inline void Init(uint32_t nbins) {
if (nbins_ != nbins) {
data_arr_.clear();
nbins_ = nbins;
}
}
// create an empty histogram for i-th node
void AddHistRow(bst_uint nid) {
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
if (nid >= row_ptr_.size()) {
row_ptr_.resize(nid + 1, kMax);
}
CHECK_EQ(row_ptr_[nid], kMax);
inline void AddHistRow(bst_uint nid) {
if (data_arr_.size() <= nid) {
size_t prev = data_arr_.size();
data_arr_.resize(nid + 1);
row_ptr_[nid] = data_.size();
data_.resize(data_.size() + nbins_);
for (size_t i = prev; i < data_arr_.size(); ++i) {
data_arr_[i].resize(nbins_);
}
}
}
private:
/*! \brief number of all bins over all features */
uint32_t nbins_;
std::vector<tree::GradStats> data_;
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
std::vector<size_t> row_ptr_;
uint32_t nbins_ = 0;
std::vector<std::vector<GradStatHist>> data_arr_;
};
/*!
* \brief builder for histograms of gradient statistics
*/
@@ -267,21 +355,55 @@ class GHistBuilder {
inline void Init(size_t nthread, uint32_t nbins) {
nthread_ = nthread;
nbins_ = nbins;
thread_init_.resize(nthread_);
}
// construct a histogram via histogram aggregation
void BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat,
GHistRow hist);
// same, with feature grouping
void BuildBlockHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices,
const GHistIndexBlockMatrix& gmatb,
GHistRow hist);
// construct a histogram via subtraction trick
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
const RowSetCollection::Elem row_indices,
const GHistIndexBlockMatrix& gmatb,
GHistRow hist) {
constexpr int kUnroll = 8; // loop unrolling factor
const int32_t nblock = gmatb.GetNumBlock();
const size_t nrows = row_indices.end - row_indices.begin;
const size_t rest = nrows % kUnroll;
#pragma omp parallel for
for (int32_t bid = 0; bid < nblock; ++bid) {
auto gmat = gmatb[bid];
for (size_t i = 0; i < nrows - rest; i += kUnroll) {
size_t rid[kUnroll];
size_t ibegin[kUnroll];
size_t iend[kUnroll];
GradientPair stat[kUnroll];
for (int k = 0; k < kUnroll; ++k) {
rid[k] = row_indices.begin[i + k];
}
for (int k = 0; k < kUnroll; ++k) {
ibegin[k] = gmat.row_ptr[rid[k]];
iend[k] = gmat.row_ptr[rid[k] + 1];
}
for (int k = 0; k < kUnroll; ++k) {
stat[k] = gpair[rid[k]];
}
for (int k = 0; k < kUnroll; ++k) {
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
const uint32_t bin = gmat.index[j];
hist[bin].Add(stat[k]);
}
}
}
for (size_t i = nrows - rest; i < nrows; ++i) {
const size_t rid = row_indices.begin[i];
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
const GradientPair stat = gpair[rid];
for (size_t j = ibegin; j < iend; ++j) {
const uint32_t bin = gmat.index[j];
hist[bin].Add(stat);
}
}
}
}
uint32_t GetNumBins() {
return nbins_;
@@ -292,11 +414,19 @@ class GHistBuilder {
size_t nthread_;
/*! \brief number of all bins over all features */
uint32_t nbins_;
std::vector<size_t> thread_init_;
std::vector<tree::GradStats> data_;
};
void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid,
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat);
void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid,
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat);
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_HIST_UTIL_H_

View File

@@ -27,10 +27,10 @@ class RowSetCollection {
// id of node associated with this instance set; -1 means uninitialized
Elem()
= default;
Elem(const size_t* begin,
const size_t* end,
int node_id)
: begin(begin), end(end), node_id(node_id) {}
Elem(const size_t* begin_,
const size_t* end_,
int node_id_)
: begin(begin_), end(end_), node_id(node_id_) {}
inline size_t Size() const {
return end - begin;
@@ -42,6 +42,10 @@ class RowSetCollection {
std::vector<size_t> right;
};
size_t Size(unsigned node_id) {
return elem_of_each_node_[node_id].Size();
}
inline std::vector<Elem>::const_iterator begin() const { // NOLINT
return elem_of_each_node_.begin();
}
@@ -51,12 +55,12 @@ class RowSetCollection {
}
/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr)
<< "access element that is not in the set";
inline Elem operator[](unsigned node_id) const {
const Elem e = elem_of_each_node_[node_id];
return e;
}
// clear up things
inline void Clear() {
elem_of_each_node_.clear();
@@ -81,38 +85,29 @@ class RowSetCollection {
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
}
// split rowset into two
inline void AddSplit(unsigned node_id,
const std::vector<Split>& row_split_tloc,
size_t iLeft,
unsigned left_node_id,
unsigned right_node_id) {
const Elem e = elem_of_each_node_[node_id];
const auto nthread = static_cast<bst_omp_uint>(row_split_tloc.size());
CHECK(e.begin != nullptr);
size_t* all_begin = dmlc::BeginPtr(row_indices_);
size_t* begin = all_begin + (e.begin - all_begin);
Elem e = elem_of_each_node_[node_id];
size_t* it = begin;
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
it += row_split_tloc[tid].left.size();
}
size_t* split_pt = it;
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
it += row_split_tloc[tid].right.size();
}
CHECK(e.begin != nullptr);
size_t* begin = const_cast<size_t*>(e.begin);
size_t* split_pt = begin + iLeft;
if (left_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
elem_of_each_node_.resize((left_node_id + 1)*2, Elem(nullptr, nullptr, -1));
}
if (right_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
elem_of_each_node_.resize((right_node_id + 1)*2, Elem(nullptr, nullptr, -1));
}
elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
elem_of_each_node_[node_id] = Elem(begin, e.end, -1);
}
// stores the row indices in the set