Optimized BuildHist function (#5156)
This commit is contained in:
parent
4240daed4e
commit
c67163250e
@ -659,13 +659,59 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief fill a histogram by zeroes
|
||||
*/
|
||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
|
||||
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Increment hist as dst += add in range [begin, end)
|
||||
*/
|
||||
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
|
||||
const FPType* padd = reinterpret_cast<const FPType*>(add.data());
|
||||
|
||||
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||
pdst[i] += padd[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Copy hist from src to dst in range [begin, end)
|
||||
*/
|
||||
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
|
||||
const FPType* psrc = reinterpret_cast<const FPType*>(src.data());
|
||||
|
||||
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||
pdst[i] = psrc[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Compute Subtraction: dst = src1 - src2 in range [begin, end)
|
||||
*/
|
||||
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
|
||||
size_t begin, size_t end) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
|
||||
const FPType* psrc1 = reinterpret_cast<const FPType*>(src1.data());
|
||||
const FPType* psrc2 = reinterpret_cast<const FPType*>(src2.data());
|
||||
|
||||
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||
pdst[i] = psrc1[i] - psrc2[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow hist) {
|
||||
const size_t nthread = static_cast<size_t>(this->nthread_);
|
||||
data_.resize(nbins_ * nthread_);
|
||||
|
||||
const size_t* rid = row_indices.begin;
|
||||
const size_t nrows = row_indices.Size();
|
||||
const uint32_t* index = gmat.index.data();
|
||||
@ -673,79 +719,27 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
||||
|
||||
double* hist_data = reinterpret_cast<double*>(hist.data());
|
||||
double* data = reinterpret_cast<double*>(data_.data());
|
||||
|
||||
const size_t block_size = 512;
|
||||
size_t n_blocks = nrows/block_size;
|
||||
n_blocks += !!(nrows - n_blocks*block_size);
|
||||
|
||||
const size_t nthread_to_process = std::min(nthread, n_blocks);
|
||||
memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t));
|
||||
|
||||
const size_t cache_line_size = 64;
|
||||
const size_t prefetch_offset = 10;
|
||||
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
|
||||
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
|
||||
|
||||
#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
|
||||
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
|
||||
dmlc::omp_uint tid = omp_get_thread_num();
|
||||
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
|
||||
reinterpret_cast<double*>(data_.data() + tid * nbins_));
|
||||
for (size_t i = 0; i < nrows; ++i) {
|
||||
const size_t icol_start = row_ptr[rid[i]];
|
||||
const size_t icol_end = row_ptr[rid[i]+1];
|
||||
|
||||
if (!thread_init_[tid]) {
|
||||
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
|
||||
thread_init_[tid] = true;
|
||||
if (i < nrows - no_prefetch_size) {
|
||||
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
|
||||
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
|
||||
}
|
||||
|
||||
const size_t istart = iblock*block_size;
|
||||
const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size);
|
||||
for (size_t i = istart; i < iend; ++i) {
|
||||
const size_t icol_start = row_ptr[rid[i]];
|
||||
const size_t icol_end = row_ptr[rid[i]+1];
|
||||
for (size_t j = icol_start; j < icol_end; ++j) {
|
||||
const uint32_t idx_bin = 2*index[j];
|
||||
const size_t idx_gh = 2*rid[i];
|
||||
|
||||
if (i < nrows - no_prefetch_size) {
|
||||
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
|
||||
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
|
||||
}
|
||||
|
||||
for (size_t j = icol_start; j < icol_end; ++j) {
|
||||
const uint32_t idx_bin = 2*index[j];
|
||||
const size_t idx_gh = 2*rid[i];
|
||||
|
||||
data_local_hist[idx_bin] += pgh[idx_gh];
|
||||
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nthread_to_process > 1) {
|
||||
const size_t size = (2*nbins_);
|
||||
const size_t block_size = 1024;
|
||||
size_t n_blocks = size/block_size;
|
||||
n_blocks += !!(size - n_blocks*block_size);
|
||||
|
||||
size_t n_worked_bins = 0;
|
||||
for (size_t i = 0; i < nthread_to_process; ++i) {
|
||||
if (thread_init_[i]) {
|
||||
thread_init_[n_worked_bins++] = i;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
|
||||
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
|
||||
const size_t istart = iblock * block_size;
|
||||
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
|
||||
|
||||
const size_t bin = 2 * thread_init_[0] * nbins_;
|
||||
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));
|
||||
|
||||
for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
|
||||
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
|
||||
for (size_t i = istart; i < iend; i++) {
|
||||
hist_data[i] += data[bin + i];
|
||||
}
|
||||
}
|
||||
hist_data[idx_bin] += pgh[idx_gh];
|
||||
hist_data[idx_bin+1] += pgh[idx_gh+1];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -801,10 +795,6 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
}
|
||||
|
||||
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
tree::GradStats* p_self = self.data();
|
||||
tree::GradStats* p_sibling = sibling.data();
|
||||
tree::GradStats* p_parent = parent.data();
|
||||
|
||||
const size_t size = self.size();
|
||||
CHECK_EQ(sibling.size(), size);
|
||||
CHECK_EQ(parent.size(), size);
|
||||
@ -816,9 +806,7 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
|
||||
for (omp_ulong iblock = 0; iblock < n_blocks; ++iblock) {
|
||||
const size_t ibegin = iblock*block_size;
|
||||
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
|
||||
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {
|
||||
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
|
||||
}
|
||||
SubtractionHist(self, parent, sibling, ibegin, iend);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -14,8 +14,10 @@
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#include "row_set.h"
|
||||
#include "threading_utils.h"
|
||||
#include "../tree/param.h"
|
||||
#include "./quantile.h"
|
||||
#include "./timer.h"
|
||||
@ -254,7 +256,7 @@ class DenseCuts : public CutsBuilder {
|
||||
|
||||
// FIXME(trivialfis): Merge this into generic cut builder.
|
||||
/*! \brief Builds the cut matrix on the GPU.
|
||||
*
|
||||
*
|
||||
* \return The row stride across the entire dataset.
|
||||
*/
|
||||
size_t DeviceSketch(int device,
|
||||
@ -343,13 +345,34 @@ class GHistIndexBlockMatrix {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief histogram of graident statistics for a single node.
|
||||
* Consists of multiple GradStats, each entry showing total graident statistics
|
||||
* \brief histogram of gradient statistics for a single node.
|
||||
* Consists of multiple GradStats, each entry showing total gradient statistics
|
||||
* for that particular bin
|
||||
* Uses global bin id so as to represent all features simultaneously
|
||||
*/
|
||||
using GHistRow = Span<tree::GradStats>;
|
||||
|
||||
/*!
|
||||
* \brief fill a histogram by zeros
|
||||
*/
|
||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Increment hist as dst += add in range [begin, end)
|
||||
*/
|
||||
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Copy hist from src to dst in range [begin, end)
|
||||
*/
|
||||
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Compute Subtraction: dst = src1 - src2 in range [begin, end)
|
||||
*/
|
||||
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
|
||||
size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief histogram of gradient statistics for multiple nodes
|
||||
*/
|
||||
@ -372,9 +395,13 @@ class HistCollection {
|
||||
|
||||
// initialize histogram collection
|
||||
void Init(uint32_t nbins) {
|
||||
nbins_ = nbins;
|
||||
if (nbins_ != nbins) {
|
||||
nbins_ = nbins;
|
||||
// quite expensive operation, so let's do this only once
|
||||
data_.clear();
|
||||
}
|
||||
row_ptr_.clear();
|
||||
data_.clear();
|
||||
n_nodes_added_ = 0;
|
||||
}
|
||||
|
||||
// create an empty histogram for i-th node
|
||||
@ -385,20 +412,201 @@ class HistCollection {
|
||||
}
|
||||
CHECK_EQ(row_ptr_[nid], kMax);
|
||||
|
||||
row_ptr_[nid] = data_.size();
|
||||
data_.resize(data_.size() + nbins_);
|
||||
if (data_.size() < nbins_ * (nid + 1)) {
|
||||
data_.resize(nbins_ * (nid + 1));
|
||||
}
|
||||
|
||||
row_ptr_[nid] = nbins_ * n_nodes_added_;
|
||||
n_nodes_added_++;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief number of all bins over all features */
|
||||
uint32_t nbins_;
|
||||
uint32_t nbins_ = 0;
|
||||
/*! \brief amount of active nodes in hist collection */
|
||||
uint32_t n_nodes_added_ = 0;
|
||||
|
||||
std::vector<tree::GradStats> data_;
|
||||
|
||||
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
|
||||
/*! \brief row_ptr_[nid] locates bin for histogram of node nid */
|
||||
std::vector<size_t> row_ptr_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Stores temporary histograms to compute them in parallel
|
||||
* Supports processing multiple tree-nodes for nested parallelism
|
||||
* Able to reduce histograms across threads in efficient way
|
||||
*/
|
||||
class ParallelGHistBuilder {
|
||||
public:
|
||||
void Init(size_t nbins) {
|
||||
if (nbins != nbins_) {
|
||||
hist_buffer_.Init(nbins);
|
||||
nbins_ = nbins;
|
||||
}
|
||||
}
|
||||
|
||||
// Add new elements if needed, mark all hists as unused
|
||||
// targeted_hists - already allocated hists which should contain final results after Reduce() call
|
||||
void Reset(size_t nthreads, size_t nodes, const BlockedSpace2d& space,
|
||||
const std::vector<GHistRow>& targeted_hists) {
|
||||
hist_buffer_.Init(nbins_);
|
||||
tid_nid_to_hist_.clear();
|
||||
hist_memory_.clear();
|
||||
threads_to_nids_map_.clear();
|
||||
|
||||
targeted_hists_ = targeted_hists;
|
||||
|
||||
CHECK_EQ(nodes, targeted_hists.size());
|
||||
|
||||
nodes_ = nodes;
|
||||
nthreads_ = nthreads;
|
||||
|
||||
MatchThreadsToNodes(space);
|
||||
AllocateAdditionalHistograms();
|
||||
MatchNodeNidPairToHist();
|
||||
|
||||
hist_was_used_.resize(nthreads * nodes_);
|
||||
std::fill(hist_was_used_.begin(), hist_was_used_.end(), static_cast<int>(false));
|
||||
}
|
||||
|
||||
// Get specified hist, initialize hist by zeros if it wasn't used before
|
||||
GHistRow GetInitializedHist(size_t tid, size_t nid) {
|
||||
CHECK_LT(nid, nodes_);
|
||||
CHECK_LT(tid, nthreads_);
|
||||
|
||||
size_t idx = tid_nid_to_hist_.at({tid, nid});
|
||||
GHistRow hist = hist_memory_[idx];
|
||||
|
||||
if (!hist_was_used_[tid * nodes_ + nid]) {
|
||||
InitilizeHistByZeroes(hist, 0, hist.size());
|
||||
hist_was_used_[tid * nodes_ + nid] = static_cast<int>(true);
|
||||
}
|
||||
|
||||
return hist;
|
||||
}
|
||||
|
||||
// Reduce following bins (begin, end] for nid-node in dst across threads
|
||||
void ReduceHist(size_t nid, size_t begin, size_t end) {
|
||||
CHECK_GT(end, begin);
|
||||
CHECK_LT(nid, nodes_);
|
||||
|
||||
GHistRow dst = targeted_hists_[nid];
|
||||
|
||||
bool is_updated = false;
|
||||
for (size_t tid = 0; tid < nthreads_; ++tid) {
|
||||
if (hist_was_used_[tid * nodes_ + nid]) {
|
||||
is_updated = true;
|
||||
const size_t idx = tid_nid_to_hist_.at({tid, nid});
|
||||
GHistRow src = hist_memory_[idx];
|
||||
|
||||
if (dst.data() != src.data()) {
|
||||
IncrementHist(dst, src, begin, end);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_updated) {
|
||||
// In distributed mode - some tree nodes can be empty on local machines,
|
||||
// So we need just set local hist by zeros in this case
|
||||
InitilizeHistByZeroes(dst, begin, end);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void MatchThreadsToNodes(const BlockedSpace2d& space) {
|
||||
const size_t space_size = space.Size();
|
||||
const size_t chunck_size = space_size / nthreads_ + !!(space_size % nthreads_);
|
||||
|
||||
threads_to_nids_map_.resize(nthreads_ * nodes_, false);
|
||||
|
||||
for (size_t tid = 0; tid < nthreads_; ++tid) {
|
||||
size_t begin = chunck_size * tid;
|
||||
size_t end = std::min(begin + chunck_size, space_size);
|
||||
|
||||
if (begin < space_size) {
|
||||
size_t nid_begin = space.GetFirstDimension(begin);
|
||||
size_t nid_end = space.GetFirstDimension(end-1);
|
||||
|
||||
for (size_t nid = nid_begin; nid <= nid_end; ++nid) {
|
||||
// true - means thread 'tid' will work to compute partial hist for node 'nid'
|
||||
threads_to_nids_map_[tid * nodes_ + nid] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AllocateAdditionalHistograms() {
|
||||
size_t hist_allocated_additionally = 0;
|
||||
|
||||
for (size_t nid = 0; nid < nodes_; ++nid) {
|
||||
int nthreads_for_nid = 0;
|
||||
|
||||
for (size_t tid = 0; tid < nthreads_; ++tid) {
|
||||
if (threads_to_nids_map_[tid * nodes_ + nid]) {
|
||||
nthreads_for_nid++;
|
||||
}
|
||||
}
|
||||
|
||||
// In distributed mode - some tree nodes can be empty on local machines,
|
||||
// set nthreads_for_nid to 0 in this case.
|
||||
// In another case - allocate additional (nthreads_for_nid - 1) histograms,
|
||||
// because one is already allocated externally (will store final result for the node).
|
||||
hist_allocated_additionally += std::max<int>(0, nthreads_for_nid - 1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < hist_allocated_additionally; ++i) {
|
||||
hist_buffer_.AddHistRow(i);
|
||||
}
|
||||
}
|
||||
|
||||
void MatchNodeNidPairToHist() {
|
||||
size_t hist_total = 0;
|
||||
size_t hist_allocated_additionally = 0;
|
||||
|
||||
for (size_t nid = 0; nid < nodes_; ++nid) {
|
||||
bool first_hist = true;
|
||||
for (size_t tid = 0; tid < nthreads_; ++tid) {
|
||||
if (threads_to_nids_map_[tid * nodes_ + nid]) {
|
||||
if (first_hist) {
|
||||
hist_memory_.push_back(targeted_hists_[nid]);
|
||||
first_hist = false;
|
||||
} else {
|
||||
hist_memory_.push_back(hist_buffer_[hist_allocated_additionally]);
|
||||
hist_allocated_additionally++;
|
||||
}
|
||||
// map pair {tid, nid} to index of allocated histogram from hist_memory_
|
||||
tid_nid_to_hist_[{tid, nid}] = hist_total++;
|
||||
CHECK_EQ(hist_total, hist_memory_.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief number of bins in each histogram */
|
||||
size_t nbins_ = 0;
|
||||
/*! \brief number of threads for parallel computation */
|
||||
size_t nthreads_ = 0;
|
||||
/*! \brief number of nodes which will be processed in parallel */
|
||||
size_t nodes_ = 0;
|
||||
/*! \brief Buffer for additional histograms for Parallel processing */
|
||||
HistCollection hist_buffer_;
|
||||
/*!
|
||||
* \brief Marks which hists were used, it means that they should be merged.
|
||||
* Contains only {true or false} values
|
||||
* but 'int' is used instead of 'bool', because std::vector<bool> isn't thread safe
|
||||
*/
|
||||
std::vector<int> hist_was_used_;
|
||||
|
||||
/*! \brief Buffer for additional histograms for Parallel processing */
|
||||
std::vector<bool> threads_to_nids_map_;
|
||||
/*! \brief Contains histograms for final results */
|
||||
std::vector<GHistRow> targeted_hists_;
|
||||
/*! \brief Allocated memory for histograms used for construction */
|
||||
std::vector<GHistRow> hist_memory_;
|
||||
/*! \brief map pair {tid, nid} to index of allocated histogram from hist_memory_ */
|
||||
std::map<std::pair<size_t, size_t>, size_t> tid_nid_to_hist_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief builder for histograms of gradient statistics
|
||||
*/
|
||||
@ -408,7 +616,6 @@ class GHistBuilder {
|
||||
inline void Init(size_t nthread, uint32_t nbins) {
|
||||
nthread_ = nthread;
|
||||
nbins_ = nbins;
|
||||
thread_init_.resize(nthread_);
|
||||
}
|
||||
|
||||
// construct a histogram via histogram aggregation
|
||||
@ -433,8 +640,6 @@ class GHistBuilder {
|
||||
size_t nthread_;
|
||||
/*! \brief number of all bins over all features */
|
||||
uint32_t nbins_;
|
||||
std::vector<size_t> thread_init_;
|
||||
std::vector<tree::GradStats> data_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -108,12 +108,19 @@ class BlockedSpace2d {
|
||||
|
||||
// Wrapper to implement nested parallelism with simple omp parallel for
|
||||
template<typename Func>
|
||||
void ParallelFor2d(const BlockedSpace2d& space, Func func) {
|
||||
const int num_blocks_in_space = static_cast<int>(space.Size());
|
||||
void ParallelFor2d(const BlockedSpace2d& space, const int nthreads, Func func) {
|
||||
const size_t num_blocks_in_space = space.Size();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < num_blocks_in_space; i++) {
|
||||
func(space.GetFirstDimension(i), space.GetRange(i));
|
||||
#pragma omp parallel num_threads(nthreads)
|
||||
{
|
||||
size_t tid = omp_get_thread_num();
|
||||
size_t chunck_size = num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads);
|
||||
|
||||
size_t begin = chunck_size * tid;
|
||||
size_t end = std::min(begin + chunck_size, num_blocks_in_space);
|
||||
for (auto i = begin; i < end; i++) {
|
||||
func(space.GetFirstDimension(i), space.GetRange(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
* Copyright 2017-2018 by Contributors
|
||||
* \file updater_quantile_hist.cc
|
||||
* \brief use quantized feature values to construct a tree
|
||||
* \author Philip Cho, Tianqi Checn
|
||||
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
||||
*/
|
||||
#include <dmlc/timer.h>
|
||||
#include <rabit/rabit.h>
|
||||
@ -44,7 +44,7 @@ void QuantileHistMaker::Configure(const Args& args) {
|
||||
pruner_->Configure(args);
|
||||
param_.UpdateAllowUnknown(args);
|
||||
|
||||
// initialise the split evaluator
|
||||
// initialize the split evaluator
|
||||
if (!spliteval_) {
|
||||
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||
}
|
||||
@ -100,66 +100,121 @@ void QuantileHistMaker::Builder::SyncHistograms(
|
||||
int sync_count,
|
||||
RegTree *p_tree) {
|
||||
builder_monitor_.Start("SyncHistograms");
|
||||
this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count);
|
||||
// use Subtraction Trick
|
||||
for (auto const& node_pair : nodes_for_subtraction_trick_) {
|
||||
hist_.AddHistRow(node_pair.first);
|
||||
SubtractionTrick(hist_[node_pair.first], hist_[node_pair.second],
|
||||
hist_[(*p_tree)[node_pair.first].Parent()]);
|
||||
|
||||
const bool isDistributed = rabit::IsDistributed();
|
||||
|
||||
const size_t nbins = hist_builder_.GetNumBins();
|
||||
common::BlockedSpace2d space(nodes_for_explicit_hist_build_.size(), [&](size_t node) {
|
||||
return nbins;
|
||||
}, 1024);
|
||||
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
|
||||
const auto entry = nodes_for_explicit_hist_build_[node];
|
||||
auto this_hist = hist_[entry.nid];
|
||||
// Merging histograms from each thread into once
|
||||
hist_buffer_.ReduceHist(node, r.begin(), r.end());
|
||||
|
||||
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1 && !isDistributed) {
|
||||
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
|
||||
auto sibling_hist = hist_[entry.sibling_nid];
|
||||
|
||||
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||
}
|
||||
});
|
||||
|
||||
if (isDistributed) {
|
||||
this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count);
|
||||
// use Subtraction Trick
|
||||
for (auto const& node : nodes_for_subtraction_trick_) {
|
||||
SubtractionTrick(hist_[node.nid], hist_[node.sibling_nid],
|
||||
hist_[(*p_tree)[node.nid].Parent()]);
|
||||
}
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("SyncHistograms");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
|
||||
ExpandEntry entry,
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h) {
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
nodes_for_explicit_hist_build_.push_back(entry);
|
||||
|
||||
if (entry.sibling_nid > -1) {
|
||||
nodes_for_subtraction_trick_.emplace_back(entry.sibling_nid, entry.nid,
|
||||
p_tree->GetDepth(entry.sibling_nid), 0.0f, 0);
|
||||
}
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
|
||||
AddHistRows(&starting_index, &sync_count);
|
||||
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
||||
SyncHistograms(starting_index, sync_count, p_tree);
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::AddHistRows(int *starting_index, int *sync_count) {
|
||||
builder_monitor_.Start("AddHistRows");
|
||||
|
||||
for (auto const& entry : nodes_for_explicit_hist_build_) {
|
||||
int nid = entry.nid;
|
||||
hist_.AddHistRow(nid);
|
||||
(*starting_index) = std::min(nid, (*starting_index));
|
||||
}
|
||||
(*sync_count) = nodes_for_explicit_hist_build_.size();
|
||||
|
||||
for (auto const& node : nodes_for_subtraction_trick_) {
|
||||
hist_.AddHistRow(node.nid);
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("AddHistRows");
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::BuildLocalHistograms(
|
||||
int *starting_index,
|
||||
int *sync_count,
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h) {
|
||||
builder_monitor_.Start("BuildLocalHistograms");
|
||||
for (auto const& entry : qexpand_depth_wise_) {
|
||||
int nid = entry.nid;
|
||||
RegTree::Node &node = (*p_tree)[nid];
|
||||
if (rabit::IsDistributed()) {
|
||||
if (node.IsRoot() || node.IsLeftChild()) {
|
||||
hist_.AddHistRow(nid);
|
||||
// in distributed setting, we always calculate from left child or root node
|
||||
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false);
|
||||
if (!node.IsRoot()) {
|
||||
nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid;
|
||||
}
|
||||
(*sync_count)++;
|
||||
(*starting_index) = std::min((*starting_index), nid);
|
||||
}
|
||||
} else {
|
||||
if (!node.IsRoot() && node.IsLeftChild() &&
|
||||
(row_set_collection_[nid].Size() <
|
||||
row_set_collection_[(*p_tree)[node.Parent()].RightChild()].Size())) {
|
||||
hist_.AddHistRow(nid);
|
||||
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false);
|
||||
nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid;
|
||||
(*sync_count)++;
|
||||
(*starting_index) = std::min((*starting_index), nid);
|
||||
} else if (!node.IsRoot() && !node.IsLeftChild() &&
|
||||
(row_set_collection_[nid].Size() <=
|
||||
row_set_collection_[(*p_tree)[node.Parent()].LeftChild()].Size())) {
|
||||
hist_.AddHistRow(nid);
|
||||
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false);
|
||||
nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].LeftChild()] = nid;
|
||||
(*sync_count)++;
|
||||
(*starting_index) = std::min((*starting_index), nid);
|
||||
} else if (node.IsRoot()) {
|
||||
hist_.AddHistRow(nid);
|
||||
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false);
|
||||
(*sync_count)++;
|
||||
(*starting_index) = std::min((*starting_index), nid);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
||||
|
||||
// create space of size (# rows in each node)
|
||||
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
|
||||
return row_set_collection_[nid].Size();
|
||||
}, 256);
|
||||
|
||||
std::vector<GHistRow> target_hists(n_nodes);
|
||||
for (size_t i = 0; i < n_nodes; ++i) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
|
||||
target_hists[i] = hist_[nid];
|
||||
}
|
||||
|
||||
hist_buffer_.Reset(this->nthread_, n_nodes, space, target_hists);
|
||||
|
||||
// Parallel processing by nodes and data in each node
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) {
|
||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
const int32_t nid = nodes_for_explicit_hist_build_[nid_in_set].nid;
|
||||
|
||||
auto start_of_row_set = row_set_collection_[nid].begin;
|
||||
auto rid_set = RowSetCollection::Elem(start_of_row_set + r.begin(),
|
||||
start_of_row_set + r.end(),
|
||||
nid);
|
||||
BuildHist(gpair_h, rid_set, gmat, gmatb, hist_buffer_.GetInitializedHist(tid, nid_in_set));
|
||||
});
|
||||
|
||||
builder_monitor_.Stop("BuildLocalHistograms");
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::BuildNodeStats(
|
||||
const GHistIndexMatrix &gmat,
|
||||
DMatrix *p_fmat,
|
||||
@ -193,7 +248,7 @@ void QuantileHistMaker::Builder::EvaluateSplits(
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth) {
|
||||
this->EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree);
|
||||
EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree);
|
||||
|
||||
for (auto const& entry : qexpand_depth_wise_) {
|
||||
int nid = entry.nid;
|
||||
@ -206,9 +261,9 @@ void QuantileHistMaker::Builder::EvaluateSplits(
|
||||
this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree);
|
||||
int left_id = (*p_tree)[nid].LeftChild();
|
||||
int right_id = (*p_tree)[nid].RightChild();
|
||||
temp_qexpand_depth->push_back(ExpandEntry(left_id,
|
||||
temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id,
|
||||
p_tree->GetDepth(left_id), 0.0, (*timestamp)++));
|
||||
temp_qexpand_depth->push_back(ExpandEntry(right_id,
|
||||
temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id,
|
||||
p_tree->GetDepth(right_id), 0.0, (*timestamp)++));
|
||||
// - 1 parent + 2 new children
|
||||
(*num_leaves)++;
|
||||
@ -216,6 +271,43 @@ void QuantileHistMaker::Builder::EvaluateSplits(
|
||||
}
|
||||
}
|
||||
|
||||
// Split nodes to 2 sets depending on amount of rows in each node
|
||||
// Histograms for small nodes will be built explicitly
|
||||
// Histograms for big nodes will be built by 'Subtraction Trick'
|
||||
// Exception: in distributed setting, we always build the histogram for the left child node
|
||||
// and use 'Subtraction Trick' to built the histogram for the right child node.
|
||||
// This ensures that the workers operate on the same set of tree nodes.
|
||||
void QuantileHistMaker::Builder::SplitSiblings(const std::vector<ExpandEntry>& nodes,
|
||||
std::vector<ExpandEntry>* small_siblings,
|
||||
std::vector<ExpandEntry>* big_siblings,
|
||||
RegTree *p_tree) {
|
||||
for (auto const& entry : nodes) {
|
||||
int nid = entry.nid;
|
||||
RegTree::Node &node = (*p_tree)[nid];
|
||||
if (rabit::IsDistributed()) {
|
||||
if (node.IsRoot() || node.IsLeftChild()) {
|
||||
small_siblings->push_back(entry);
|
||||
} else {
|
||||
big_siblings->push_back(entry);
|
||||
}
|
||||
} else {
|
||||
if (!node.IsRoot() && node.IsLeftChild() &&
|
||||
(row_set_collection_[nid].Size() <
|
||||
row_set_collection_[(*p_tree)[node.Parent()].RightChild()].Size())) {
|
||||
small_siblings->push_back(entry);
|
||||
} else if (!node.IsRoot() && !node.IsLeftChild() &&
|
||||
(row_set_collection_[nid].Size() <=
|
||||
row_set_collection_[(*p_tree)[node.Parent()].LeftChild()].Size())) {
|
||||
small_siblings->push_back(entry);
|
||||
} else if (node.IsRoot()) {
|
||||
small_siblings->push_back(entry);
|
||||
} else {
|
||||
big_siblings->push_back(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
@ -227,21 +319,28 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
||||
int num_leaves = 0;
|
||||
|
||||
// in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway
|
||||
qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid,
|
||||
qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
|
||||
p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++));
|
||||
++num_leaves;
|
||||
for (int depth = 0; depth < param_.max_depth + 1; depth++) {
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
std::vector<ExpandEntry> temp_qexpand_depth;
|
||||
BuildLocalHistograms(&starting_index, &sync_count, gmat, gmatb, p_tree, gpair_h);
|
||||
|
||||
SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
|
||||
&nodes_for_subtraction_trick_, p_tree);
|
||||
AddHistRows(&starting_index, &sync_count);
|
||||
|
||||
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
||||
SyncHistograms(starting_index, sync_count, p_tree);
|
||||
|
||||
BuildNodeStats(gmat, p_fmat, p_tree, gpair_h);
|
||||
EvaluateSplits(gmat, column_matrix, p_fmat, p_tree, &num_leaves, depth, ×tamp,
|
||||
&temp_qexpand_depth);
|
||||
// clean up
|
||||
qexpand_depth_wise_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
if (temp_qexpand_depth.empty()) {
|
||||
break;
|
||||
} else {
|
||||
@ -262,14 +361,12 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
unsigned timestamp = 0;
|
||||
int num_leaves = 0;
|
||||
|
||||
hist_.AddHistRow(ExpandEntry::kRootNid);
|
||||
BuildHist(gpair_h, row_set_collection_[ExpandEntry::kRootNid], gmat, gmatb,
|
||||
hist_[ExpandEntry::kRootNid], true);
|
||||
ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
|
||||
p_tree->GetDepth(0), 0.0f, timestamp++);
|
||||
BuildHistogramsLossGuide(node, gmat, gmatb, p_tree, gpair_h);
|
||||
|
||||
this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
|
||||
ExpandEntry node(ExpandEntry::kRootNid, p_tree->GetDepth(ExpandEntry::kRootNid),
|
||||
snode_[ExpandEntry::kRootNid].best.loss_chg, timestamp++);
|
||||
this->EvaluateSplit({node}, gmat, hist_, *p_fmat, *p_tree);
|
||||
node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg;
|
||||
|
||||
@ -289,20 +386,20 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
|
||||
const int cleft = (*p_tree)[nid].LeftChild();
|
||||
const int cright = (*p_tree)[nid].RightChild();
|
||||
hist_.AddHistRow(cleft);
|
||||
hist_.AddHistRow(cright);
|
||||
|
||||
ExpandEntry left_node(cleft, cright, p_tree->GetDepth(cleft),
|
||||
0.0f, timestamp++);
|
||||
ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright),
|
||||
0.0f, timestamp++);
|
||||
|
||||
if (rabit::IsDistributed()) {
|
||||
// in distributed mode, we need to keep consistent across workers
|
||||
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true);
|
||||
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
|
||||
BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
|
||||
} else {
|
||||
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
|
||||
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true);
|
||||
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
|
||||
BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
|
||||
} else {
|
||||
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright], true);
|
||||
SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
|
||||
BuildHistogramsLossGuide(right_node, gmat, gmatb, p_tree, gpair_h);
|
||||
}
|
||||
}
|
||||
|
||||
@ -313,11 +410,6 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
snode_[cleft].weight, snode_[cright].weight);
|
||||
interaction_constraints_.Split(nid, featureid, cleft, cright);
|
||||
|
||||
ExpandEntry left_node(cleft, p_tree->GetDepth(cleft),
|
||||
snode_[cleft].best.loss_chg, timestamp++);
|
||||
ExpandEntry right_node(cright, p_tree->GetDepth(cright),
|
||||
snode_[cright].best.loss_chg, timestamp++);
|
||||
|
||||
this->EvaluateSplit({left_node, right_node}, gmat, hist_, *p_fmat, *p_tree);
|
||||
left_node.loss_chg = snode_[cleft].best.loss_chg;
|
||||
right_node.loss_chg = snode_[cright].best.loss_chg;
|
||||
@ -427,6 +519,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
// initialize histogram collection
|
||||
uint32_t nbins = gmat.cut.Ptrs().back();
|
||||
hist_.Init(nbins);
|
||||
hist_buffer_.Init(nbins);
|
||||
|
||||
// initialize histogram builder
|
||||
#pragma omp parallel
|
||||
@ -586,7 +679,7 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& n
|
||||
builder_monitor_.Start("EvaluateSplit");
|
||||
|
||||
const size_t n_nodes_in_set = nodes_set.size();
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
const size_t nthread = std::max(1, this->nthread_);
|
||||
|
||||
using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
|
||||
std::vector<FeatureSetType> features_sets(n_nodes_in_set);
|
||||
@ -604,12 +697,13 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& n
|
||||
|
||||
// Create 2D space (# of nodes to process x # of features to process)
|
||||
// to process them in parallel
|
||||
const size_t grain_size = std::max<size_t>(1, features_sets[0]->Size() / nthread);
|
||||
common::BlockedSpace2d space(n_nodes_in_set, [&](size_t nid_in_set) {
|
||||
return features_sets[nid_in_set]->Size();
|
||||
}, 1);
|
||||
}, grain_size);
|
||||
|
||||
// Start parallel enumeration for all tree nodes in the set and all features
|
||||
common::ParallelFor2d(space, [&](size_t nid_in_set, common::Range1d r) {
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) {
|
||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
GHistRow node_hist = hist[nid];
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
* Copyright 2017-2018 by Contributors
|
||||
* \file updater_quantile_hist.h
|
||||
* \brief use quantized feature values to construct a tree
|
||||
* \author Philip Cho, Tianqi Chen
|
||||
* \author Philip Cho, Tianqi Chen, Egor Smirnov
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||
@ -157,18 +157,12 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
GHistRow hist,
|
||||
bool sync_hist) {
|
||||
builder_monitor_.Start("BuildHist");
|
||||
GHistRow hist) {
|
||||
if (param_.enable_feature_grouping > 0) {
|
||||
hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist);
|
||||
} else {
|
||||
hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
|
||||
}
|
||||
if (sync_hist) {
|
||||
this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins());
|
||||
}
|
||||
builder_monitor_.Stop("BuildHist");
|
||||
}
|
||||
|
||||
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
@ -183,13 +177,15 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
protected:
|
||||
/* tree growing policies */
|
||||
struct ExpandEntry {
|
||||
static const int kRootNid = 0;
|
||||
static const int kRootNid = 0;
|
||||
static const int kEmptyNid = -1;
|
||||
int nid;
|
||||
int sibling_nid;
|
||||
int depth;
|
||||
bst_float loss_chg;
|
||||
unsigned timestamp;
|
||||
ExpandEntry(int nid, int depth, bst_float loss_chg, unsigned tstmp)
|
||||
: nid(nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {}
|
||||
ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, unsigned tstmp):
|
||||
nid(nid), sibling_nid(sibling_nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {}
|
||||
};
|
||||
|
||||
// initialize temp data structure
|
||||
@ -259,13 +255,28 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
|
||||
void BuildLocalHistograms(int *starting_index,
|
||||
int *sync_count,
|
||||
const GHistIndexMatrix &gmat,
|
||||
void BuildLocalHistograms(const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
|
||||
void AddHistRows(int *starting_index, int *sync_count);
|
||||
|
||||
void BuildHistogramsLossGuide(
|
||||
ExpandEntry entry,
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
|
||||
// Split nodes to 2 sets depending on amount of rows in each node
|
||||
// Histograms for small nodes will be built explicitly
|
||||
// Histograms for big nodes will be built by 'Subtraction Trick'
|
||||
void SplitSiblings(const std::vector<ExpandEntry>& nodes,
|
||||
std::vector<ExpandEntry>* small_siblings,
|
||||
std::vector<ExpandEntry>* big_siblings,
|
||||
RegTree *p_tree);
|
||||
|
||||
void SyncHistograms(int starting_index,
|
||||
int sync_count,
|
||||
RegTree *p_tree);
|
||||
@ -336,12 +347,15 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
std::vector<ExpandEntry> qexpand_depth_wise_;
|
||||
// key is the node id which should be calculated by Subtraction Trick, value is the node which
|
||||
// provides the evidence for substracts
|
||||
std::unordered_map<int, int> nodes_for_subtraction_trick_;
|
||||
std::vector<ExpandEntry> nodes_for_subtraction_trick_;
|
||||
// list of nodes whose histograms would be built explicitly.
|
||||
std::vector<ExpandEntry> nodes_for_explicit_hist_build_;
|
||||
|
||||
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
||||
DataLayout data_layout_;
|
||||
|
||||
common::Monitor builder_monitor_;
|
||||
common::ParallelGHistBuilder hist_buffer_;
|
||||
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
|
||||
};
|
||||
|
||||
|
||||
@ -9,6 +9,123 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
size_t GetNThreads() {
|
||||
size_t nthreads;
|
||||
#pragma omp parallel
|
||||
{
|
||||
#pragma omp master
|
||||
nthreads = omp_get_num_threads();
|
||||
}
|
||||
return nthreads;
|
||||
}
|
||||
|
||||
|
||||
TEST(ParallelGHistBuilder, Reset) {
|
||||
constexpr size_t kBins = 10;
|
||||
constexpr size_t kNodes = 5;
|
||||
constexpr size_t kNodesExtended = 10;
|
||||
constexpr size_t kTasksPerNode = 10;
|
||||
constexpr double kValue = 1.0;
|
||||
const size_t nthreads = GetNThreads();
|
||||
|
||||
HistCollection collection;
|
||||
collection.Init(kBins);
|
||||
|
||||
for(size_t inode = 0; inode < kNodesExtended; inode++) {
|
||||
collection.AddHistRow(inode);
|
||||
}
|
||||
|
||||
ParallelGHistBuilder hist_builder;
|
||||
hist_builder.Init(kBins);
|
||||
std::vector<GHistRow> target_hist(kNodes);
|
||||
for(size_t i = 0; i < target_hist.size(); ++i) {
|
||||
target_hist[i] = collection[i];
|
||||
}
|
||||
|
||||
common::BlockedSpace2d space(kNodes, [&](size_t node) { return kTasksPerNode; }, 1);
|
||||
hist_builder.Reset(nthreads, kNodes, space, target_hist);
|
||||
|
||||
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
|
||||
const size_t itask = r.begin();
|
||||
const size_t tid = omp_get_thread_num();
|
||||
|
||||
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
// fill hist by some non-null values
|
||||
for(size_t j = 0; j < kBins; ++j) {
|
||||
hist[j].Add(kValue, kValue);
|
||||
}
|
||||
});
|
||||
|
||||
// reset and extend buffer
|
||||
target_hist.resize(kNodesExtended);
|
||||
for(size_t i = 0; i < target_hist.size(); ++i) {
|
||||
target_hist[i] = collection[i];
|
||||
}
|
||||
common::BlockedSpace2d space2(kNodesExtended, [&](size_t node) { return kTasksPerNode; }, 1);
|
||||
hist_builder.Reset(nthreads, kNodesExtended, space2, target_hist);
|
||||
|
||||
common::ParallelFor2d(space2, nthreads, [&](size_t inode, common::Range1d r) {
|
||||
const size_t itask = r.begin();
|
||||
const size_t tid = omp_get_thread_num();
|
||||
|
||||
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
// fill hist by some non-null values
|
||||
for(size_t j = 0; j < kBins; ++j) {
|
||||
ASSERT_EQ(0.0, hist[j].GetGrad());
|
||||
ASSERT_EQ(0.0, hist[j].GetHess());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
TEST(ParallelGHistBuilder, ReduceHist) {
|
||||
constexpr size_t kBins = 10;
|
||||
constexpr size_t kNodes = 5;
|
||||
constexpr size_t kNodesExtended = 10;
|
||||
constexpr size_t kTasksPerNode = 10;
|
||||
constexpr double kValue = 1.0;
|
||||
const size_t nthreads = GetNThreads();
|
||||
|
||||
HistCollection collection;
|
||||
collection.Init(kBins);
|
||||
|
||||
for(size_t inode = 0; inode < kNodes; inode++) {
|
||||
collection.AddHistRow(inode);
|
||||
}
|
||||
|
||||
ParallelGHistBuilder hist_builder;
|
||||
hist_builder.Init(kBins);
|
||||
std::vector<GHistRow> target_hist(kNodes);
|
||||
for(size_t i = 0; i < target_hist.size(); ++i) {
|
||||
target_hist[i] = collection[i];
|
||||
}
|
||||
|
||||
common::BlockedSpace2d space(kNodes, [&](size_t node) { return kTasksPerNode; }, 1);
|
||||
hist_builder.Reset(nthreads, kNodes, space, target_hist);
|
||||
|
||||
// Simple analog of BuildHist function, works in parallel for both tree-nodes and data in node
|
||||
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
|
||||
const size_t itask = r.begin();
|
||||
const size_t tid = omp_get_thread_num();
|
||||
|
||||
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
for(size_t i = 0; i < kBins; ++i) {
|
||||
hist[i].Add(kValue, kValue);
|
||||
}
|
||||
});
|
||||
|
||||
for(size_t inode = 0; inode < kNodes; inode++) {
|
||||
hist_builder.ReduceHist(inode, 0, kBins);
|
||||
|
||||
// We had kTasksPerNode tasks to add kValue to each bin for each node
|
||||
// So, after reducing we expect to have (kValue * kTasksPerNode) in each node
|
||||
for(size_t i = 0; i < kBins; ++i) {
|
||||
ASSERT_EQ(kValue * kTasksPerNode, collection[inode][i].GetGrad());
|
||||
ASSERT_EQ(kValue * kTasksPerNode, collection[inode][i].GetHess());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(CutsBuilder, SearchGroupInd) {
|
||||
size_t constexpr kNumGroups = 4;
|
||||
size_t constexpr kRows = 17;
|
||||
|
||||
@ -37,7 +37,7 @@ TEST(ParallelFor2d, Test) {
|
||||
return kDim2;
|
||||
}, kGrainSize);
|
||||
|
||||
ParallelFor2d(space, [&](size_t i, Range1d r) {
|
||||
ParallelFor2d(space, 4, [&](size_t i, Range1d r) {
|
||||
for (auto j = r.begin(); j < r.end(); ++j) {
|
||||
matrix[i*kDim2 + j] += 1;
|
||||
}
|
||||
@ -65,7 +65,7 @@ TEST(ParallelFor2dNonUniform, Test) {
|
||||
working_space[i].resize(dim2[i], 0);
|
||||
}
|
||||
|
||||
ParallelFor2d(space, [&](size_t i, Range1d r) {
|
||||
ParallelFor2d(space, 4, [&](size_t i, Range1d r) {
|
||||
for (auto j = r.begin(); j < r.end(); ++j) {
|
||||
working_space[i][j] += 1;
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
GHistIndexBlockMatrix dummy;
|
||||
hist_.AddHistRow(nid);
|
||||
BuildHist(gpair, row_set_collection_[nid],
|
||||
gmat, dummy, hist_[nid], false);
|
||||
gmat, dummy, hist_[nid]);
|
||||
|
||||
// Check if number of histogram bins is correct
|
||||
ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back());
|
||||
@ -149,7 +149,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
hist_.AddHistRow(0);
|
||||
|
||||
BuildHist(row_gpairs, row_set_collection_[0],
|
||||
gmat, quantile_index_block, hist_[0], false);
|
||||
gmat, quantile_index_block, hist_[0]);
|
||||
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), tree);
|
||||
|
||||
@ -211,7 +211,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
ExpandEntry node(0, tree.GetDepth(0), snode_[0].best.loss_chg, 0);
|
||||
ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
|
||||
tree.GetDepth(0), snode_[0].best.loss_chg, 0);
|
||||
RealImpl::EvaluateSplit({node}, gmat, hist_, *(*dmat), tree);
|
||||
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user