Add float32 histogram (#5624)
* new single_precision_histogram param was added. Co-authored-by: SHVETS, KIRILL <kirill.shvets@intel.com> Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
parent
e49607af19
commit
cd3d14ad0e
@ -225,12 +225,15 @@ Parameters for Tree Booster
|
||||
list is a group of indices of features that are allowed to interact with each other.
|
||||
See tutorial for more information
|
||||
|
||||
Additional parameters for `gpu_hist` tree method
|
||||
Additional parameters for `hist` and 'gpu_hist' tree method
|
||||
================================================
|
||||
|
||||
* ``single_precision_histogram``, [default=``false``]
|
||||
|
||||
- Use single precision to build histograms. See document for GPU support for more details.
|
||||
- Use single precision to build histograms instead of double precision.
|
||||
|
||||
Additional parameters for `gpu_hist` tree method
|
||||
================================================
|
||||
|
||||
* ``deterministic_histogram``, [default=``true``]
|
||||
|
||||
|
||||
@ -141,6 +141,15 @@ class GradientPairInternal {
|
||||
public:
|
||||
using ValueT = T;
|
||||
|
||||
inline void Add(const ValueT& grad, const ValueT& hess) {
|
||||
grad_ += grad;
|
||||
hess_ += hess;
|
||||
}
|
||||
|
||||
inline static void Reduce(GradientPairInternal<T>& a, const GradientPairInternal<T>& b) { // NOLINT(*)
|
||||
a += b;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
|
||||
@ -148,9 +157,8 @@ class GradientPairInternal {
|
||||
SetHess(hess);
|
||||
}
|
||||
|
||||
// Copy constructor if of same value type
|
||||
XGBOOST_DEVICE GradientPairInternal(const GradientPairInternal<T> &g)
|
||||
: grad_(g.grad_), hess_(g.hess_) {} // NOLINT
|
||||
// Copy constructor if of same value type, marked as default to be trivially_copyable
|
||||
GradientPairInternal(const GradientPairInternal<T> &g) = default;
|
||||
|
||||
// Copy constructor if different value type - use getters and setters to
|
||||
// perform conversion
|
||||
|
||||
@ -830,54 +830,78 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
||||
/*!
|
||||
* \brief fill a histogram by zeros in range [begin, end)
|
||||
*/
|
||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
|
||||
template<typename GradientSumT>
|
||||
void InitilizeHistByZeroes(GHistRow<GradientSumT> hist, size_t begin, size_t end) {
|
||||
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
std::fill(hist.begin() + begin, hist.begin() + end, tree::GradStats());
|
||||
std::fill(hist.begin() + begin, hist.begin() + end,
|
||||
xgboost::detail::GradientPairInternal<GradientSumT>());
|
||||
#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
|
||||
memset(hist.data() + begin, '\0', (end-begin)*
|
||||
sizeof(xgboost::detail::GradientPairInternal<GradientSumT>));
|
||||
#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
|
||||
}
|
||||
template void InitilizeHistByZeroes(GHistRow<float> hist, size_t begin,
|
||||
size_t end);
|
||||
template void InitilizeHistByZeroes(GHistRow<double> hist, size_t begin,
|
||||
size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Increment hist as dst += add in range [begin, end)
|
||||
*/
|
||||
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
|
||||
const FPType* padd = reinterpret_cast<const FPType*>(add.data());
|
||||
template<typename GradientSumT>
|
||||
void IncrementHist(GHistRow<GradientSumT> dst, const GHistRow<GradientSumT> add,
|
||||
size_t begin, size_t end) {
|
||||
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst.data());
|
||||
const GradientSumT* padd = reinterpret_cast<const GradientSumT*>(add.data());
|
||||
|
||||
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||
pdst[i] += padd[i];
|
||||
}
|
||||
}
|
||||
template void IncrementHist(GHistRow<float> dst, const GHistRow<float> add,
|
||||
size_t begin, size_t end);
|
||||
template void IncrementHist(GHistRow<double> dst, const GHistRow<double> add,
|
||||
size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Copy hist from src to dst in range [begin, end)
|
||||
*/
|
||||
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
|
||||
const FPType* psrc = reinterpret_cast<const FPType*>(src.data());
|
||||
template<typename GradientSumT>
|
||||
void CopyHist(GHistRow<GradientSumT> dst, const GHistRow<GradientSumT> src,
|
||||
size_t begin, size_t end) {
|
||||
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst.data());
|
||||
const GradientSumT* psrc = reinterpret_cast<const GradientSumT*>(src.data());
|
||||
|
||||
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||
pdst[i] = psrc[i];
|
||||
}
|
||||
}
|
||||
template void CopyHist(GHistRow<float> dst, const GHistRow<float> src,
|
||||
size_t begin, size_t end);
|
||||
template void CopyHist(GHistRow<double> dst, const GHistRow<double> src,
|
||||
size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Compute Subtraction: dst = src1 - src2 in range [begin, end)
|
||||
*/
|
||||
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
|
||||
template<typename GradientSumT>
|
||||
void SubtractionHist(GHistRow<GradientSumT> dst, const GHistRow<GradientSumT> src1,
|
||||
const GHistRow<GradientSumT> src2,
|
||||
size_t begin, size_t end) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* pdst = reinterpret_cast<FPType*>(dst.data());
|
||||
const FPType* psrc1 = reinterpret_cast<const FPType*>(src1.data());
|
||||
const FPType* psrc2 = reinterpret_cast<const FPType*>(src2.data());
|
||||
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst.data());
|
||||
const GradientSumT* psrc1 = reinterpret_cast<const GradientSumT*>(src1.data());
|
||||
const GradientSumT* psrc2 = reinterpret_cast<const GradientSumT*>(src2.data());
|
||||
|
||||
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||
pdst[i] = psrc1[i] - psrc2[i];
|
||||
}
|
||||
}
|
||||
template void SubtractionHist(GHistRow<float> dst, const GHistRow<float> src1,
|
||||
const GHistRow<float> src2,
|
||||
size_t begin, size_t end);
|
||||
template void SubtractionHist(GHistRow<double> dst, const GHistRow<double> src1,
|
||||
const GHistRow<double> src2,
|
||||
size_t begin, size_t end);
|
||||
|
||||
struct Prefetch {
|
||||
public:
|
||||
@ -908,7 +932,7 @@ void BuildHistDenseKernel(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const size_t n_features,
|
||||
GHistRow hist) {
|
||||
GHistRow<FPType> hist) {
|
||||
const size_t size = row_indices.Size();
|
||||
const size_t* rid = row_indices.begin;
|
||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
||||
@ -948,7 +972,7 @@ template<typename FPType, bool do_prefetch>
|
||||
void BuildHistSparseKernel(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow hist) {
|
||||
GHistRow<FPType> hist) {
|
||||
const size_t size = row_indices.Size();
|
||||
const size_t* rid = row_indices.begin;
|
||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
||||
@ -987,7 +1011,7 @@ void BuildHistSparseKernel(const std::vector<GradientPair>& gpair,
|
||||
template<typename FPType, bool do_prefetch, typename BinIdxType>
|
||||
void BuildHistDispatchKernel(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat, GHistRow hist, bool isDense) {
|
||||
const GHistIndexMatrix& gmat, GHistRow<FPType> hist, bool isDense) {
|
||||
if (isDense) {
|
||||
const size_t* row_ptr = gmat.row_ptr.data();
|
||||
const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]];
|
||||
@ -1002,7 +1026,7 @@ void BuildHistDispatchKernel(const std::vector<GradientPair>& gpair,
|
||||
template<typename FPType, bool do_prefetch>
|
||||
void BuildHistKernel(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat, const bool isDense, GHistRow hist) {
|
||||
const GHistIndexMatrix& gmat, const bool isDense, GHistRow<FPType> hist) {
|
||||
const bool is_dense = row_indices.Size() && isDense;
|
||||
switch (gmat.index.GetBinTypeSize()) {
|
||||
case kUint8BinsTypeSize:
|
||||
@ -1022,12 +1046,12 @@ void BuildHistKernel(const std::vector<GradientPair>& gpair,
|
||||
}
|
||||
}
|
||||
|
||||
void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
template<typename GradientSumT>
|
||||
void GHistBuilder<GradientSumT>::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow hist,
|
||||
GHistRowT hist,
|
||||
bool isDense) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
const size_t nrows = row_indices.Size();
|
||||
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
|
||||
|
||||
@ -1036,21 +1060,34 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
|
||||
if (contiguousBlock) {
|
||||
// contiguous memory access, built-in HW prefetching is enough
|
||||
BuildHistKernel<FPType, false>(gpair, row_indices, gmat, isDense, hist);
|
||||
BuildHistKernel<GradientSumT, false>(gpair, row_indices, gmat, isDense, hist);
|
||||
} else {
|
||||
const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size);
|
||||
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end);
|
||||
|
||||
BuildHistKernel<FPType, true>(gpair, span1, gmat, isDense, hist);
|
||||
BuildHistKernel<GradientSumT, true>(gpair, span1, gmat, isDense, hist);
|
||||
// no prefetching to avoid loading extra memory
|
||||
BuildHistKernel<FPType, false>(gpair, span2, gmat, isDense, hist);
|
||||
BuildHistKernel<GradientSumT, false>(gpair, span2, gmat, isDense, hist);
|
||||
}
|
||||
}
|
||||
template
|
||||
void GHistBuilder<float>::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow<float> hist,
|
||||
bool isDense);
|
||||
template
|
||||
void GHistBuilder<double>::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow<double> hist,
|
||||
bool isDense);
|
||||
|
||||
void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
template<typename GradientSumT>
|
||||
void GHistBuilder<GradientSumT>::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
GHistRow hist) {
|
||||
GHistRowT hist) {
|
||||
constexpr int kUnroll = 8; // loop unrolling factor
|
||||
const size_t nblock = gmatb.GetNumBlock();
|
||||
const size_t nrows = row_indices.end - row_indices.begin;
|
||||
@ -1058,7 +1095,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
#if defined(_OPENMP)
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
|
||||
#endif // defined(_OPENMP)
|
||||
tree::GradStats* p_hist = hist.data();
|
||||
xgboost::detail::GradientPairInternal<GradientSumT>* p_hist = hist.data();
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(guided)
|
||||
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
|
||||
@ -1079,7 +1116,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
|
||||
const uint32_t bin = gmat.index[j];
|
||||
p_hist[bin].Add(stat[k]);
|
||||
p_hist[bin].Add(stat[k].GetGrad(), stat[k].GetHess());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1090,13 +1127,27 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
const GradientPair stat = gpair[rid];
|
||||
for (size_t j = ibegin; j < iend; ++j) {
|
||||
const uint32_t bin = gmat.index[j];
|
||||
p_hist[bin].Add(stat);
|
||||
p_hist[bin].Add(stat.GetGrad(), stat.GetHess());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template
|
||||
void GHistBuilder<float>::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
GHistRow<float> hist);
|
||||
template
|
||||
void GHistBuilder<double>::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
GHistRow<double> hist);
|
||||
|
||||
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
|
||||
template<typename GradientSumT>
|
||||
void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT self,
|
||||
GHistRowT sibling,
|
||||
GHistRowT parent) {
|
||||
const size_t size = self.size();
|
||||
CHECK_EQ(sibling.size(), size);
|
||||
CHECK_EQ(parent.size(), size);
|
||||
@ -1111,6 +1162,14 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
|
||||
SubtractionHist(self, parent, sibling, ibegin, iend);
|
||||
}
|
||||
}
|
||||
template
|
||||
void GHistBuilder<float>::SubtractionTrick(GHistRow<float> self,
|
||||
GHistRow<float> sibling,
|
||||
GHistRow<float> parent);
|
||||
template
|
||||
void GHistBuilder<double>::SubtractionTrick(GHistRow<double> self,
|
||||
GHistRow<double> sibling,
|
||||
GHistRow<double> parent);
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -391,46 +391,52 @@ class GHistIndexBlockMatrix {
|
||||
std::vector<Block> blocks_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief histogram of gradient statistics for a single node.
|
||||
* Consists of multiple GradStats, each entry showing total gradient statistics
|
||||
* for that particular bin
|
||||
* Uses global bin id so as to represent all features simultaneously
|
||||
*/
|
||||
using GHistRow = Span<tree::GradStats>;
|
||||
template<typename GradientSumT>
|
||||
using GHistRow = Span<xgboost::detail::GradientPairInternal<GradientSumT> >;
|
||||
|
||||
/*!
|
||||
* \brief fill a histogram by zeros
|
||||
*/
|
||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end);
|
||||
template<typename GradientSumT>
|
||||
void InitilizeHistByZeroes(GHistRow<GradientSumT> hist, size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Increment hist as dst += add in range [begin, end)
|
||||
*/
|
||||
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end);
|
||||
template<typename GradientSumT>
|
||||
void IncrementHist(GHistRow<GradientSumT> dst, const GHistRow<GradientSumT> add,
|
||||
size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Copy hist from src to dst in range [begin, end)
|
||||
*/
|
||||
void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end);
|
||||
template<typename GradientSumT>
|
||||
void CopyHist(GHistRow<GradientSumT> dst, const GHistRow<GradientSumT> src,
|
||||
size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief Compute Subtraction: dst = src1 - src2 in range [begin, end)
|
||||
*/
|
||||
void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
|
||||
template<typename GradientSumT>
|
||||
void SubtractionHist(GHistRow<GradientSumT> dst, const GHistRow<GradientSumT> src1,
|
||||
const GHistRow<GradientSumT> src2,
|
||||
size_t begin, size_t end);
|
||||
|
||||
/*!
|
||||
* \brief histogram of gradient statistics for multiple nodes
|
||||
*/
|
||||
template<typename GradientSumT>
|
||||
class HistCollection {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT>;
|
||||
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
||||
|
||||
// access histogram for i-th node
|
||||
GHistRow operator[](bst_uint nid) const {
|
||||
GHistRowT operator[](bst_uint nid) const {
|
||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
CHECK_NE(row_ptr_[nid], kMax);
|
||||
tree::GradStats* ptr =
|
||||
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
|
||||
GradientPairT* ptr =
|
||||
const_cast<GradientPairT*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
|
||||
return {ptr, nbins_};
|
||||
}
|
||||
|
||||
@ -473,7 +479,7 @@ class HistCollection {
|
||||
/*! \brief amount of active nodes in hist collection */
|
||||
uint32_t n_nodes_added_ = 0;
|
||||
|
||||
std::vector<tree::GradStats> data_;
|
||||
std::vector<GradientPairT> data_;
|
||||
|
||||
/*! \brief row_ptr_[nid] locates bin for histogram of node nid */
|
||||
std::vector<size_t> row_ptr_;
|
||||
@ -484,8 +490,11 @@ class HistCollection {
|
||||
* Supports processing multiple tree-nodes for nested parallelism
|
||||
* Able to reduce histograms across threads in efficient way
|
||||
*/
|
||||
template<typename GradientSumT>
|
||||
class ParallelGHistBuilder {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT>;
|
||||
|
||||
void Init(size_t nbins) {
|
||||
if (nbins != nbins_) {
|
||||
hist_buffer_.Init(nbins);
|
||||
@ -496,7 +505,7 @@ class ParallelGHistBuilder {
|
||||
// Add new elements if needed, mark all hists as unused
|
||||
// targeted_hists - already allocated hists which should contain final results after Reduce() call
|
||||
void Reset(size_t nthreads, size_t nodes, const BlockedSpace2d& space,
|
||||
const std::vector<GHistRow>& targeted_hists) {
|
||||
const std::vector<GHistRowT>& targeted_hists) {
|
||||
hist_buffer_.Init(nbins_);
|
||||
tid_nid_to_hist_.clear();
|
||||
hist_memory_.clear();
|
||||
@ -518,12 +527,12 @@ class ParallelGHistBuilder {
|
||||
}
|
||||
|
||||
// Get specified hist, initialize hist by zeros if it wasn't used before
|
||||
GHistRow GetInitializedHist(size_t tid, size_t nid) {
|
||||
GHistRowT GetInitializedHist(size_t tid, size_t nid) {
|
||||
CHECK_LT(nid, nodes_);
|
||||
CHECK_LT(tid, nthreads_);
|
||||
|
||||
size_t idx = tid_nid_to_hist_.at({tid, nid});
|
||||
GHistRow hist = hist_memory_[idx];
|
||||
GHistRowT hist = hist_memory_[idx];
|
||||
|
||||
if (!hist_was_used_[tid * nodes_ + nid]) {
|
||||
InitilizeHistByZeroes(hist, 0, hist.size());
|
||||
@ -538,14 +547,14 @@ class ParallelGHistBuilder {
|
||||
CHECK_GT(end, begin);
|
||||
CHECK_LT(nid, nodes_);
|
||||
|
||||
GHistRow dst = targeted_hists_[nid];
|
||||
GHistRowT dst = targeted_hists_[nid];
|
||||
|
||||
bool is_updated = false;
|
||||
for (size_t tid = 0; tid < nthreads_; ++tid) {
|
||||
if (hist_was_used_[tid * nodes_ + nid]) {
|
||||
is_updated = true;
|
||||
const size_t idx = tid_nid_to_hist_.at({tid, nid});
|
||||
GHistRow src = hist_memory_[idx];
|
||||
GHistRowT src = hist_memory_[idx];
|
||||
|
||||
if (dst.data() != src.data()) {
|
||||
IncrementHist(dst, src, begin, end);
|
||||
@ -636,7 +645,7 @@ class ParallelGHistBuilder {
|
||||
/*! \brief number of nodes which will be processed in parallel */
|
||||
size_t nodes_ = 0;
|
||||
/*! \brief Buffer for additional histograms for Parallel processing */
|
||||
HistCollection hist_buffer_;
|
||||
HistCollection<GradientSumT> hist_buffer_;
|
||||
/*!
|
||||
* \brief Marks which hists were used, it means that they should be merged.
|
||||
* Contains only {true or false} values
|
||||
@ -647,9 +656,9 @@ class ParallelGHistBuilder {
|
||||
/*! \brief Buffer for additional histograms for Parallel processing */
|
||||
std::vector<bool> threads_to_nids_map_;
|
||||
/*! \brief Contains histograms for final results */
|
||||
std::vector<GHistRow> targeted_hists_;
|
||||
std::vector<GHistRowT> targeted_hists_;
|
||||
/*! \brief Allocated memory for histograms used for construction */
|
||||
std::vector<GHistRow> hist_memory_;
|
||||
std::vector<GHistRowT> hist_memory_;
|
||||
/*! \brief map pair {tid, nid} to index of allocated histogram from hist_memory_ */
|
||||
std::map<std::pair<size_t, size_t>, size_t> tid_nid_to_hist_;
|
||||
};
|
||||
@ -657,8 +666,11 @@ class ParallelGHistBuilder {
|
||||
/*!
|
||||
* \brief builder for histograms of gradient statistics
|
||||
*/
|
||||
template<typename GradientSumT>
|
||||
class GHistBuilder {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT>;
|
||||
|
||||
GHistBuilder() = default;
|
||||
GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {}
|
||||
|
||||
@ -666,15 +678,17 @@ class GHistBuilder {
|
||||
void BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow hist,
|
||||
GHistRowT hist,
|
||||
bool isDense);
|
||||
// same, with feature grouping
|
||||
void BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
GHistRow hist);
|
||||
GHistRowT hist);
|
||||
// construct a histogram via subtraction trick
|
||||
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
|
||||
void SubtractionTrick(GHistRowT self,
|
||||
GHistRowT sibling,
|
||||
GHistRowT parent);
|
||||
|
||||
uint32_t GetNumBins() const {
|
||||
return nbins_;
|
||||
|
||||
@ -332,14 +332,15 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad)
|
||||
|
||||
/*! \brief core statistics used for tree construction */
|
||||
struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
using GradType = double;
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad { 0 };
|
||||
GradType sum_grad { 0 };
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess { 0 };
|
||||
GradType sum_hess { 0 };
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
|
||||
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
|
||||
XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; }
|
||||
XGBOOST_DEVICE GradType GetHess() const { return sum_hess; }
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, GradStats s) {
|
||||
os << s.GetGrad() << "/" << s.GetHess();
|
||||
@ -354,7 +355,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
template <typename GpairT>
|
||||
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
|
||||
: sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {}
|
||||
explicit GradStats(const double grad, const double hess)
|
||||
explicit GradStats(const GradType grad, const GradType hess)
|
||||
: sum_grad(grad), sum_hess(hess) {}
|
||||
/*!
|
||||
* \brief accumulate statistics
|
||||
@ -379,7 +380,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
/*! \return whether the statistics is not used yet */
|
||||
inline bool Empty() const { return sum_hess == 0.0; }
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(double grad, double hess) {
|
||||
inline void Add(GradType grad, GradType hess) {
|
||||
sum_grad += grad;
|
||||
sum_hess += hess;
|
||||
}
|
||||
@ -425,7 +426,11 @@ struct SplitEntryContainer {
|
||||
* \param split_index the feature index where the split is on
|
||||
*/
|
||||
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||
if (this->SplitIndex() <= split_index) {
|
||||
if (std::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
|
||||
// for example when lambda = 0 & min_child_weight = 0
|
||||
// skip value in this case
|
||||
return false;
|
||||
} else if (this->SplitIndex() <= split_index) {
|
||||
return new_loss_chg > this->loss_chg;
|
||||
} else {
|
||||
return !(this->loss_chg > new_loss_chg);
|
||||
|
||||
@ -35,6 +35,8 @@ namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist);
|
||||
|
||||
DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam);
|
||||
|
||||
void QuantileHistMaker::Configure(const Args& args) {
|
||||
// initialize pruner
|
||||
if (!pruner_) {
|
||||
@ -42,7 +44,7 @@ void QuantileHistMaker::Configure(const Args& args) {
|
||||
}
|
||||
pruner_->Configure(args);
|
||||
param_.UpdateAllowUnknown(args);
|
||||
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
// initialize the split evaluator
|
||||
if (!spliteval_) {
|
||||
spliteval_.reset(SplitEvaluator::Create(param_.split_evaluator));
|
||||
@ -51,6 +53,32 @@ void QuantileHistMaker::Configure(const Args& args) {
|
||||
spliteval_->Init(¶m_);
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::SetBuilder(std::unique_ptr<Builder<GradientSumT>>* builder,
|
||||
DMatrix *dmat) {
|
||||
builder->reset(new Builder<GradientSumT>(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
int_constraint_, dmat));
|
||||
if (rabit::IsDistributed()) {
|
||||
(*builder)->SetHistSynchronizer(new DistributedHistSynchronizer<GradientSumT>());
|
||||
(*builder)->SetHistRowsAdder(new DistributedHistRowsAdder<GradientSumT>());
|
||||
} else {
|
||||
(*builder)->SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
|
||||
(*builder)->SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
|
||||
HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
for (auto tree : trees) {
|
||||
builder->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree);
|
||||
}
|
||||
}
|
||||
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
@ -71,22 +99,16 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
param_.learning_rate = lr / trees.size();
|
||||
int_constraint_.Configure(param_, dmat->Info().num_col_);
|
||||
// build tree
|
||||
if (!builder_) {
|
||||
builder_.reset(new Builder(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
int_constraint_, dmat));
|
||||
if (rabit::IsDistributed()) {
|
||||
builder_->SetHistSynchronizer(new DistributedHistSynchronizer());
|
||||
builder_->SetHistRowsAdder(new DistributedHistRowsAdder());
|
||||
} else {
|
||||
builder_->SetHistSynchronizer(new BatchHistSynchronizer());
|
||||
builder_->SetHistRowsAdder(new BatchHistRowsAdder());
|
||||
if (hist_maker_param_.single_precision_histogram) {
|
||||
if (!float_builder_) {
|
||||
SetBuilder(&float_builder_, dmat);
|
||||
}
|
||||
}
|
||||
for (auto tree : trees) {
|
||||
builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree);
|
||||
CallBuilderUpdate(float_builder_, gpair, dmat, trees);
|
||||
} else {
|
||||
if (!double_builder_) {
|
||||
SetBuilder(&double_builder_, dmat);
|
||||
}
|
||||
CallBuilderUpdate(double_builder_, gpair, dmat, trees);
|
||||
}
|
||||
|
||||
param_.learning_rate = lr;
|
||||
@ -97,14 +119,21 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) {
|
||||
if (!builder_ || param_.subsample < 1.0f) {
|
||||
if (param_.subsample < 1.0f) {
|
||||
return false;
|
||||
} else {
|
||||
return builder_->UpdatePredictionCache(data, out_preds);
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BatchHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||
template <typename GradientSumT>
|
||||
void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT* builder,
|
||||
int starting_index,
|
||||
int sync_count,
|
||||
RegTree *p_tree) {
|
||||
@ -130,7 +159,8 @@ void BatchHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||
builder->builder_monitor_.Stop("SyncHistograms");
|
||||
}
|
||||
|
||||
void DistributedHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||
template <typename GradientSumT>
|
||||
void DistributedHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT* builder,
|
||||
int starting_index,
|
||||
int sync_count,
|
||||
RegTree *p_tree) {
|
||||
@ -172,9 +202,11 @@ void DistributedHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* bui
|
||||
builder->builder_monitor_.Stop("SyncHistograms");
|
||||
}
|
||||
|
||||
void DistributedHistSynchronizer::ParallelSubtractionHist(QuantileHistMaker::Builder* builder,
|
||||
template <typename GradientSumT>
|
||||
void DistributedHistSynchronizer<GradientSumT>::ParallelSubtractionHist(
|
||||
BuilderT* builder,
|
||||
const common::BlockedSpace2d& space,
|
||||
const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
|
||||
const std::vector<ExpandEntryT>& nodes,
|
||||
const RegTree * p_tree) {
|
||||
common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) {
|
||||
const auto entry = nodes[node];
|
||||
@ -190,7 +222,8 @@ void DistributedHistSynchronizer::ParallelSubtractionHist(QuantileHistMaker::Bui
|
||||
});
|
||||
}
|
||||
|
||||
void BatchHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder,
|
||||
template <typename GradientSumT>
|
||||
void BatchHistRowsAdder<GradientSumT>::AddHistRows(BuilderT* builder,
|
||||
int *starting_index, int *sync_count,
|
||||
RegTree *p_tree) {
|
||||
builder->builder_monitor_.Start("AddHistRows");
|
||||
@ -209,7 +242,8 @@ void BatchHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder,
|
||||
builder->builder_monitor_.Stop("AddHistRows");
|
||||
}
|
||||
|
||||
void DistributedHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder,
|
||||
template <typename GradientSumT>
|
||||
void DistributedHistRowsAdder<GradientSumT>::AddHistRows(BuilderT* builder,
|
||||
int *starting_index, int *sync_count,
|
||||
RegTree *p_tree) {
|
||||
builder->builder_monitor_.Start("AddHistRows");
|
||||
@ -243,15 +277,28 @@ void DistributedHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder,
|
||||
builder->builder_monitor_.Stop("AddHistRows");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::SetHistSynchronizer(HistSynchronizer* sync) {
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::SetHistSynchronizer(
|
||||
HistSynchronizer<GradientSumT>* sync) {
|
||||
hist_synchronizer_.reset(sync);
|
||||
}
|
||||
template void QuantileHistMaker::Builder<double>::SetHistSynchronizer(
|
||||
HistSynchronizer<double>* sync);
|
||||
template void QuantileHistMaker::Builder<float>::SetHistSynchronizer(
|
||||
HistSynchronizer<float>* sync);
|
||||
|
||||
void QuantileHistMaker::Builder::SetHistRowsAdder(HistRowsAdder* adder) {
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::SetHistRowsAdder(
|
||||
HistRowsAdder<GradientSumT>* adder) {
|
||||
hist_rows_adder_.reset(adder);
|
||||
}
|
||||
template void QuantileHistMaker::Builder<double>::SetHistRowsAdder(
|
||||
HistRowsAdder<double>* sync);
|
||||
template void QuantileHistMaker::Builder<float>::SetHistRowsAdder(
|
||||
HistRowsAdder<float>* sync);
|
||||
|
||||
void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::BuildHistogramsLossGuide(
|
||||
ExpandEntry entry,
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
@ -274,7 +321,8 @@ void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
|
||||
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::BuildLocalHistograms(
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::BuildLocalHistograms(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
RegTree *p_tree,
|
||||
@ -289,7 +337,7 @@ void QuantileHistMaker::Builder::BuildLocalHistograms(
|
||||
return row_set_collection_[nid].Size();
|
||||
}, 256);
|
||||
|
||||
std::vector<GHistRow> target_hists(n_nodes);
|
||||
std::vector<GHistRowT> target_hists(n_nodes);
|
||||
for (size_t i = 0; i < n_nodes; ++i) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
|
||||
target_hists[i] = hist_[nid];
|
||||
@ -312,8 +360,8 @@ void QuantileHistMaker::Builder::BuildLocalHistograms(
|
||||
builder_monitor_.Stop("BuildLocalHistograms");
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::BuildNodeStats(
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::BuildNodeStats(
|
||||
const GHistIndexMatrix &gmat,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree,
|
||||
@ -336,8 +384,8 @@ void QuantileHistMaker::Builder::BuildNodeStats(
|
||||
}
|
||||
builder_monitor_.Stop("BuildNodeStats");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::AddSplitsToTree(
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToTree(
|
||||
const GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
@ -377,8 +425,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::EvaluateAndApplySplits(
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::EvaluateAndApplySplits(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const ColumnMatrix &column_matrix,
|
||||
RegTree *p_tree,
|
||||
@ -400,7 +448,8 @@ void QuantileHistMaker::Builder::EvaluateAndApplySplits(
|
||||
// Exception: in distributed setting, we always build the histogram for the left child node
|
||||
// and use 'Subtraction Trick' to built the histogram for the right child node.
|
||||
// This ensures that the workers operate on the same set of tree nodes.
|
||||
void QuantileHistMaker::Builder::SplitSiblings(const std::vector<ExpandEntry>& nodes,
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(const std::vector<ExpandEntry>& nodes,
|
||||
std::vector<ExpandEntry>* small_siblings,
|
||||
std::vector<ExpandEntry>* big_siblings,
|
||||
RegTree *p_tree) {
|
||||
@ -427,8 +476,8 @@ void QuantileHistMaker::Builder::SplitSiblings(const std::vector<ExpandEntry>& n
|
||||
}
|
||||
builder_monitor_.Stop("SplitSiblings");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ExpandWithDepthWise(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const GHistIndexBlockMatrix &gmatb,
|
||||
const ColumnMatrix &column_matrix,
|
||||
@ -468,8 +517,8 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ExpandWithLossGuide(
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
const ColumnMatrix& column_matrix,
|
||||
@ -545,7 +594,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
builder_monitor_.Stop("ExpandWithLossGuide");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::Update(const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
const ColumnMatrix& column_matrix,
|
||||
HostDeviceVector<GradientPair>* gpair,
|
||||
@ -574,8 +624,8 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
||||
|
||||
builder_monitor_.Stop("Update");
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
||||
template<typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
@ -624,8 +674,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
||||
builder_monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::InitSampling(const std::vector<GradientPair>& gpair,
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
std::vector<size_t>* row_indices) {
|
||||
const auto& info = fmat.Info();
|
||||
@ -682,7 +732,8 @@ void QuantileHistMaker::Builder::InitSampling(const std::vector<GradientPair>& g
|
||||
row_indices_local.resize(prefix_sum);
|
||||
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
}
|
||||
void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
@ -712,7 +763,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
{
|
||||
this->nthread_ = omp_get_num_threads();
|
||||
}
|
||||
hist_builder_ = GHistBuilder(this->nthread_, nbins);
|
||||
hist_builder_ = GHistBuilder<GradientSumT>(this->nthread_, nbins);
|
||||
|
||||
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
||||
row_indices.resize(info.num_row_);
|
||||
@ -842,7 +893,8 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
// is equal to sum of statistics for all values:
|
||||
// then - there are no missing values
|
||||
// else - there are missing values
|
||||
bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e,
|
||||
template<typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::SplitContainsMissingValues(const GradStats e,
|
||||
const NodeEntry& snode) {
|
||||
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
|
||||
return false;
|
||||
@ -852,9 +904,11 @@ bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e,
|
||||
}
|
||||
|
||||
// nodes_set - set of nodes to be processed in parallel
|
||||
void QuantileHistMaker::Builder::EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||
template<typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::EvaluateSplits(
|
||||
const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
const RegTree& tree) {
|
||||
builder_monitor_.Start("EvaluateSplits");
|
||||
|
||||
@ -886,7 +940,7 @@ void QuantileHistMaker::Builder::EvaluateSplits(const std::vector<ExpandEntry>&
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t nid_in_set, common::Range1d r) {
|
||||
const int32_t nid = nodes_set[nid_in_set].nid;
|
||||
const auto tid = static_cast<unsigned>(omp_get_thread_num());
|
||||
GHistRow node_hist = hist[nid];
|
||||
GHistRowT node_hist = hist[nid];
|
||||
|
||||
for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) {
|
||||
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set];
|
||||
@ -1014,8 +1068,9 @@ inline std::pair<size_t, size_t> PartitionSparseKernel(
|
||||
return {nleft_elems, nright_elems};
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
template <typename BinIdxType>
|
||||
void QuantileHistMaker::Builder::PartitionKernel(
|
||||
void QuantileHistMaker::Builder<GradientSumT>::PartitionKernel(
|
||||
const size_t node_in_set, const size_t nid, common::Range1d range,
|
||||
const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree) {
|
||||
const size_t* rid = row_set_collection_[nid].begin;
|
||||
@ -1068,8 +1123,9 @@ void QuantileHistMaker::Builder::PartitionKernel(
|
||||
partition_builder_.SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::FindSplitConditions(const std::vector<ExpandEntry>& nodes,
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::FindSplitConditions(
|
||||
const std::vector<ExpandEntry>& nodes,
|
||||
const RegTree& tree,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<int32_t>* split_conditions) {
|
||||
@ -1095,9 +1151,10 @@ void QuantileHistMaker::Builder::FindSplitConditions(const std::vector<ExpandEnt
|
||||
(*split_conditions)[i] = split_cond;
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes,
|
||||
RegTree* p_tree) {
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToRowSet(
|
||||
const std::vector<ExpandEntry>& nodes,
|
||||
RegTree* p_tree) {
|
||||
const size_t n_nodes = nodes.size();
|
||||
for (size_t i = 0; i < n_nodes; ++i) {
|
||||
const int32_t nid = nodes[i].nid;
|
||||
@ -1109,11 +1166,11 @@ void QuantileHistMaker::Builder::AddSplitsToRowSet(const std::vector<ExpandEntry
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::ApplySplit(const std::vector<ExpandEntry> nodes,
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::ApplySplit(const std::vector<ExpandEntry> nodes,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
RegTree* p_tree) {
|
||||
builder_monitor_.Start("ApplySplit");
|
||||
// 1. Find split condition for each split
|
||||
@ -1169,8 +1226,8 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vector<ExpandEntry> nodes
|
||||
AddSplitsToRowSet(nodes, p_tree);
|
||||
builder_monitor_.Stop("ApplySplit");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitNewNode(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
@ -1181,8 +1238,8 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
}
|
||||
|
||||
{
|
||||
auto& stats = snode_[nid].stats;
|
||||
GHistRow hist = hist_[nid];
|
||||
GHistRowT hist = hist_[nid];
|
||||
GradientPairT grad_stat;
|
||||
if (tree[nid].IsRoot()) {
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
|
||||
@ -1190,16 +1247,17 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
|
||||
auto begin = hist.data();
|
||||
for (uint32_t i = ibegin; i < iend; ++i) {
|
||||
const GradStats et = begin[i];
|
||||
stats.Add(et.sum_grad, et.sum_hess);
|
||||
const GradientPairT et = begin[i];
|
||||
grad_stat.Add(et.GetGrad(), et.GetHess());
|
||||
}
|
||||
} else {
|
||||
const RowSetCollection::Elem e = row_set_collection_[nid];
|
||||
for (const size_t* it = e.begin; it < e.end; ++it) {
|
||||
stats.Add(gpair[*it]);
|
||||
grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess());
|
||||
}
|
||||
}
|
||||
histred_.Allreduce(&snode_[nid].stats, 1);
|
||||
histred_.Allreduce(&grad_stat, 1);
|
||||
snode_[nid].stats = tree::GradStats(grad_stat.GetGrad(), grad_stat.GetHess());
|
||||
} else {
|
||||
int parent_id = tree[nid].Parent();
|
||||
if (tree[nid].IsLeftChild()) {
|
||||
@ -1225,9 +1283,10 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
// Enumerate the split values of specific feature.
|
||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing value
|
||||
// for the particular feature fid.
|
||||
template<typename GradientSumT>
|
||||
template <int d_step>
|
||||
GradStats QuantileHistMaker::Builder::EnumerateSplit(
|
||||
const GHistIndexMatrix &gmat, const GHistRow &hist, const NodeEntry &snode,
|
||||
GradStats QuantileHistMaker::Builder<GradientSumT>::EnumerateSplit(
|
||||
const GHistIndexMatrix &gmat, const GHistRowT &hist, const NodeEntry &snode,
|
||||
SplitEntry *p_best, bst_uint fid, bst_uint nodeID) const {
|
||||
CHECK(d_step == +1 || d_step == -1);
|
||||
|
||||
|
||||
@ -78,8 +78,35 @@ using xgboost::common::GHistBuilder;
|
||||
using xgboost::common::ColumnMatrix;
|
||||
using xgboost::common::Column;
|
||||
|
||||
template <typename GradientSumT>
|
||||
class HistSynchronizer;
|
||||
|
||||
template <typename GradientSumT>
|
||||
class BatchHistSynchronizer;
|
||||
|
||||
template <typename GradientSumT>
|
||||
class DistributedHistSynchronizer;
|
||||
|
||||
template <typename GradientSumT>
|
||||
class HistRowsAdder;
|
||||
|
||||
template <typename GradientSumT>
|
||||
class BatchHistRowsAdder;
|
||||
|
||||
template <typename GradientSumT>
|
||||
class DistributedHistRowsAdder;
|
||||
|
||||
// training parameters specific to this algorithm
|
||||
struct CPUHistMakerTrainParam
|
||||
: public XGBoostParameter<CPUHistMakerTrainParam> {
|
||||
bool single_precision_histogram;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||
"Use single precision to build histograms.");
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker: public TreeUpdater {
|
||||
public:
|
||||
@ -98,10 +125,12 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("train_param"), &this->param_);
|
||||
FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["train_param"] = ToJson(param_);
|
||||
out["cpu_hist_train_param"] = ToJson(hist_maker_param_);
|
||||
}
|
||||
|
||||
char const* Name() const override {
|
||||
@ -109,12 +138,21 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename GradientSumT>
|
||||
friend class HistSynchronizer;
|
||||
template <typename GradientSumT>
|
||||
friend class BatchHistSynchronizer;
|
||||
template <typename GradientSumT>
|
||||
friend class DistributedHistSynchronizer;
|
||||
|
||||
template <typename GradientSumT>
|
||||
friend class HistRowsAdder;
|
||||
template <typename GradientSumT>
|
||||
friend class BatchHistRowsAdder;
|
||||
template <typename GradientSumT>
|
||||
friend class DistributedHistRowsAdder;
|
||||
|
||||
CPUHistMakerTrainParam hist_maker_param_;
|
||||
// training parameter
|
||||
TrainParam param_;
|
||||
// quantized data matrix
|
||||
@ -142,8 +180,11 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
};
|
||||
// actual builder that runs the algorithm
|
||||
|
||||
template<typename GradientSumT>
|
||||
struct Builder {
|
||||
public:
|
||||
using GHistRowT = GHistRow<GradientSumT>;
|
||||
using GradientPairT = xgboost::detail::GradientPairInternal<GradientSumT>;
|
||||
// constructor
|
||||
explicit Builder(const TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
@ -168,7 +209,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
GHistRow hist) {
|
||||
GHistRowT hist) {
|
||||
if (param_.enable_feature_grouping > 0) {
|
||||
hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist);
|
||||
} else {
|
||||
@ -176,7 +217,9 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
inline void SubtractionTrick(GHistRowT self,
|
||||
GHistRowT sibling,
|
||||
GHistRowT parent) {
|
||||
builder_monitor_.Start("SubtractionTrick");
|
||||
hist_builder_.SubtractionTrick(self, sibling, parent);
|
||||
builder_monitor_.Stop("SubtractionTrick");
|
||||
@ -184,16 +227,17 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds);
|
||||
void SetHistSynchronizer(HistSynchronizer* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder* adder);
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
protected:
|
||||
friend class HistSynchronizer;
|
||||
friend class BatchHistSynchronizer;
|
||||
friend class DistributedHistSynchronizer;
|
||||
friend class HistRowsAdder;
|
||||
friend class BatchHistRowsAdder;
|
||||
friend class DistributedHistRowsAdder;
|
||||
friend class HistSynchronizer<GradientSumT>;
|
||||
friend class BatchHistSynchronizer<GradientSumT>;
|
||||
friend class DistributedHistSynchronizer<GradientSumT>;
|
||||
friend class HistRowsAdder<GradientSumT>;
|
||||
friend class BatchHistRowsAdder<GradientSumT>;
|
||||
friend class DistributedHistRowsAdder<GradientSumT>;
|
||||
|
||||
/* tree growing policies */
|
||||
struct ExpandEntry {
|
||||
static const int kRootNid = 0;
|
||||
@ -225,13 +269,13 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
const RegTree& tree);
|
||||
|
||||
void ApplySplit(std::vector<ExpandEntry> nodes,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const HistCollection<GradientSumT>& hist,
|
||||
RegTree* p_tree);
|
||||
|
||||
template <typename BinIdxType>
|
||||
@ -255,7 +299,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing
|
||||
// value for the particular feature fid.
|
||||
template <int d_step>
|
||||
GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRow &hist,
|
||||
GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRowT &hist,
|
||||
const NodeEntry &snode, SplitEntry *p_best,
|
||||
bst_uint fid, bst_uint nodeID) const;
|
||||
|
||||
@ -345,16 +389,16 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
||||
std::vector<NodeEntry> snode_;
|
||||
/*! \brief culmulative histogram of gradients. */
|
||||
HistCollection hist_;
|
||||
HistCollection<GradientSumT> hist_;
|
||||
/*! \brief culmulative local parent histogram of gradients. */
|
||||
HistCollection hist_local_worker_;
|
||||
HistCollection<GradientSumT> hist_local_worker_;
|
||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||
of InitNewNode() */
|
||||
uint32_t fid_least_bins_;
|
||||
/*! \brief local prediction cache; maps node id to leaf value */
|
||||
std::vector<float> leaf_value_cache_;
|
||||
|
||||
GHistBuilder hist_builder_;
|
||||
GHistBuilder<GradientSumT> hist_builder_;
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
@ -382,61 +426,92 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
DataLayout data_layout_;
|
||||
|
||||
common::Monitor builder_monitor_;
|
||||
common::ParallelGHistBuilder hist_buffer_;
|
||||
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
|
||||
std::unique_ptr<HistSynchronizer> hist_synchronizer_;
|
||||
std::unique_ptr<HistRowsAdder> hist_rows_adder_;
|
||||
common::ParallelGHistBuilder<GradientSumT> hist_buffer_;
|
||||
rabit::Reducer<GradientPairT, GradientPairT::Reduce> histred_;
|
||||
std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_;
|
||||
std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_;
|
||||
};
|
||||
common::Monitor updater_monitor_;
|
||||
std::unique_ptr<Builder> builder_;
|
||||
|
||||
template<typename GradientSumT>
|
||||
void SetBuilder(std::unique_ptr<Builder<GradientSumT>>*, DMatrix *dmat);
|
||||
|
||||
template<typename GradientSumT>
|
||||
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
|
||||
HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees);
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Builder<float>> float_builder_;
|
||||
std::unique_ptr<Builder<double>> double_builder_;
|
||||
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||
FeatureInteractionConstraintHost int_constraint_;
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
class HistSynchronizer {
|
||||
public:
|
||||
virtual void SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||
using BuilderT = QuantileHistMaker::Builder<GradientSumT>;
|
||||
|
||||
virtual void SyncHistograms(BuilderT* builder,
|
||||
int starting_index,
|
||||
int sync_count,
|
||||
RegTree *p_tree) = 0;
|
||||
virtual ~HistSynchronizer() = default;
|
||||
};
|
||||
|
||||
class BatchHistSynchronizer: public HistSynchronizer {
|
||||
template <typename GradientSumT>
|
||||
class BatchHistSynchronizer: public HistSynchronizer<GradientSumT> {
|
||||
public:
|
||||
void SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||
using BuilderT = QuantileHistMaker::Builder<GradientSumT>;
|
||||
void SyncHistograms(BuilderT* builder,
|
||||
int starting_index,
|
||||
int sync_count,
|
||||
RegTree *p_tree) override;
|
||||
};
|
||||
|
||||
class DistributedHistSynchronizer: public HistSynchronizer {
|
||||
template <typename GradientSumT>
|
||||
class DistributedHistSynchronizer: public HistSynchronizer<GradientSumT> {
|
||||
public:
|
||||
void SyncHistograms(QuantileHistMaker::Builder* builder_,
|
||||
int starting_index, int sync_count, RegTree *p_tree) override;
|
||||
using BuilderT = QuantileHistMaker::Builder<GradientSumT>;
|
||||
using ExpandEntryT = typename BuilderT::ExpandEntry;
|
||||
|
||||
void ParallelSubtractionHist(QuantileHistMaker::Builder* builder,
|
||||
void SyncHistograms(BuilderT* builder, int starting_index,
|
||||
int sync_count, RegTree *p_tree) override;
|
||||
|
||||
void ParallelSubtractionHist(BuilderT* builder,
|
||||
const common::BlockedSpace2d& space,
|
||||
const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
|
||||
const std::vector<ExpandEntryT>& nodes,
|
||||
const RegTree * p_tree);
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
class HistRowsAdder {
|
||||
public:
|
||||
virtual void AddHistRows(QuantileHistMaker::Builder* builder,
|
||||
int *starting_index, int *sync_count, RegTree *p_tree) = 0;
|
||||
using BuilderT = QuantileHistMaker::Builder<GradientSumT>;
|
||||
|
||||
virtual void AddHistRows(BuilderT* builder, int *starting_index,
|
||||
int *sync_count, RegTree *p_tree) = 0;
|
||||
virtual ~HistRowsAdder() = default;
|
||||
};
|
||||
|
||||
class BatchHistRowsAdder: public HistRowsAdder {
|
||||
template <typename GradientSumT>
|
||||
class BatchHistRowsAdder: public HistRowsAdder<GradientSumT> {
|
||||
public:
|
||||
void AddHistRows(QuantileHistMaker::Builder* builder,
|
||||
int *starting_index, int *sync_count, RegTree *p_tree) override;
|
||||
using BuilderT = QuantileHistMaker::Builder<GradientSumT>;
|
||||
void AddHistRows(BuilderT*, int *starting_index,
|
||||
int *sync_count, RegTree *p_tree) override;
|
||||
};
|
||||
|
||||
class DistributedHistRowsAdder: public HistRowsAdder {
|
||||
template <typename GradientSumT>
|
||||
class DistributedHistRowsAdder: public HistRowsAdder<GradientSumT> {
|
||||
public:
|
||||
void AddHistRows(QuantileHistMaker::Builder* builder,
|
||||
int *starting_index, int *sync_count, RegTree *p_tree) override;
|
||||
using BuilderT = QuantileHistMaker::Builder<GradientSumT>;
|
||||
void AddHistRows(BuilderT*, int *starting_index,
|
||||
int *sync_count, RegTree *p_tree) override;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -20,8 +20,8 @@ size_t GetNThreads() {
|
||||
return nthreads;
|
||||
}
|
||||
|
||||
|
||||
TEST(ParallelGHistBuilder, Reset) {
|
||||
template <typename GradientSumT>
|
||||
void ParallelGHistBuilderReset() {
|
||||
constexpr size_t kBins = 10;
|
||||
constexpr size_t kNodes = 5;
|
||||
constexpr size_t kNodesExtended = 10;
|
||||
@ -29,16 +29,16 @@ TEST(ParallelGHistBuilder, Reset) {
|
||||
constexpr double kValue = 1.0;
|
||||
const size_t nthreads = GetNThreads();
|
||||
|
||||
HistCollection collection;
|
||||
HistCollection<GradientSumT> collection;
|
||||
collection.Init(kBins);
|
||||
|
||||
for(size_t inode = 0; inode < kNodesExtended; inode++) {
|
||||
collection.AddHistRow(inode);
|
||||
}
|
||||
|
||||
ParallelGHistBuilder hist_builder;
|
||||
ParallelGHistBuilder<GradientSumT> hist_builder;
|
||||
hist_builder.Init(kBins);
|
||||
std::vector<GHistRow> target_hist(kNodes);
|
||||
std::vector<GHistRow<GradientSumT>> target_hist(kNodes);
|
||||
for(size_t i = 0; i < target_hist.size(); ++i) {
|
||||
target_hist[i] = collection[i];
|
||||
}
|
||||
@ -49,7 +49,7 @@ TEST(ParallelGHistBuilder, Reset) {
|
||||
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
|
||||
const size_t tid = omp_get_thread_num();
|
||||
|
||||
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
GHistRow<GradientSumT> hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
// fill hist by some non-null values
|
||||
for(size_t j = 0; j < kBins; ++j) {
|
||||
hist[j].Add(kValue, kValue);
|
||||
@ -67,7 +67,7 @@ TEST(ParallelGHistBuilder, Reset) {
|
||||
common::ParallelFor2d(space2, nthreads, [&](size_t inode, common::Range1d r) {
|
||||
const size_t tid = omp_get_thread_num();
|
||||
|
||||
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
GHistRow<GradientSumT> hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
// fill hist by some non-null values
|
||||
for(size_t j = 0; j < kBins; ++j) {
|
||||
ASSERT_EQ(0.0, hist[j].GetGrad());
|
||||
@ -76,23 +76,25 @@ TEST(ParallelGHistBuilder, Reset) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST(ParallelGHistBuilder, ReduceHist) {
|
||||
|
||||
template <typename GradientSumT>
|
||||
void ParallelGHistBuilderReduceHist(){
|
||||
constexpr size_t kBins = 10;
|
||||
constexpr size_t kNodes = 5;
|
||||
constexpr size_t kTasksPerNode = 10;
|
||||
constexpr double kValue = 1.0;
|
||||
const size_t nthreads = GetNThreads();
|
||||
|
||||
HistCollection collection;
|
||||
HistCollection<GradientSumT> collection;
|
||||
collection.Init(kBins);
|
||||
|
||||
for(size_t inode = 0; inode < kNodes; inode++) {
|
||||
collection.AddHistRow(inode);
|
||||
}
|
||||
|
||||
ParallelGHistBuilder hist_builder;
|
||||
ParallelGHistBuilder<GradientSumT> hist_builder;
|
||||
hist_builder.Init(kBins);
|
||||
std::vector<GHistRow> target_hist(kNodes);
|
||||
std::vector<GHistRow<GradientSumT>> target_hist(kNodes);
|
||||
for(size_t i = 0; i < target_hist.size(); ++i) {
|
||||
target_hist[i] = collection[i];
|
||||
}
|
||||
@ -104,7 +106,7 @@ TEST(ParallelGHistBuilder, ReduceHist) {
|
||||
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
|
||||
const size_t tid = omp_get_thread_num();
|
||||
|
||||
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
GHistRow<GradientSumT> hist = hist_builder.GetInitializedHist(tid, inode);
|
||||
for(size_t i = 0; i < kBins; ++i) {
|
||||
hist[i].Add(kValue, kValue);
|
||||
}
|
||||
@ -122,6 +124,21 @@ TEST(ParallelGHistBuilder, ReduceHist) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ParallelGHistBuilder, ResetDouble) {
|
||||
ParallelGHistBuilderReset<double>();
|
||||
}
|
||||
|
||||
TEST(ParallelGHistBuilder, ResetFloat) {
|
||||
ParallelGHistBuilderReset<float>();
|
||||
}
|
||||
|
||||
TEST(ParallelGHistBuilder, ReduceHistDouble) {
|
||||
ParallelGHistBuilderReduceHist<double>();
|
||||
}
|
||||
|
||||
TEST(ParallelGHistBuilder, ReduceHistFloat) {
|
||||
ParallelGHistBuilderReduceHist<float>();
|
||||
}
|
||||
|
||||
TEST(CutsBuilder, SearchGroupInd) {
|
||||
size_t constexpr kNumGroups = 4;
|
||||
|
||||
@ -21,8 +21,11 @@ namespace tree {
|
||||
class QuantileHistMock : public QuantileHistMaker {
|
||||
static double constexpr kEps = 1e-6;
|
||||
|
||||
struct BuilderMock : public QuantileHistMaker::Builder {
|
||||
using RealImpl = QuantileHistMaker::Builder;
|
||||
template <typename GradientSumT>
|
||||
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
|
||||
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
|
||||
using ExpandEntryT = typename RealImpl::ExpandEntry;
|
||||
using GHistRowT = typename RealImpl::GHistRowT;
|
||||
|
||||
BuilderMock(const TrainParam& param,
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
@ -30,7 +33,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
FeatureInteractionConstraintHost int_constraint,
|
||||
DMatrix const* fmat)
|
||||
: RealImpl(param, std::move(pruner), std::move(spliteval),
|
||||
std::move(int_constraint), fmat) {}
|
||||
std::move(int_constraint), fmat) {}
|
||||
|
||||
public:
|
||||
void TestInitData(const GHistIndexMatrix& gmat,
|
||||
@ -38,7 +41,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
ASSERT_EQ(data_layout_, kSparseData);
|
||||
ASSERT_EQ(this->data_layout_, RealImpl::kSparseData);
|
||||
|
||||
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
|
||||
* part of QuantileHist updater logic, but we include it here because
|
||||
@ -105,14 +108,14 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
// save state of global rng engine
|
||||
auto initial_rnd = common::GlobalRandom();
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
std::vector<size_t> row_indices_initial = *row_set_collection_.Data();
|
||||
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
|
||||
|
||||
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
|
||||
omp_set_num_threads(i_nthreads);
|
||||
// return initial state of global rng engine
|
||||
common::GlobalRandom() = initial_rnd;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
||||
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
|
||||
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
|
||||
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
|
||||
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
|
||||
@ -129,26 +132,26 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
|
||||
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
ASSERT_EQ(sync_count, 2);
|
||||
ASSERT_EQ(starting_index, 3);
|
||||
|
||||
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
|
||||
ASSERT_EQ(hist_.RowExists(node.nid), true);
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
|
||||
ASSERT_EQ(hist_.RowExists(node.nid), true);
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,60 +165,61 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 0
|
||||
nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
|
||||
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 1
|
||||
nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(), (*tree)[0].RightChild(),
|
||||
this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(),
|
||||
(*tree)[0].RightChild(),
|
||||
tree->GetDepth(1), 0.0f, 0);
|
||||
nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(), (*tree)[0].LeftChild(),
|
||||
this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(),
|
||||
(*tree)[0].LeftChild(),
|
||||
tree->GetDepth(2), 0.0f, 0);
|
||||
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
this->nodes_for_explicit_hist_build_.clear();
|
||||
this->nodes_for_subtraction_trick_.clear();
|
||||
// level 2
|
||||
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||
this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||
|
||||
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
||||
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
|
||||
ASSERT_EQ(n_nodes, 2);
|
||||
row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
|
||||
this->row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
|
||||
(*tree)[0].RightChild(), 4, 4);
|
||||
row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
|
||||
this->row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
|
||||
(*tree)[1].RightChild(), 2, 2);
|
||||
row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(),
|
||||
this->row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(),
|
||||
(*tree)[2].RightChild(), 2, 2);
|
||||
|
||||
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
|
||||
return row_set_collection_[nid].Size();
|
||||
const int32_t nid = this->nodes_for_explicit_hist_build_[node].nid;
|
||||
return this->row_set_collection_[nid].Size();
|
||||
}, 256);
|
||||
|
||||
std::vector<GHistRow> target_hists(n_nodes);
|
||||
for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) {
|
||||
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
|
||||
target_hists[i] = hist_[nid];
|
||||
std::vector<GHistRowT> target_hists(n_nodes);
|
||||
for (size_t i = 0; i < this->nodes_for_explicit_hist_build_.size(); ++i) {
|
||||
const int32_t nid = this->nodes_for_explicit_hist_build_[i].nid;
|
||||
target_hists[i] = this->hist_[nid];
|
||||
}
|
||||
|
||||
const size_t nbins = hist_builder_.GetNumBins();
|
||||
const size_t nbins = this->hist_builder_.GetNumBins();
|
||||
// set values to specific nodes hist
|
||||
std::vector<size_t> n_ids = {1, 2};
|
||||
for (size_t i : n_ids) {
|
||||
auto this_hist = hist_[i];
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
|
||||
auto this_hist = this->hist_[i];
|
||||
GradientSumT* p_hist = reinterpret_cast<GradientSumT*>(this_hist.data());
|
||||
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
|
||||
p_hist[bin_id] = 2*bin_id;
|
||||
}
|
||||
@ -223,41 +227,39 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
n_ids[0] = 3;
|
||||
n_ids[1] = 5;
|
||||
for (size_t i : n_ids) {
|
||||
auto this_hist = hist_[i];
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
|
||||
auto this_hist = this->hist_[i];
|
||||
GradientSumT* p_hist = reinterpret_cast<GradientSumT*>(this_hist.data());
|
||||
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
|
||||
p_hist[bin_id] = bin_id;
|
||||
}
|
||||
}
|
||||
|
||||
hist_buffer_.Reset(1, n_nodes, space, target_hists);
|
||||
this->hist_buffer_.Reset(1, n_nodes, space, target_hists);
|
||||
// sync hist
|
||||
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, tree);
|
||||
this->hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, tree);
|
||||
|
||||
auto check_hist = [] (const GHistRow parent, const GHistRow left,
|
||||
const GHistRow right, size_t begin, size_t end) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
const FPType* p_parent = reinterpret_cast<const FPType*>(parent.data());
|
||||
const FPType* p_left = reinterpret_cast<const FPType*>(left.data());
|
||||
const FPType* p_right = reinterpret_cast<const FPType*>(right.data());
|
||||
auto check_hist = [] (const GHistRowT parent, const GHistRowT left,
|
||||
const GHistRowT right, size_t begin, size_t end) {
|
||||
const GradientSumT* p_parent = reinterpret_cast<const GradientSumT*>(parent.data());
|
||||
const GradientSumT* p_left = reinterpret_cast<const GradientSumT*>(left.data());
|
||||
const GradientSumT* p_right = reinterpret_cast<const GradientSumT*>(right.data());
|
||||
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
|
||||
}
|
||||
};
|
||||
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
|
||||
auto this_hist = hist_[node.nid];
|
||||
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
auto parent_hist = hist_[parent_id];
|
||||
auto sibling_hist = hist_[node.sibling_nid];
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
}
|
||||
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
|
||||
auto this_hist = hist_[node.nid];
|
||||
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
|
||||
auto this_hist = this->hist_[node.nid];
|
||||
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||
auto parent_hist = hist_[parent_id];
|
||||
auto sibling_hist = hist_[node.sibling_nid];
|
||||
auto parent_hist = this->hist_[parent_id];
|
||||
auto sibling_hist = this->hist_[node.sibling_nid];
|
||||
|
||||
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||
}
|
||||
@ -272,13 +274,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f} };
|
||||
RealImpl::InitData(gmat, gpair, fmat, tree);
|
||||
GHistIndexBlockMatrix dummy;
|
||||
hist_.AddHistRow(nid);
|
||||
BuildHist(gpair, row_set_collection_[nid],
|
||||
gmat, dummy, hist_[nid]);
|
||||
this->hist_.AddHistRow(nid);
|
||||
this->BuildHist(gpair, this->row_set_collection_[nid],
|
||||
gmat, dummy, this->hist_[nid]);
|
||||
|
||||
// Check if number of histogram bins is correct
|
||||
ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back());
|
||||
std::vector<GradientPairPrecise> histogram_expected(hist_[nid].size());
|
||||
ASSERT_EQ(this->hist_[nid].size(), gmat.cut.Ptrs().back());
|
||||
std::vector<GradientPairPrecise> histogram_expected(this->hist_[nid].size());
|
||||
|
||||
// Compute the correct histogram (histogram_expected)
|
||||
const size_t num_row = fmat.Info().num_row_;
|
||||
@ -293,10 +295,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
// Now validate the computed histogram returned by BuildHist
|
||||
for (size_t i = 0; i < hist_[nid].size(); ++i) {
|
||||
for (size_t i = 0; i < this->hist_[nid].size(); ++i) {
|
||||
GradientPairPrecise sol = histogram_expected[i];
|
||||
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
|
||||
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
|
||||
ASSERT_NEAR(sol.GetGrad(), this->hist_[nid][i].GetGrad(), kEps);
|
||||
ASSERT_NEAR(sol.GetHess(), this->hist_[nid][i].GetHess(), kEps);
|
||||
}
|
||||
}
|
||||
|
||||
@ -313,10 +315,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
gmat.Init(dmat.get(), kMaxBins);
|
||||
|
||||
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
|
||||
hist_.AddHistRow(0);
|
||||
this->hist_.AddHistRow(0);
|
||||
|
||||
BuildHist(row_gpairs, row_set_collection_[0],
|
||||
gmat, quantile_index_block, hist_[0]);
|
||||
this->BuildHist(row_gpairs, this->row_set_collection_[0],
|
||||
gmat, quantile_index_block, this->hist_[0]);
|
||||
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
||||
|
||||
@ -331,7 +333,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
// Initialize split evaluator
|
||||
std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net"));
|
||||
evaluator->Init(¶m_);
|
||||
evaluator->Init(&this->param_);
|
||||
|
||||
// Now enumerate all feature*threshold combination to get best split
|
||||
// To simplify logic, we make some assumptions:
|
||||
@ -378,11 +380,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
|
||||
tree.GetDepth(0), snode_[0].best.loss_chg, 0);
|
||||
RealImpl::EvaluateSplits({node}, gmat, hist_, tree);
|
||||
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid,
|
||||
RealImpl::ExpandEntry::kEmptyNid,
|
||||
tree.GetDepth(0),
|
||||
this->snode_[0].best.loss_chg, 0);
|
||||
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
|
||||
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
}
|
||||
|
||||
void TestEvaluateSplitParallel(const GHistIndexBlockMatrix &quantile_index_block,
|
||||
@ -411,7 +415,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
// treat everything as dense, as this is what we intend to test here
|
||||
cm.Init(gmat, 0.0);
|
||||
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
|
||||
hist_.AddHistRow(0);
|
||||
this->hist_.AddHistRow(0);
|
||||
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
||||
|
||||
@ -430,9 +434,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
const size_t bin_id = gmat.index[offset];
|
||||
if (bin_id >= bin_id_min && bin_id < bin_id_max) {
|
||||
if (bin_id <= split) {
|
||||
left_cnt ++;
|
||||
left_cnt++;
|
||||
} else {
|
||||
right_cnt ++;
|
||||
right_cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -450,7 +454,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
RealImpl::partition_builder_.Init(1, 1, [&](size_t node_in_set) {
|
||||
return 1;
|
||||
});
|
||||
RealImpl::PartitionKernel<uint8_t>(0, 0, common::Range1d(0, kNRows), split, cm, tree);
|
||||
this->template PartitionKernel<uint8_t>(0, 0, common::Range1d(0, kNRows),
|
||||
split, cm, tree);
|
||||
RealImpl::partition_builder_.CalculateRowOffsets();
|
||||
ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt);
|
||||
ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt);
|
||||
@ -462,28 +467,47 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
int static constexpr kNRows = 8, kNCols = 16;
|
||||
std::shared_ptr<xgboost::DMatrix> dmat_;
|
||||
const std::vector<std::pair<std::string, std::string> > cfg_;
|
||||
std::shared_ptr<BuilderMock> builder_;
|
||||
std::shared_ptr<BuilderMock<float> > float_builder_;
|
||||
std::shared_ptr<BuilderMock<double> > double_builder_;
|
||||
|
||||
public:
|
||||
explicit QuantileHistMock(
|
||||
const std::vector<std::pair<std::string, std::string> >& args, bool batch = true) :
|
||||
const std::vector<std::pair<std::string, std::string> >& args,
|
||||
const bool single_precision_histogram = false, bool batch = true) :
|
||||
cfg_{args} {
|
||||
QuantileHistMaker::Configure(args);
|
||||
spliteval_->Init(¶m_);
|
||||
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
builder_.reset(
|
||||
new BuilderMock(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
int_constraint_,
|
||||
dmat_.get()));
|
||||
if (batch) {
|
||||
builder_->SetHistSynchronizer(new BatchHistSynchronizer());
|
||||
builder_->SetHistRowsAdder(new BatchHistRowsAdder());
|
||||
if (single_precision_histogram) {
|
||||
float_builder_.reset(
|
||||
new BuilderMock<float>(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
int_constraint_,
|
||||
dmat_.get()));
|
||||
if (batch) {
|
||||
float_builder_->SetHistSynchronizer(new BatchHistSynchronizer<float>());
|
||||
float_builder_->SetHistRowsAdder(new BatchHistRowsAdder<float>());
|
||||
} else {
|
||||
float_builder_->SetHistSynchronizer(new DistributedHistSynchronizer<float>());
|
||||
float_builder_->SetHistRowsAdder(new DistributedHistRowsAdder<float>());
|
||||
}
|
||||
} else {
|
||||
builder_->SetHistSynchronizer(new DistributedHistSynchronizer());
|
||||
builder_->SetHistRowsAdder(new DistributedHistRowsAdder());
|
||||
double_builder_.reset(
|
||||
new BuilderMock<double>(
|
||||
param_,
|
||||
std::move(pruner_),
|
||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||
int_constraint_,
|
||||
dmat_.get()));
|
||||
if (batch) {
|
||||
double_builder_->SetHistSynchronizer(new BatchHistSynchronizer<double>());
|
||||
double_builder_->SetHistRowsAdder(new BatchHistRowsAdder<double>());
|
||||
} else {
|
||||
double_builder_->SetHistSynchronizer(new DistributedHistSynchronizer<double>());
|
||||
double_builder_->SetHistRowsAdder(new DistributedHistRowsAdder<double>());
|
||||
}
|
||||
}
|
||||
}
|
||||
~QuantileHistMock() override = default;
|
||||
@ -501,8 +525,11 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::vector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
|
||||
builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
} else {
|
||||
float_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
}
|
||||
}
|
||||
|
||||
void TestInitDataSampling() {
|
||||
@ -516,8 +543,11 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::vector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
|
||||
builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
} else {
|
||||
float_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
}
|
||||
}
|
||||
|
||||
void TestAddHistRows() {
|
||||
@ -530,7 +560,11 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::vector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||
} else {
|
||||
float_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||
}
|
||||
}
|
||||
|
||||
void TestSyncHistograms() {
|
||||
@ -543,7 +577,11 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::vector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||
} else {
|
||||
float_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -554,22 +592,31 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat_.get(), kMaxBins);
|
||||
|
||||
builder_->TestBuildHist(0, gmat, *dmat_, tree);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestBuildHist(0, gmat, *dmat_, tree);
|
||||
} else {
|
||||
float_builder_->TestBuildHist(0, gmat, *dmat_, tree);
|
||||
}
|
||||
}
|
||||
|
||||
void TestEvaluateSplit() {
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
|
||||
builder_->TestEvaluateSplit(gmatb_, tree);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestEvaluateSplit(gmatb_, tree);
|
||||
} else {
|
||||
float_builder_->TestEvaluateSplit(gmatb_, tree);
|
||||
}
|
||||
}
|
||||
|
||||
void TestApplySplit() {
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
|
||||
builder_->TestApplySplit(gmatb_, tree);
|
||||
if (double_builder_) {
|
||||
double_builder_->TestApplySplit(gmatb_, tree);
|
||||
} else {
|
||||
float_builder_->TestEvaluateSplit(gmatb_, tree);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -578,6 +625,9 @@ TEST(QuantileHist, InitData) {
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestInitData();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestInitData();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, InitDataSampling) {
|
||||
@ -587,6 +637,9 @@ TEST(QuantileHist, InitDataSampling) {
|
||||
{"subsample", std::to_string(subsample)}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestInitDataSampling();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestInitDataSampling();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, AddHistRows) {
|
||||
@ -594,6 +647,9 @@ TEST(QuantileHist, AddHistRows) {
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestAddHistRows();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestAddHistRows();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, SyncHistograms) {
|
||||
@ -601,6 +657,9 @@ TEST(QuantileHist, SyncHistograms) {
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestSyncHistograms();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestSyncHistograms();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, DistributedAddHistRows) {
|
||||
@ -608,6 +667,9 @@ TEST(QuantileHist, DistributedAddHistRows) {
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||
QuantileHistMock maker(cfg, false);
|
||||
maker.TestAddHistRows();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestAddHistRows();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, DistributedSyncHistograms) {
|
||||
@ -615,6 +677,9 @@ TEST(QuantileHist, DistributedSyncHistograms) {
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||
QuantileHistMock maker(cfg, false);
|
||||
maker.TestSyncHistograms();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestSyncHistograms();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, BuildHist) {
|
||||
@ -624,6 +689,9 @@ TEST(QuantileHist, BuildHist) {
|
||||
{"enable_feature_grouping", std::to_string(0)}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestBuildHist();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestBuildHist();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, EvalSplits) {
|
||||
@ -634,6 +702,9 @@ TEST(QuantileHist, EvalSplits) {
|
||||
{"min_child_weight", "0"}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestEvaluateSplit();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestEvaluateSplit();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, ApplySplit) {
|
||||
@ -644,6 +715,9 @@ TEST(QuantileHist, ApplySplit) {
|
||||
{"min_child_weight", "0"}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestApplySplit();
|
||||
const bool single_precision_histogram = true;
|
||||
QuantileHistMock maker_float(cfg, single_precision_histogram);
|
||||
maker_float.TestApplySplit();
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
|
||||
@ -57,7 +57,8 @@ class TestUpdaters(unittest.TestCase):
|
||||
'max_bin': [2, 256],
|
||||
'grow_policy': ['depthwise', 'lossguide'],
|
||||
'max_leaves': [64, 0],
|
||||
'verbosity': [0]}
|
||||
'verbosity': [0],
|
||||
'single_precision_histogram': [True, False]}
|
||||
for param in parameter_combinations(variable_param):
|
||||
result = run_suite(param)
|
||||
assert_results_non_increasing(result, 1e-2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user