Optimized BuildHist function (#5156)

This commit is contained in:
Egor Smirnov 2020-01-30 10:32:57 +03:00 committed by GitHub
parent 4240daed4e
commit c67163250e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 610 additions and 184 deletions

View File

@ -659,13 +659,59 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
} }
} }
/*!
* \brief fill a histogram by zeroes
*/
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
}
/*!
* \brief Increment hist as dst += add in range [begin, end)
*/
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* padd = reinterpret_cast<const FPType*>(add.data());
for (size_t i = 2 * begin; i < 2 * end; ++i) {
pdst[i] += padd[i];
}
}
/*!
* \brief Copy hist from src to dst in range [begin, end)
*/
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc = reinterpret_cast<const FPType*>(src.data());
for (size_t i = 2 * begin; i < 2 * end; ++i) {
pdst[i] = psrc[i];
}
}
/*!
* \brief Compute Subtraction: dst = src1 - src2 in range [begin, end)
*/
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
const FPType* psrc1 = reinterpret_cast<const FPType*>(src1.data());
const FPType* psrc2 = reinterpret_cast<const FPType*>(src2.data());
for (size_t i = 2 * begin; i < 2 * end; ++i) {
pdst[i] = psrc1[i] - psrc2[i];
}
}
void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair, void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
GHistRow hist) { GHistRow hist) {
const size_t nthread = static_cast<size_t>(this->nthread_);
data_.resize(nbins_ * nthread_);
const size_t* rid = row_indices.begin; const size_t* rid = row_indices.begin;
const size_t nrows = row_indices.Size(); const size_t nrows = row_indices.Size();
const uint32_t* index = gmat.index.data(); const uint32_t* index = gmat.index.data();
@ -673,34 +719,13 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const float* pgh = reinterpret_cast<const float*>(gpair.data()); const float* pgh = reinterpret_cast<const float*>(gpair.data());
double* hist_data = reinterpret_cast<double*>(hist.data()); double* hist_data = reinterpret_cast<double*>(hist.data());
double* data = reinterpret_cast<double*>(data_.data());
const size_t block_size = 512;
size_t n_blocks = nrows/block_size;
n_blocks += !!(nrows - n_blocks*block_size);
const size_t nthread_to_process = std::min(nthread, n_blocks);
memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t));
const size_t cache_line_size = 64; const size_t cache_line_size = 64;
const size_t prefetch_offset = 10; const size_t prefetch_offset = 10;
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
#pragma omp parallel for num_threads(nthread_to_process) schedule(guided) for (size_t i = 0; i < nrows; ++i) {
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
dmlc::omp_uint tid = omp_get_thread_num();
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
reinterpret_cast<double*>(data_.data() + tid * nbins_));
if (!thread_init_[tid]) {
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
thread_init_[tid] = true;
}
const size_t istart = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size);
for (size_t i = istart; i < iend; ++i) {
const size_t icol_start = row_ptr[rid[i]]; const size_t icol_start = row_ptr[rid[i]];
const size_t icol_end = row_ptr[rid[i]+1]; const size_t icol_end = row_ptr[rid[i]+1];
@ -713,39 +738,8 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const uint32_t idx_bin = 2*index[j]; const uint32_t idx_bin = 2*index[j];
const size_t idx_gh = 2*rid[i]; const size_t idx_gh = 2*rid[i];
data_local_hist[idx_bin] += pgh[idx_gh]; hist_data[idx_bin] += pgh[idx_gh];
data_local_hist[idx_bin+1] += pgh[idx_gh+1]; hist_data[idx_bin+1] += pgh[idx_gh+1];
}
}
}
if (nthread_to_process > 1) {
const size_t size = (2*nbins_);
const size_t block_size = 1024;
size_t n_blocks = size/block_size;
n_blocks += !!(size - n_blocks*block_size);
size_t n_worked_bins = 0;
for (size_t i = 0; i < nthread_to_process; ++i) {
if (thread_init_[i]) {
thread_init_[n_worked_bins++] = i;
}
}
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
const size_t istart = iblock * block_size;
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
const size_t bin = 2 * thread_init_[0] * nbins_;
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));
for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
for (size_t i = istart; i < iend; i++) {
hist_data[i] += data[bin + i];
}
}
} }
} }
} }
@ -801,10 +795,6 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
} }
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
tree::GradStats* p_self = self.data();
tree::GradStats* p_sibling = sibling.data();
tree::GradStats* p_parent = parent.data();
const size_t size = self.size(); const size_t size = self.size();
CHECK_EQ(sibling.size(), size); CHECK_EQ(sibling.size(), size);
CHECK_EQ(parent.size(), size); CHECK_EQ(parent.size(), size);
@ -816,9 +806,7 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
for (omp_ulong iblock = 0; iblock < n_blocks; ++iblock) { for (omp_ulong iblock = 0; iblock < n_blocks; ++iblock) {
const size_t ibegin = iblock*block_size; const size_t ibegin = iblock*block_size;
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size); const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) { SubtractionHist(self, parent, sibling, ibegin, iend);
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
}
} }
} }

View File

@ -14,8 +14,10 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <map>
#include "row_set.h" #include "row_set.h"
#include "threading_utils.h"
#include "../tree/param.h" #include "../tree/param.h"
#include "./quantile.h" #include "./quantile.h"
#include "./timer.h" #include "./timer.h"
@ -343,13 +345,34 @@ class GHistIndexBlockMatrix {
}; };
/*! /*!
* \brief histogram of graident statistics for a single node. * \brief histogram of gradient statistics for a single node.
* Consists of multiple GradStats, each entry showing total graident statistics * Consists of multiple GradStats, each entry showing total gradient statistics
* for that particular bin * for that particular bin
* Uses global bin id so as to represent all features simultaneously * Uses global bin id so as to represent all features simultaneously
*/ */
using GHistRow = Span<tree::GradStats>; using GHistRow = Span<tree::GradStats>;
/*!
* \brief fill a histogram by zeros
*/
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end);
/*!
* \brief Increment hist as dst += add in range [begin, end)
*/
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end);
/*!
* \brief Copy hist from src to dst in range [begin, end)
*/
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end);
/*!
* \brief Compute Subtraction: dst = src1 - src2 in range [begin, end)
*/
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
size_t begin, size_t end);
/*! /*!
* \brief histogram of gradient statistics for multiple nodes * \brief histogram of gradient statistics for multiple nodes
*/ */
@ -372,10 +395,14 @@ class HistCollection {
// initialize histogram collection // initialize histogram collection
void Init(uint32_t nbins) { void Init(uint32_t nbins) {
if (nbins_ != nbins) {
nbins_ = nbins; nbins_ = nbins;
row_ptr_.clear(); // quite expensive operation, so let's do this only once
data_.clear(); data_.clear();
} }
row_ptr_.clear();
n_nodes_added_ = 0;
}
// create an empty histogram for i-th node // create an empty histogram for i-th node
void AddHistRow(bst_uint nid) { void AddHistRow(bst_uint nid) {
@ -385,20 +412,201 @@ class HistCollection {
} }
CHECK_EQ(row_ptr_[nid], kMax); CHECK_EQ(row_ptr_[nid], kMax);
row_ptr_[nid] = data_.size(); if (data_.size() < nbins_ * (nid + 1)) {
data_.resize(data_.size() + nbins_); data_.resize(nbins_ * (nid + 1));
}
row_ptr_[nid] = nbins_ * n_nodes_added_;
n_nodes_added_++;
} }
private: private:
/*! \brief number of all bins over all features */ /*! \brief number of all bins over all features */
uint32_t nbins_; uint32_t nbins_ = 0;
/*! \brief amount of active nodes in hist collection */
uint32_t n_nodes_added_ = 0;
std::vector<tree::GradStats> data_; std::vector<tree::GradStats> data_;
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */ /*! \brief row_ptr_[nid] locates bin for histogram of node nid */
std::vector<size_t> row_ptr_; std::vector<size_t> row_ptr_;
}; };
/*!
* \brief Stores temporary histograms to compute them in parallel
* Supports processing multiple tree-nodes for nested parallelism
* Able to reduce histograms across threads in efficient way
*/
class ParallelGHistBuilder {
public:
void Init(size_t nbins) {
if (nbins != nbins_) {
hist_buffer_.Init(nbins);
nbins_ = nbins;
}
}
// Add new elements if needed, mark all hists as unused
// targeted_hists - already allocated hists which should contain final results after Reduce() call
void Reset(size_t nthreads, size_t nodes, const BlockedSpace2d& space,
const std::vector<GHistRow>& targeted_hists) {
hist_buffer_.Init(nbins_);
tid_nid_to_hist_.clear();
hist_memory_.clear();
threads_to_nids_map_.clear();
targeted_hists_ = targeted_hists;
CHECK_EQ(nodes, targeted_hists.size());
nodes_ = nodes;
nthreads_ = nthreads;
MatchThreadsToNodes(space);
AllocateAdditionalHistograms();
MatchNodeNidPairToHist();
hist_was_used_.resize(nthreads * nodes_);
std::fill(hist_was_used_.begin(), hist_was_used_.end(), static_cast<int>(false));
}
// Get specified hist, initialize hist by zeros if it wasn't used before
GHistRow GetInitializedHist(size_t tid, size_t nid) {
CHECK_LT(nid, nodes_);
CHECK_LT(tid, nthreads_);
size_t idx = tid_nid_to_hist_.at({tid, nid});
GHistRow hist = hist_memory_[idx];
if (!hist_was_used_[tid * nodes_ + nid]) {
InitilizeHistByZeroes(hist, 0, hist.size());
hist_was_used_[tid * nodes_ + nid] = static_cast<int>(true);
}
return hist;
}
// Reduce following bins (begin, end] for nid-node in dst across threads
void ReduceHist(size_t nid, size_t begin, size_t end) {
CHECK_GT(end, begin);
CHECK_LT(nid, nodes_);
GHistRow dst = targeted_hists_[nid];
bool is_updated = false;
for (size_t tid = 0; tid < nthreads_; ++tid) {
if (hist_was_used_[tid * nodes_ + nid]) {
is_updated = true;
const size_t idx = tid_nid_to_hist_.at({tid, nid});
GHistRow src = hist_memory_[idx];
if (dst.data() != src.data()) {
IncrementHist(dst, src, begin, end);
}
}
}
if (!is_updated) {
// In distributed mode - some tree nodes can be empty on local machines,
// So we need just set local hist by zeros in this case
InitilizeHistByZeroes(dst, begin, end);
}
}
protected:
void MatchThreadsToNodes(const BlockedSpace2d& space) {
const size_t space_size = space.Size();
const size_t chunck_size = space_size / nthreads_ + !!(space_size % nthreads_);
threads_to_nids_map_.resize(nthreads_ * nodes_, false);
for (size_t tid = 0; tid < nthreads_; ++tid) {
size_t begin = chunck_size * tid;
size_t end = std::min(begin + chunck_size, space_size);
if (begin < space_size) {
size_t nid_begin = space.GetFirstDimension(begin);
size_t nid_end = space.GetFirstDimension(end-1);
for (size_t nid = nid_begin; nid <= nid_end; ++nid) {
// true - means thread 'tid' will work to compute partial hist for node 'nid'
threads_to_nids_map_[tid * nodes_ + nid] = true;
}
}
}
}
void AllocateAdditionalHistograms() {
size_t hist_allocated_additionally = 0;
for (size_t nid = 0; nid < nodes_; ++nid) {
int nthreads_for_nid = 0;
for (size_t tid = 0; tid < nthreads_; ++tid) {
if (threads_to_nids_map_[tid * nodes_ + nid]) {
nthreads_for_nid++;
}
}
// In distributed mode - some tree nodes can be empty on local machines,
// set nthreads_for_nid to 0 in this case.
// In another case - allocate additional (nthreads_for_nid - 1) histograms,
// because one is already allocated externally (will store final result for the node).
hist_allocated_additionally += std::max<int>(0, nthreads_for_nid - 1);
}
for (size_t i = 0; i < hist_allocated_additionally; ++i) {
hist_buffer_.AddHistRow(i);
}
}
void MatchNodeNidPairToHist() {
size_t hist_total = 0;
size_t hist_allocated_additionally = 0;
for (size_t nid = 0; nid < nodes_; ++nid) {
bool first_hist = true;
for (size_t tid = 0; tid < nthreads_; ++tid) {
if (threads_to_nids_map_[tid * nodes_ + nid]) {
if (first_hist) {
hist_memory_.push_back(targeted_hists_[nid]);
first_hist = false;
} else {
hist_memory_.push_back(hist_buffer_[hist_allocated_additionally]);
hist_allocated_additionally++;
}
// map pair {tid, nid} to index of allocated histogram from hist_memory_
tid_nid_to_hist_[{tid, nid}] = hist_total++;
CHECK_EQ(hist_total, hist_memory_.size());
}
}
}
}
/*! \brief number of bins in each histogram */
size_t nbins_ = 0;
/*! \brief number of threads for parallel computation */
size_t nthreads_ = 0;
/*! \brief number of nodes which will be processed in parallel */
size_t nodes_ = 0;
/*! \brief Buffer for additional histograms for Parallel processing */
HistCollection hist_buffer_;
/*!
* \brief Marks which hists were used, it means that they should be merged.
* Contains only {true or false} values
* but 'int' is used instead of 'bool', because std::vector<bool> isn't thread safe
*/
std::vector<int> hist_was_used_;
/*! \brief Buffer for additional histograms for Parallel processing */
std::vector<bool> threads_to_nids_map_;
/*! \brief Contains histograms for final results */
std::vector<GHistRow> targeted_hists_;
/*! \brief Allocated memory for histograms used for construction */
std::vector<GHistRow> hist_memory_;
/*! \brief map pair {tid, nid} to index of allocated histogram from hist_memory_ */
std::map<std::pair<size_t, size_t>, size_t> tid_nid_to_hist_;
};
/*! /*!
* \brief builder for histograms of gradient statistics * \brief builder for histograms of gradient statistics
*/ */
@ -408,7 +616,6 @@ class GHistBuilder {
inline void Init(size_t nthread, uint32_t nbins) { inline void Init(size_t nthread, uint32_t nbins) {
nthread_ = nthread; nthread_ = nthread;
nbins_ = nbins; nbins_ = nbins;
thread_init_.resize(nthread_);
} }
// construct a histogram via histogram aggregation // construct a histogram via histogram aggregation
@ -433,8 +640,6 @@ class GHistBuilder {
size_t nthread_; size_t nthread_;
/*! \brief number of all bins over all features */ /*! \brief number of all bins over all features */
uint32_t nbins_; uint32_t nbins_;
std::vector<size_t> thread_init_;
std::vector<tree::GradStats> data_;
}; };

View File

@ -108,13 +108,20 @@ class BlockedSpace2d {
// Wrapper to implement nested parallelism with simple omp parallel for // Wrapper to implement nested parallelism with simple omp parallel for
template<typename Func> template<typename Func>
void ParallelFor2d(const BlockedSpace2d& space, Func func) { void ParallelFor2d(const BlockedSpace2d& space, const int nthreads, Func func) {
const int num_blocks_in_space = static_cast<int>(space.Size()); const size_t num_blocks_in_space = space.Size();
#pragma omp parallel for #pragma omp parallel num_threads(nthreads)
for (auto i = 0; i < num_blocks_in_space; i++) { {
size_t tid = omp_get_thread_num();
size_t chunck_size = num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads);
size_t begin = chunck_size * tid;
size_t end = std::min(begin + chunck_size, num_blocks_in_space);
for (auto i = begin; i < end; i++) {
func(space.GetFirstDimension(i), space.GetRange(i)); func(space.GetFirstDimension(i), space.GetRange(i));
} }
}
} }
} // namespace common } // namespace common

View File

@ -2,7 +2,7 @@
* Copyright 2017-2018 by Contributors * Copyright 2017-2018 by Contributors
* \file updater_quantile_hist.cc * \file updater_quantile_hist.cc
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn * \author Philip Cho, Tianqi Checn, Egor Smirnov
*/ */
#include <dmlc/timer.h> #include <dmlc/timer.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
@ -44,7 +44,7 @@ void QuantileHistMaker::Configure(const Args& args) {
pruner_->Configure(args); pruner_->Configure(args);
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
// initialise the split evaluator // initialize the split evaluator
if (!spliteval_) { if (!spliteval_) {
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator)); spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
} }
@ -100,66 +100,121 @@ void QuantileHistMaker::Builder::SyncHistograms(
int sync_count, int sync_count,
RegTree *p_tree) { RegTree *p_tree) {
builder_monitor_.Start("SyncHistograms"); builder_monitor_.Start("SyncHistograms");
const bool isDistributed = rabit::IsDistributed();
const size_t nbins = hist_builder_.GetNumBins();
common::BlockedSpace2d space(nodes_for_explicit_hist_build_.size(), [&](size_t node) {
return nbins;
}, 1024);
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
const auto entry = nodes_for_explicit_hist_build_[node];
auto this_hist = hist_[entry.nid];
// Merging histograms from each thread into once
hist_buffer_.ReduceHist(node, r.begin(), r.end());
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1 && !isDistributed) {
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[entry.sibling_nid];
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
}
});
if (isDistributed) {
this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count); this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count);
// use Subtraction Trick // use Subtraction Trick
for (auto const& node_pair : nodes_for_subtraction_trick_) { for (auto const& node : nodes_for_subtraction_trick_) {
hist_.AddHistRow(node_pair.first); SubtractionTrick(hist_[node.nid], hist_[node.sibling_nid],
SubtractionTrick(hist_[node_pair.first], hist_[node_pair.second], hist_[(*p_tree)[node.nid].Parent()]);
hist_[(*p_tree)[node_pair.first].Parent()]);
} }
}
builder_monitor_.Stop("SyncHistograms"); builder_monitor_.Stop("SyncHistograms");
} }
void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
ExpandEntry entry,
const GHistIndexMatrix &gmat,
const GHistIndexBlockMatrix &gmatb,
RegTree *p_tree,
const std::vector<GradientPair> &gpair_h) {
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
nodes_for_explicit_hist_build_.push_back(entry);
if (entry.sibling_nid > -1) {
nodes_for_subtraction_trick_.emplace_back(entry.sibling_nid, entry.nid,
p_tree->GetDepth(entry.sibling_nid), 0.0f, 0);
}
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
AddHistRows(&starting_index, &sync_count);
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
SyncHistograms(starting_index, sync_count, p_tree);
}
void QuantileHistMaker::Builder::AddHistRows(int *starting_index, int *sync_count) {
builder_monitor_.Start("AddHistRows");
for (auto const& entry : nodes_for_explicit_hist_build_) {
int nid = entry.nid;
hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
}
(*sync_count) = nodes_for_explicit_hist_build_.size();
for (auto const& node : nodes_for_subtraction_trick_) {
hist_.AddHistRow(node.nid);
}
builder_monitor_.Stop("AddHistRows");
}
void QuantileHistMaker::Builder::BuildLocalHistograms( void QuantileHistMaker::Builder::BuildLocalHistograms(
int *starting_index,
int *sync_count,
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
const GHistIndexBlockMatrix &gmatb, const GHistIndexBlockMatrix &gmatb,
RegTree *p_tree, RegTree *p_tree,
const std::vector<GradientPair> &gpair_h) { const std::vector<GradientPair> &gpair_h) {
builder_monitor_.Start("BuildLocalHistograms"); builder_monitor_.Start("BuildLocalHistograms");
for (auto const& entry : qexpand_depth_wise_) {
int nid = entry.nid; const size_t n_nodes = nodes_for_explicit_hist_build_.size();
RegTree::Node &node = (*p_tree)[nid];
if (rabit::IsDistributed()) { // create space of size (# rows in each node)
if (node.IsRoot() || node.IsLeftChild()) { common::BlockedSpace2d space(n_nodes, [&](size_t node) {
hist_.AddHistRow(nid); const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
// in distributed setting, we always calculate from left child or root node return row_set_collection_[nid].Size();
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false); }, 256);
if (!node.IsRoot()) {
nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid; std::vector<GHistRow> target_hists(n_nodes);
} for (size_t i = 0; i < n_nodes; ++i) {
(*sync_count)++; const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
(*starting_index) = std::min((*starting_index), nid); target_hists[i] = hist_[nid];
}
} else {
if (!node.IsRoot() && node.IsLeftChild() &&
(row_set_collection_[nid].Size() <
row_set_collection_[(*p_tree)[node.Parent()].RightChild()].Size())) {
hist_.AddHistRow(nid);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false);
nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].RightChild()] = nid;
(*sync_count)++;
(*starting_index) = std::min((*starting_index), nid);
} else if (!node.IsRoot() && !node.IsLeftChild() &&
(row_set_collection_[nid].Size() <=
row_set_collection_[(*p_tree)[node.Parent()].LeftChild()].Size())) {
hist_.AddHistRow(nid);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false);
nodes_for_subtraction_trick_[(*p_tree)[node.Parent()].LeftChild()] = nid;
(*sync_count)++;
(*starting_index) = std::min((*starting_index), nid);
} else if (node.IsRoot()) {
hist_.AddHistRow(nid);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, hist_[nid], false);
(*sync_count)++;
(*starting_index) = std::min((*starting_index), nid);
}
}
} }
hist_buffer_.Reset(this->nthread_, n_nodes, space, target_hists);
// Parallel processing by nodes and data in each node
common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) {
const auto tid = static_cast<unsigned>(omp_get_thread_num());
const int32_t nid = nodes_for_explicit_hist_build_[nid_in_set].nid;
auto start_of_row_set = row_set_collection_[nid].begin;
auto rid_set = RowSetCollection::Elem(start_of_row_set + r.begin(),
start_of_row_set + r.end(),
nid);
BuildHist(gpair_h, rid_set, gmat, gmatb, hist_buffer_.GetInitializedHist(tid, nid_in_set));
});
builder_monitor_.Stop("BuildLocalHistograms"); builder_monitor_.Stop("BuildLocalHistograms");
} }
void QuantileHistMaker::Builder::BuildNodeStats( void QuantileHistMaker::Builder::BuildNodeStats(
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
DMatrix *p_fmat, DMatrix *p_fmat,
@ -193,7 +248,7 @@ void QuantileHistMaker::Builder::EvaluateSplits(
int depth, int depth,
unsigned *timestamp, unsigned *timestamp,
std::vector<ExpandEntry> *temp_qexpand_depth) { std::vector<ExpandEntry> *temp_qexpand_depth) {
this->EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree); EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree);
for (auto const& entry : qexpand_depth_wise_) { for (auto const& entry : qexpand_depth_wise_) {
int nid = entry.nid; int nid = entry.nid;
@ -206,9 +261,9 @@ void QuantileHistMaker::Builder::EvaluateSplits(
this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree); this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree);
int left_id = (*p_tree)[nid].LeftChild(); int left_id = (*p_tree)[nid].LeftChild();
int right_id = (*p_tree)[nid].RightChild(); int right_id = (*p_tree)[nid].RightChild();
temp_qexpand_depth->push_back(ExpandEntry(left_id, temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id,
p_tree->GetDepth(left_id), 0.0, (*timestamp)++)); p_tree->GetDepth(left_id), 0.0, (*timestamp)++));
temp_qexpand_depth->push_back(ExpandEntry(right_id, temp_qexpand_depth->push_back(ExpandEntry(right_id, left_id,
p_tree->GetDepth(right_id), 0.0, (*timestamp)++)); p_tree->GetDepth(right_id), 0.0, (*timestamp)++));
// - 1 parent + 2 new children // - 1 parent + 2 new children
(*num_leaves)++; (*num_leaves)++;
@ -216,6 +271,43 @@ void QuantileHistMaker::Builder::EvaluateSplits(
} }
} }
// Split nodes to 2 sets depending on amount of rows in each node
// Histograms for small nodes will be built explicitly
// Histograms for big nodes will be built by 'Subtraction Trick'
// Exception: in distributed setting, we always build the histogram for the left child node
// and use 'Subtraction Trick' to built the histogram for the right child node.
// This ensures that the workers operate on the same set of tree nodes.
void QuantileHistMaker::Builder::SplitSiblings(const std::vector<ExpandEntry>& nodes,
std::vector<ExpandEntry>* small_siblings,
std::vector<ExpandEntry>* big_siblings,
RegTree *p_tree) {
for (auto const& entry : nodes) {
int nid = entry.nid;
RegTree::Node &node = (*p_tree)[nid];
if (rabit::IsDistributed()) {
if (node.IsRoot() || node.IsLeftChild()) {
small_siblings->push_back(entry);
} else {
big_siblings->push_back(entry);
}
} else {
if (!node.IsRoot() && node.IsLeftChild() &&
(row_set_collection_[nid].Size() <
row_set_collection_[(*p_tree)[node.Parent()].RightChild()].Size())) {
small_siblings->push_back(entry);
} else if (!node.IsRoot() && !node.IsLeftChild() &&
(row_set_collection_[nid].Size() <=
row_set_collection_[(*p_tree)[node.Parent()].LeftChild()].Size())) {
small_siblings->push_back(entry);
} else if (node.IsRoot()) {
small_siblings->push_back(entry);
} else {
big_siblings->push_back(entry);
}
}
}
}
void QuantileHistMaker::Builder::ExpandWithDepthWise( void QuantileHistMaker::Builder::ExpandWithDepthWise(
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
const GHistIndexBlockMatrix &gmatb, const GHistIndexBlockMatrix &gmatb,
@ -227,21 +319,28 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
int num_leaves = 0; int num_leaves = 0;
// in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway
qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, qexpand_depth_wise_.emplace_back(ExpandEntry(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++)); p_tree->GetDepth(ExpandEntry::kRootNid), 0.0, timestamp++));
++num_leaves; ++num_leaves;
for (int depth = 0; depth < param_.max_depth + 1; depth++) { for (int depth = 0; depth < param_.max_depth + 1; depth++) {
int starting_index = std::numeric_limits<int>::max(); int starting_index = std::numeric_limits<int>::max();
int sync_count = 0; int sync_count = 0;
std::vector<ExpandEntry> temp_qexpand_depth; std::vector<ExpandEntry> temp_qexpand_depth;
BuildLocalHistograms(&starting_index, &sync_count, gmat, gmatb, p_tree, gpair_h);
SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
&nodes_for_subtraction_trick_, p_tree);
AddHistRows(&starting_index, &sync_count);
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
SyncHistograms(starting_index, sync_count, p_tree); SyncHistograms(starting_index, sync_count, p_tree);
BuildNodeStats(gmat, p_fmat, p_tree, gpair_h); BuildNodeStats(gmat, p_fmat, p_tree, gpair_h);
EvaluateSplits(gmat, column_matrix, p_fmat, p_tree, &num_leaves, depth, &timestamp, EvaluateSplits(gmat, column_matrix, p_fmat, p_tree, &num_leaves, depth, &timestamp,
&temp_qexpand_depth); &temp_qexpand_depth);
// clean up // clean up
qexpand_depth_wise_.clear(); qexpand_depth_wise_.clear();
nodes_for_subtraction_trick_.clear(); nodes_for_subtraction_trick_.clear();
nodes_for_explicit_hist_build_.clear();
if (temp_qexpand_depth.empty()) { if (temp_qexpand_depth.empty()) {
break; break;
} else { } else {
@ -262,14 +361,12 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
unsigned timestamp = 0; unsigned timestamp = 0;
int num_leaves = 0; int num_leaves = 0;
hist_.AddHistRow(ExpandEntry::kRootNid); ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
BuildHist(gpair_h, row_set_collection_[ExpandEntry::kRootNid], gmat, gmatb, p_tree->GetDepth(0), 0.0f, timestamp++);
hist_[ExpandEntry::kRootNid], true); BuildHistogramsLossGuide(node, gmat, gmatb, p_tree, gpair_h);
this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree); this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree);
ExpandEntry node(ExpandEntry::kRootNid, p_tree->GetDepth(ExpandEntry::kRootNid),
snode_[ExpandEntry::kRootNid].best.loss_chg, timestamp++);
this->EvaluateSplit({node}, gmat, hist_, *p_fmat, *p_tree); this->EvaluateSplit({node}, gmat, hist_, *p_fmat, *p_tree);
node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg; node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg;
@ -289,20 +386,20 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
const int cleft = (*p_tree)[nid].LeftChild(); const int cleft = (*p_tree)[nid].LeftChild();
const int cright = (*p_tree)[nid].RightChild(); const int cright = (*p_tree)[nid].RightChild();
hist_.AddHistRow(cleft);
hist_.AddHistRow(cright); ExpandEntry left_node(cleft, cright, p_tree->GetDepth(cleft),
0.0f, timestamp++);
ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright),
0.0f, timestamp++);
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
// in distributed mode, we need to keep consistent across workers // in distributed mode, we need to keep consistent across workers
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true); BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
} else { } else {
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft], true); BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
} else { } else {
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright], true); BuildHistogramsLossGuide(right_node, gmat, gmatb, p_tree, gpair_h);
SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
} }
} }
@ -313,11 +410,6 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
snode_[cleft].weight, snode_[cright].weight); snode_[cleft].weight, snode_[cright].weight);
interaction_constraints_.Split(nid, featureid, cleft, cright); interaction_constraints_.Split(nid, featureid, cleft, cright);
ExpandEntry left_node(cleft, p_tree->GetDepth(cleft),
snode_[cleft].best.loss_chg, timestamp++);
ExpandEntry right_node(cright, p_tree->GetDepth(cright),
snode_[cright].best.loss_chg, timestamp++);
this->EvaluateSplit({left_node, right_node}, gmat, hist_, *p_fmat, *p_tree); this->EvaluateSplit({left_node, right_node}, gmat, hist_, *p_fmat, *p_tree);
left_node.loss_chg = snode_[cleft].best.loss_chg; left_node.loss_chg = snode_[cleft].best.loss_chg;
right_node.loss_chg = snode_[cright].best.loss_chg; right_node.loss_chg = snode_[cright].best.loss_chg;
@ -427,6 +519,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
// initialize histogram collection // initialize histogram collection
uint32_t nbins = gmat.cut.Ptrs().back(); uint32_t nbins = gmat.cut.Ptrs().back();
hist_.Init(nbins); hist_.Init(nbins);
hist_buffer_.Init(nbins);
// initialize histogram builder // initialize histogram builder
#pragma omp parallel #pragma omp parallel
@ -586,7 +679,7 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& n
builder_monitor_.Start("EvaluateSplit"); builder_monitor_.Start("EvaluateSplit");
const size_t n_nodes_in_set = nodes_set.size(); const size_t n_nodes_in_set = nodes_set.size();
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); const size_t nthread = std::max(1, this->nthread_);
using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>; using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
std::vector<FeatureSetType> features_sets(n_nodes_in_set); std::vector<FeatureSetType> features_sets(n_nodes_in_set);
@ -604,12 +697,13 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& n
// Create 2D space (# of nodes to process x # of features to process) // Create 2D space (# of nodes to process x # of features to process)
// to process them in parallel // to process them in parallel
const size_t grain_size = std::max<size_t>(1, features_sets[0]->Size() / nthread);
common::BlockedSpace2d space(n_nodes_in_set, [&](size_t nid_in_set) { common::BlockedSpace2d space(n_nodes_in_set, [&](size_t nid_in_set) {
return features_sets[nid_in_set]->Size(); return features_sets[nid_in_set]->Size();
}, 1); }, grain_size);
// Start parallel enumeration for all tree nodes in the set and all features // Start parallel enumeration for all tree nodes in the set and all features
common::ParallelFor2d(space, [&](size_t nid_in_set, common::Range1d r) { common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) {
const int32_t nid = nodes_set[nid_in_set].nid; const int32_t nid = nodes_set[nid_in_set].nid;
const auto tid = static_cast<unsigned>(omp_get_thread_num()); const auto tid = static_cast<unsigned>(omp_get_thread_num());
GHistRow node_hist = hist[nid]; GHistRow node_hist = hist[nid];

View File

@ -2,7 +2,7 @@
* Copyright 2017-2018 by Contributors * Copyright 2017-2018 by Contributors
* \file updater_quantile_hist.h * \file updater_quantile_hist.h
* \brief use quantized feature values to construct a tree * \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen, Egor Smirnov
*/ */
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
@ -157,18 +157,12 @@ class QuantileHistMaker: public TreeUpdater {
const RowSetCollection::Elem row_indices, const RowSetCollection::Elem row_indices,
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
const GHistIndexBlockMatrix& gmatb, const GHistIndexBlockMatrix& gmatb,
GHistRow hist, GHistRow hist) {
bool sync_hist) {
builder_monitor_.Start("BuildHist");
if (param_.enable_feature_grouping > 0) { if (param_.enable_feature_grouping > 0) {
hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist); hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist);
} else { } else {
hist_builder_.BuildHist(gpair, row_indices, gmat, hist); hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
} }
if (sync_hist) {
this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins());
}
builder_monitor_.Stop("BuildHist");
} }
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
@ -184,12 +178,14 @@ class QuantileHistMaker: public TreeUpdater {
/* tree growing policies */ /* tree growing policies */
struct ExpandEntry { struct ExpandEntry {
static const int kRootNid = 0; static const int kRootNid = 0;
static const int kEmptyNid = -1;
int nid; int nid;
int sibling_nid;
int depth; int depth;
bst_float loss_chg; bst_float loss_chg;
unsigned timestamp; unsigned timestamp;
ExpandEntry(int nid, int depth, bst_float loss_chg, unsigned tstmp) ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, unsigned tstmp):
: nid(nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} nid(nid), sibling_nid(sibling_nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {}
}; };
// initialize temp data structure // initialize temp data structure
@ -259,13 +255,28 @@ class QuantileHistMaker: public TreeUpdater {
RegTree *p_tree, RegTree *p_tree,
const std::vector<GradientPair> &gpair_h); const std::vector<GradientPair> &gpair_h);
void BuildLocalHistograms(int *starting_index, void BuildLocalHistograms(const GHistIndexMatrix &gmat,
int *sync_count, const GHistIndexBlockMatrix &gmatb,
RegTree *p_tree,
const std::vector<GradientPair> &gpair_h);
void AddHistRows(int *starting_index, int *sync_count);
void BuildHistogramsLossGuide(
ExpandEntry entry,
const GHistIndexMatrix &gmat, const GHistIndexMatrix &gmat,
const GHistIndexBlockMatrix &gmatb, const GHistIndexBlockMatrix &gmatb,
RegTree *p_tree, RegTree *p_tree,
const std::vector<GradientPair> &gpair_h); const std::vector<GradientPair> &gpair_h);
// Split nodes to 2 sets depending on amount of rows in each node
// Histograms for small nodes will be built explicitly
// Histograms for big nodes will be built by 'Subtraction Trick'
void SplitSiblings(const std::vector<ExpandEntry>& nodes,
std::vector<ExpandEntry>* small_siblings,
std::vector<ExpandEntry>* big_siblings,
RegTree *p_tree);
void SyncHistograms(int starting_index, void SyncHistograms(int starting_index,
int sync_count, int sync_count,
RegTree *p_tree); RegTree *p_tree);
@ -336,12 +347,15 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<ExpandEntry> qexpand_depth_wise_; std::vector<ExpandEntry> qexpand_depth_wise_;
// key is the node id which should be calculated by Subtraction Trick, value is the node which // key is the node id which should be calculated by Subtraction Trick, value is the node which
// provides the evidence for substracts // provides the evidence for substracts
std::unordered_map<int, int> nodes_for_subtraction_trick_; std::vector<ExpandEntry> nodes_for_subtraction_trick_;
// list of nodes whose histograms would be built explicitly.
std::vector<ExpandEntry> nodes_for_explicit_hist_build_;
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
DataLayout data_layout_; DataLayout data_layout_;
common::Monitor builder_monitor_; common::Monitor builder_monitor_;
common::ParallelGHistBuilder hist_buffer_;
rabit::Reducer<GradStats, GradStats::Reduce> histred_; rabit::Reducer<GradStats, GradStats::Reduce> histred_;
}; };

View File

@ -9,6 +9,123 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
size_t GetNThreads() {
size_t nthreads;
#pragma omp parallel
{
#pragma omp master
nthreads = omp_get_num_threads();
}
return nthreads;
}
TEST(ParallelGHistBuilder, Reset) {
constexpr size_t kBins = 10;
constexpr size_t kNodes = 5;
constexpr size_t kNodesExtended = 10;
constexpr size_t kTasksPerNode = 10;
constexpr double kValue = 1.0;
const size_t nthreads = GetNThreads();
HistCollection collection;
collection.Init(kBins);
for(size_t inode = 0; inode < kNodesExtended; inode++) {
collection.AddHistRow(inode);
}
ParallelGHistBuilder hist_builder;
hist_builder.Init(kBins);
std::vector<GHistRow> target_hist(kNodes);
for(size_t i = 0; i < target_hist.size(); ++i) {
target_hist[i] = collection[i];
}
common::BlockedSpace2d space(kNodes, [&](size_t node) { return kTasksPerNode; }, 1);
hist_builder.Reset(nthreads, kNodes, space, target_hist);
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
const size_t itask = r.begin();
const size_t tid = omp_get_thread_num();
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
// fill hist by some non-null values
for(size_t j = 0; j < kBins; ++j) {
hist[j].Add(kValue, kValue);
}
});
// reset and extend buffer
target_hist.resize(kNodesExtended);
for(size_t i = 0; i < target_hist.size(); ++i) {
target_hist[i] = collection[i];
}
common::BlockedSpace2d space2(kNodesExtended, [&](size_t node) { return kTasksPerNode; }, 1);
hist_builder.Reset(nthreads, kNodesExtended, space2, target_hist);
common::ParallelFor2d(space2, nthreads, [&](size_t inode, common::Range1d r) {
const size_t itask = r.begin();
const size_t tid = omp_get_thread_num();
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
// fill hist by some non-null values
for(size_t j = 0; j < kBins; ++j) {
ASSERT_EQ(0.0, hist[j].GetGrad());
ASSERT_EQ(0.0, hist[j].GetHess());
}
});
}
TEST(ParallelGHistBuilder, ReduceHist) {
constexpr size_t kBins = 10;
constexpr size_t kNodes = 5;
constexpr size_t kNodesExtended = 10;
constexpr size_t kTasksPerNode = 10;
constexpr double kValue = 1.0;
const size_t nthreads = GetNThreads();
HistCollection collection;
collection.Init(kBins);
for(size_t inode = 0; inode < kNodes; inode++) {
collection.AddHistRow(inode);
}
ParallelGHistBuilder hist_builder;
hist_builder.Init(kBins);
std::vector<GHistRow> target_hist(kNodes);
for(size_t i = 0; i < target_hist.size(); ++i) {
target_hist[i] = collection[i];
}
common::BlockedSpace2d space(kNodes, [&](size_t node) { return kTasksPerNode; }, 1);
hist_builder.Reset(nthreads, kNodes, space, target_hist);
// Simple analog of BuildHist function, works in parallel for both tree-nodes and data in node
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
const size_t itask = r.begin();
const size_t tid = omp_get_thread_num();
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
for(size_t i = 0; i < kBins; ++i) {
hist[i].Add(kValue, kValue);
}
});
for(size_t inode = 0; inode < kNodes; inode++) {
hist_builder.ReduceHist(inode, 0, kBins);
// We had kTasksPerNode tasks to add kValue to each bin for each node
// So, after reducing we expect to have (kValue * kTasksPerNode) in each node
for(size_t i = 0; i < kBins; ++i) {
ASSERT_EQ(kValue * kTasksPerNode, collection[inode][i].GetGrad());
ASSERT_EQ(kValue * kTasksPerNode, collection[inode][i].GetHess());
}
}
}
TEST(CutsBuilder, SearchGroupInd) { TEST(CutsBuilder, SearchGroupInd) {
size_t constexpr kNumGroups = 4; size_t constexpr kNumGroups = 4;
size_t constexpr kRows = 17; size_t constexpr kRows = 17;

View File

@ -37,7 +37,7 @@ TEST(ParallelFor2d, Test) {
return kDim2; return kDim2;
}, kGrainSize); }, kGrainSize);
ParallelFor2d(space, [&](size_t i, Range1d r) { ParallelFor2d(space, 4, [&](size_t i, Range1d r) {
for (auto j = r.begin(); j < r.end(); ++j) { for (auto j = r.begin(); j < r.end(); ++j) {
matrix[i*kDim2 + j] += 1; matrix[i*kDim2 + j] += 1;
} }
@ -65,7 +65,7 @@ TEST(ParallelFor2dNonUniform, Test) {
working_space[i].resize(dim2[i], 0); working_space[i].resize(dim2[i], 0);
} }
ParallelFor2d(space, [&](size_t i, Range1d r) { ParallelFor2d(space, 4, [&](size_t i, Range1d r) {
for (auto j = r.begin(); j < r.end(); ++j) { for (auto j = r.begin(); j < r.end(); ++j) {
working_space[i][j] += 1; working_space[i][j] += 1;
} }

View File

@ -107,7 +107,7 @@ class QuantileHistMock : public QuantileHistMaker {
GHistIndexBlockMatrix dummy; GHistIndexBlockMatrix dummy;
hist_.AddHistRow(nid); hist_.AddHistRow(nid);
BuildHist(gpair, row_set_collection_[nid], BuildHist(gpair, row_set_collection_[nid],
gmat, dummy, hist_[nid], false); gmat, dummy, hist_[nid]);
// Check if number of histogram bins is correct // Check if number of histogram bins is correct
ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back()); ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back());
@ -149,7 +149,7 @@ class QuantileHistMock : public QuantileHistMaker {
hist_.AddHistRow(0); hist_.AddHistRow(0);
BuildHist(row_gpairs, row_set_collection_[0], BuildHist(row_gpairs, row_set_collection_[0],
gmat, quantile_index_block, hist_[0], false); gmat, quantile_index_block, hist_[0]);
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), tree); RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), tree);
@ -211,7 +211,8 @@ class QuantileHistMock : public QuantileHistMaker {
} }
/* Now compare against result given by EvaluateSplit() */ /* Now compare against result given by EvaluateSplit() */
ExpandEntry node(0, tree.GetDepth(0), snode_[0].best.loss_chg, 0); ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
tree.GetDepth(0), snode_[0].best.loss_chg, 0);
RealImpl::EvaluateSplit({node}, gmat, hist_, *(*dmat), tree); RealImpl::EvaluateSplit({node}, gmat, hist_, *(*dmat), tree);
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature); ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]); ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);