Clean up training code. (#3825)

* Remove GHistRow, GHistEntry, GHistIndexRow.
* Remove kSimpleStats.
* Remove CheckInfo, SetLeafVec in GradStats and in SKStats.
* Clean up the GradStats.
* Cleanup calcgain.
* Move LossChangeMissing out of common.
* Remove [] operator from GHistIndexBlock.
This commit is contained in:
Jiaming Yuan 2019-02-07 14:22:13 +08:00 committed by GitHub
parent 325b16bccd
commit 017c97b8ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 306 additions and 406 deletions

View File

@ -1005,7 +1005,7 @@ class AllReducer {
*/ */
void Synchronize() { void Synchronize() {
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
for (int i = 0; i < device_ordinals.size(); i++) { for (size_t i = 0; i < device_ordinals.size(); i++) {
dh::safe_cuda(cudaSetDevice(device_ordinals[i])); dh::safe_cuda(cudaSetDevice(device_ordinals[i]));
dh::safe_cuda(cudaStreamSynchronize(streams[i])); dh::safe_cuda(cudaStreamSynchronize(streams[i]));
} }
@ -1051,7 +1051,7 @@ template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) { void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext{[&]() { SaveCudaContext{[&]() {
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1) #pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) { for (size_t shard = 0; shard < shards->size(); ++shard) {
f(shard, shards->at(shard)); f(shard, shards->at(shard));
} }
}}; }};

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017 by Contributors * Copyright 2017-2018 by Contributors
* \file hist_util.h * \file hist_util.h
* \brief Utilities to store histograms * \brief Utilities to store histograms
* \author Philip Cho, Tianqi Chen * \author Philip Cho, Tianqi Chen
@ -417,7 +417,7 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
const size_t* row_ptr = gmat.row_ptr.data(); const size_t* row_ptr = gmat.row_ptr.data();
const float* pgh = reinterpret_cast<const float*>(gpair.data()); const float* pgh = reinterpret_cast<const float*>(gpair.data());
double* hist_data = reinterpret_cast<double*>(hist.begin); double* hist_data = reinterpret_cast<double*>(hist.data());
double* data = reinterpret_cast<double*>(data_.data()); double* data = reinterpret_cast<double*>(data_.data());
const size_t block_size = 512; const size_t block_size = 512;
@ -432,11 +432,11 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid); size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size; no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
#pragma omp parallel for num_threads(nthread_to_process) schedule(guided) #pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
dmlc::omp_uint tid = omp_get_thread_num(); dmlc::omp_uint tid = omp_get_thread_num();
double* data_local_hist = ((nthread_to_process == 1) ? hist_data : double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
reinterpret_cast<double*>(data_.data() + tid * nbins_)); reinterpret_cast<double*>(data_.data() + tid * nbins_));
if (!thread_init_[tid]) { if (!thread_init_[tid]) {
memset(data_local_hist, '\0', 2*nbins_*sizeof(double)); memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
@ -477,7 +477,7 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
} }
} }
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided) #pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
const size_t istart = iblock * block_size; const size_t istart = iblock * block_size;
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size); const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
@ -507,8 +507,9 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
#if defined(_OPENMP) #if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
#endif #endif
tree::GradStats* p_hist = hist.data();
#pragma omp parallel for num_threads(nthread) schedule(guided) #pragma omp parallel for num_threads(nthread) schedule(guided)
for (bst_omp_uint bid = 0; bid < nblock; ++bid) { for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
auto gmat = gmatb[bid]; auto gmat = gmatb[bid];
@ -517,20 +518,17 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
size_t ibegin[kUnroll]; size_t ibegin[kUnroll];
size_t iend[kUnroll]; size_t iend[kUnroll];
GradientPair stat[kUnroll]; GradientPair stat[kUnroll];
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
rid[k] = row_indices.begin[i + k]; rid[k] = row_indices.begin[i + k];
}
for (int k = 0; k < kUnroll; ++k) {
ibegin[k] = gmat.row_ptr[rid[k]]; ibegin[k] = gmat.row_ptr[rid[k]];
iend[k] = gmat.row_ptr[rid[k] + 1]; iend[k] = gmat.row_ptr[rid[k] + 1];
}
for (int k = 0; k < kUnroll; ++k) {
stat[k] = gpair[rid[k]]; stat[k] = gpair[rid[k]];
} }
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
for (size_t j = ibegin[k]; j < iend[k]; ++j) { for (size_t j = ibegin[k]; j < iend[k]; ++j) {
const uint32_t bin = gmat.index[j]; const uint32_t bin = gmat.index[j];
hist.begin[bin].Add(stat[k]); p_hist[bin].Add(stat[k]);
} }
} }
} }
@ -541,7 +539,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
const GradientPair stat = gpair[rid]; const GradientPair stat = gpair[rid];
for (size_t j = ibegin; j < iend; ++j) { for (size_t j = ibegin; j < iend; ++j) {
const uint32_t bin = gmat.index[j]; const uint32_t bin = gmat.index[j];
hist.begin[bin].Add(stat); p_hist[bin].Add(stat);
} }
} }
} }
@ -555,24 +553,27 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
#if defined(_OPENMP) #if defined(_OPENMP)
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
#endif #endif
tree::GradStats* p_self = self.data();
tree::GradStats* p_sibling = sibling.data();
tree::GradStats* p_parent = parent.data();
#pragma omp parallel for num_threads(nthread) schedule(static) #pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint bin_id = 0; for (bst_omp_uint bin_id = 0;
bin_id < static_cast<bst_omp_uint>(nbins - rest); bin_id += kUnroll) { bin_id < static_cast<bst_omp_uint>(nbins - rest); bin_id += kUnroll) {
GHistEntry pb[kUnroll]; tree::GradStats pb[kUnroll];
GHistEntry sb[kUnroll]; tree::GradStats sb[kUnroll];
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
pb[k] = parent.begin[bin_id + k]; pb[k] = p_parent[bin_id + k];
} }
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
sb[k] = sibling.begin[bin_id + k]; sb[k] = p_sibling[bin_id + k];
} }
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
self.begin[bin_id + k].SetSubtract(pb[k], sb[k]); p_self[bin_id + k].SetSubstract(pb[k], sb[k]);
} }
} }
for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) { for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]); p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
} }
} }

View File

@ -16,45 +16,8 @@
#include "../include/rabit/rabit.h" #include "../include/rabit/rabit.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! \brief sums of gradient statistics corresponding to a histogram bin */
struct GHistEntry {
/*! \brief sum of first-order gradient statistics */
double sum_grad{0};
/*! \brief sum of second-order gradient statistics */
double sum_hess{0};
GHistEntry() = default;
inline void Clear() {
sum_grad = sum_hess = 0;
}
/*! \brief add a GradientPair to the sum */
inline void Add(const GradientPair& e) {
sum_grad += e.GetGrad();
sum_hess += e.GetHess();
}
/*! \brief add a GHistEntry to the sum */
inline void Add(const GHistEntry& e) {
sum_grad += e.sum_grad;
sum_hess += e.sum_hess;
}
inline static void Reduce(GHistEntry& a, const GHistEntry& b) { // NOLINT(*)
a.Add(b);
}
/*! \brief set sum to be difference of two GHistEntry's */
inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) {
sum_grad = a.sum_grad - b.sum_grad;
sum_hess = a.sum_hess - b.sum_hess;
}
};
/*! \brief Cut configuration for all the features. */ /*! \brief Cut configuration for all the features. */
struct HistCutMatrix { struct HistCutMatrix {
/*! \brief Unit pointer to rows by element position */ /*! \brief Unit pointer to rows by element position */
@ -83,15 +46,7 @@ void DeviceSketch
* \brief A single row in global histogram index. * \brief A single row in global histogram index.
* Directly represent the global index in the histogram entry. * Directly represent the global index in the histogram entry.
*/ */
struct GHistIndexRow { using GHistIndexRow = Span<uint32_t const>;
/*! \brief The index of the histogram */
const uint32_t* index;
/*! \brief The size of the histogram */
size_t size;
GHistIndexRow() = default;
GHistIndexRow(const uint32_t* index, size_t size)
: index(index), size(size) {}
};
/*! /*!
* \brief preprocessed global index matrix, in CSR format * \brief preprocessed global index matrix, in CSR format
@ -111,7 +66,9 @@ struct GHistIndexMatrix {
void Init(DMatrix* p_fmat, int max_num_bins); void Init(DMatrix* p_fmat, int max_num_bins);
// get i-th row // get i-th row
inline GHistIndexRow operator[](size_t i) const { inline GHistIndexRow operator[](size_t i) const {
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]}; return {&index[0] + row_ptr[i],
static_cast<GHistIndexRow::index_type>(
row_ptr[i + 1] - row_ptr[i])};
} }
inline void GetFeatureCounts(size_t* counts) const { inline void GetFeatureCounts(size_t* counts) const {
auto nfeature = cut.row_ptr.size() - 1; auto nfeature = cut.row_ptr.size() - 1;
@ -134,11 +91,6 @@ struct GHistIndexBlock {
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index) inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
: row_ptr(row_ptr), index(index) {} : row_ptr(row_ptr), index(index) {}
// get i-th row
inline GHistIndexRow operator[](size_t i) const {
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
}
}; };
class ColumnMatrix; class ColumnMatrix;
@ -171,21 +123,12 @@ class GHistIndexBlockMatrix {
}; };
/*! /*!
* \brief histogram of gradient statistics for a single node. * \brief histogram of graident statistics for a single node.
* Consists of multiple GHistEntry's, each entry showing total graident statistics * Consists of multiple GradStats, each entry showing total graident statistics
* for that particular bin * for that particular bin
* Uses global bin id so as to represent all features simultaneously * Uses global bin id so as to represent all features simultaneously
*/ */
struct GHistRow { using GHistRow = Span<tree::GradStats>;
/*! \brief base pointer to first entry */
GHistEntry* begin;
/*! \brief number of entries */
uint32_t size;
GHistRow() = default;
GHistRow(GHistEntry* begin, uint32_t size)
: begin(begin), size(size) {}
};
/*! /*!
* \brief histogram of gradient statistics for multiple nodes * \brief histogram of gradient statistics for multiple nodes
@ -193,27 +136,29 @@ struct GHistRow {
class HistCollection { class HistCollection {
public: public:
// access histogram for i-th node // access histogram for i-th node
inline GHistRow operator[](bst_uint nid) const { GHistRow operator[](bst_uint nid) const {
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max(); constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
CHECK_NE(row_ptr_[nid], kMax); CHECK_NE(row_ptr_[nid], kMax);
return {const_cast<GHistEntry*>(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_}; tree::GradStats* ptr =
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
return {ptr, nbins_};
} }
// have we computed a histogram for i-th node? // have we computed a histogram for i-th node?
inline bool RowExists(bst_uint nid) const { bool RowExists(bst_uint nid) const {
const uint32_t k_max = std::numeric_limits<uint32_t>::max(); const uint32_t k_max = std::numeric_limits<uint32_t>::max();
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max); return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
} }
// initialize histogram collection // initialize histogram collection
inline void Init(uint32_t nbins) { void Init(uint32_t nbins) {
nbins_ = nbins; nbins_ = nbins;
row_ptr_.clear(); row_ptr_.clear();
data_.clear(); data_.clear();
} }
// create an empty histogram for i-th node // create an empty histogram for i-th node
inline void AddHistRow(bst_uint nid) { void AddHistRow(bst_uint nid) {
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max(); constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
if (nid >= row_ptr_.size()) { if (nid >= row_ptr_.size()) {
row_ptr_.resize(nid + 1, kMax); row_ptr_.resize(nid + 1, kMax);
@ -228,7 +173,7 @@ class HistCollection {
/*! \brief number of all bins over all features */ /*! \brief number of all bins over all features */
uint32_t nbins_; uint32_t nbins_;
std::vector<GHistEntry> data_; std::vector<tree::GradStats> data_;
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */ /*! \brief row_ptr_[nid] locates bin for historgram of node nid */
std::vector<size_t> row_ptr_; std::vector<size_t> row_ptr_;
@ -268,8 +213,8 @@ class GHistBuilder {
size_t nthread_; size_t nthread_;
/*! \brief number of all bins over all features */ /*! \brief number of all bins over all features */
uint32_t nbins_; uint32_t nbins_;
std::vector<GHistEntry> data_;
std::vector<size_t> thread_init_; std::vector<size_t> thread_init_;
std::vector<tree::GradStats> data_;
}; };

View File

@ -140,7 +140,7 @@ class GPUDistribution {
return begin; return begin;
} }
size_t ShardSize(size_t size, int index) const { size_t ShardSize(size_t size, size_t index) const {
if (size == 0) { return 0; } if (size == 0) { return 0; }
if (offsets_.size() > 0) { if (offsets_.size() > 0) {
// explicit offsets are provided // explicit offsets are provided
@ -154,7 +154,7 @@ class GPUDistribution {
return end - begin; return end - begin;
} }
size_t ShardProperSize(size_t size, int index) const { size_t ShardProperSize(size_t size, size_t index) const {
if (size == 0) { return 0; } if (size == 0) { return 0; }
return ShardSize(size, index) - (devices_.Size() - 1 > index ? overlap_ : 0); return ShardSize(size, index) - (devices_.Size() - 1 > index ? overlap_ : 0);
} }

View File

@ -554,8 +554,8 @@ class Span {
detail::ptrdiff_t _offset, detail::ptrdiff_t _offset,
detail::ptrdiff_t _count = dynamic_extent) const { detail::ptrdiff_t _count = dynamic_extent) const {
SPAN_CHECK(_offset >= 0 && _offset < size()); SPAN_CHECK(_offset >= 0 && _offset < size());
SPAN_CHECK(_count == dynamic_extent || SPAN_CHECK((_count == dynamic_extent) ||
_count >= 0 && _offset + _count <= size()); (_count >= 0 && _offset + _count <= size()));
return {data() + _offset, _count == return {data() + _offset, _count ==
dynamic_extent ? size() - _offset : _count}; dynamic_extent ? size() - _offset : _count};

View File

@ -58,12 +58,12 @@ class Transform {
public: public:
Evaluator(Functor func, Range range, GPUSet devices, bool reshard) : Evaluator(Functor func, Range range, GPUSet devices, bool reshard) :
func_(func), range_{std::move(range)}, func_(func), range_{std::move(range)},
distribution_{std::move(GPUDistribution::Block(devices))}, reshard_{reshard},
reshard_{reshard} {} distribution_{std::move(GPUDistribution::Block(devices))} {}
Evaluator(Functor func, Range range, GPUDistribution dist, Evaluator(Functor func, Range range, GPUDistribution dist,
bool reshard) : bool reshard) :
func_(func), range_{std::move(range)}, distribution_{std::move(dist)}, func_(func), range_{std::move(range)}, reshard_{reshard},
reshard_{reshard} {} distribution_{std::move(dist)} {}
/*! /*!
* \brief Evaluate the functor with input pointers to HostDeviceVector. * \brief Evaluate the functor with input pointers to HostDeviceVector.
@ -159,7 +159,7 @@ class Transform {
template <typename... HDV> template <typename... HDV>
void LaunchCPU(Functor func, HDV*... vectors) const { void LaunchCPU(Functor func, HDV*... vectors) const {
auto end = *(range_.end()); omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (omp_ulong idx = 0; idx < end; ++idx) { for (omp_ulong idx = 0; idx < end; ++idx) {
func(idx, UnpackHDV(vectors)...); func(idx, UnpackHDV(vectors)...);

View File

@ -256,12 +256,12 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
// functions for L1 cost // functions for L1 cost
template <typename T1, typename T2> template <typename T1, typename T2>
XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) { XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 alpha) {
if (w > +lambda) { if (w > + alpha) {
return w - lambda; return w - alpha;
} }
if (w < -lambda) { if (w < - alpha) {
return w + lambda; return w + alpha;
} }
return 0.0; return 0.0;
} }
@ -271,9 +271,9 @@ XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
// calculate the cost of loss function // calculate the cost of loss function
template <typename TrainingParams, typename T> template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad, XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p,
T sum_hess, T w) { T sum_grad, T sum_hess, T w) {
return -(2.0 * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w)); return -(T(2.0) * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w));
} }
// calculate the cost of loss function // calculate the cost of loss function
@ -281,44 +281,51 @@ template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) { XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
if (sum_hess < p.min_child_weight) { if (sum_hess < p.min_child_weight) {
return T(0.0); return T(0.0);
} }
if (p.max_delta_step == 0.0f) { if (p.max_delta_step == 0.0f) {
if (p.reg_alpha == 0.0f) { if (p.reg_alpha == 0.0f) {
return Sqr(sum_grad) / (sum_hess + p.reg_lambda); return Sqr(sum_grad) / (sum_hess + p.reg_lambda);
} else { } else {
return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) / return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) /
(sum_hess + p.reg_lambda); (sum_hess + p.reg_lambda);
} }
} else { } else {
T w = CalcWeight(p, sum_grad, sum_hess); T w = CalcWeight(p, sum_grad, sum_hess);
T ret = sum_grad * w + T(0.5) * (sum_hess + p.reg_lambda) * Sqr(w); T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w);
if (p.reg_alpha == 0.0f) { if (p.reg_alpha == 0.0f) {
return T(-2.0) * ret; return ret;
} else { } else {
return T(-2.0) * (ret + p.reg_alpha * std::abs(w)); return ret + p.reg_alpha * std::abs(w);
} }
} }
} }
template <typename TrainingParams,
typename StatT, typename T = decltype(StatT().GetHess())>
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) {
return CalcGain(p, stat.GetGrad(), stat.GetHess());
}
// calculate cost of loss function with four statistics // calculate cost of loss function with four statistics
template <typename TrainingParams, typename T> template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess, XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
T test_grad, T test_hess) { T test_grad, T test_hess) {
T w = CalcWeight(sum_grad, sum_hess); T w = CalcWeight(sum_grad, sum_hess);
T ret = test_grad * w + 0.5 * (test_hess + p.reg_lambda) * Sqr(w); T ret = CalcGainGivenWeight(p, test_grad, test_hess);
if (p.reg_alpha == 0.0f) { if (p.reg_alpha == 0.0f) {
return -2.0 * ret; return ret;
} else { } else {
return -2.0 * (ret + p.reg_alpha * std::abs(w)); return ret + p.reg_alpha * std::abs(w);
} }
} }
// calculate weight given the statistics // calculate weight given the statistics
template <typename TrainingParams, typename T> template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad, XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
T sum_hess) { T sum_hess) {
if (sum_hess < p.min_child_weight || sum_hess <= 0.0) { if (sum_hess < p.min_child_weight || sum_hess <= 0.0) {
return 0.0; return 0.0;
} }
T dw; T dw;
if (p.reg_alpha == 0.0f) { if (p.reg_alpha == 0.0f) {
dw = -sum_grad / (sum_hess + p.reg_lambda); dw = -sum_grad / (sum_hess + p.reg_lambda);
@ -328,14 +335,15 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
if (p.max_delta_step != 0.0f) { if (p.max_delta_step != 0.0f) {
if (dw > p.max_delta_step) { if (dw > p.max_delta_step) {
dw = p.max_delta_step; dw = p.max_delta_step;
} }
if (dw < -p.max_delta_step) { if (dw < -p.max_delta_step) {
dw = -p.max_delta_step; dw = -p.max_delta_step;
} }
} }
return dw; return dw;
} }
// Used in gpu code where GradientPair is used for gradient sum, not GradStats.
template <typename TrainingParams, typename GpairT> template <typename TrainingParams, typename GpairT>
XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) { XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) {
return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess()); return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess());
@ -347,49 +355,27 @@ struct XGBOOST_ALIGNAS(16) GradStats {
double sum_grad; double sum_grad;
/*! \brief sum hessian statistics */ /*! \brief sum hessian statistics */
double sum_hess; double sum_hess;
/*!
* \brief whether this is simply statistics and we only need to call public:
* Add(gpair), instead of Add(gpair, info, ridx) XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
*/ XGBOOST_DEVICE double GetHess() const { return sum_hess; }
static const int kSimpleStats = 1;
/*! \brief constructor, the object must be cleared during construction */ XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
explicit GradStats(const TrainParam& param) { this->Clear(); } static_assert(sizeof(GradStats) == 16,
explicit GradStats(double sum_grad, double sum_hess) "Size of GradStats is not 16 bytes.");
: sum_grad(sum_grad), sum_hess(sum_hess) {} }
template <typename GpairT> template <typename GpairT>
XGBOOST_DEVICE explicit GradStats(const GpairT &sum) XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
: sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {}
/*! \brief clear the statistics */ explicit GradStats(const double grad, const double hess)
inline void Clear() { sum_grad = sum_hess = 0.0f; } : sum_grad(grad), sum_hess(hess) {}
/*! \brief check if necessary information is ready */
inline static void CheckInfo(const MetaInfo& info) {}
/*! /*!
* \brief accumulate statistics * \brief accumulate statistics
* \param p the gradient pair * \param p the gradient pair
*/ */
inline void Add(GradientPair p) { this->Add(p.GetGrad(), p.GetHess()); } inline void Add(GradientPair p) { this->Add(p.GetGrad(), p.GetHess()); }
/*!
* \brief accumulate statistics, more complicated version
* \param gpair the vector storing the gradient statistics
* \param info the additional information
* \param ridx instance index of this instance
*/
inline void Add(const std::vector<GradientPair>& gpair, const MetaInfo& info,
bst_uint ridx) {
const GradientPair& b = gpair[ridx];
this->Add(b.GetGrad(), b.GetHess());
}
/*! \brief calculate leaf weight */
template <typename ParamT>
XGBOOST_DEVICE inline double CalcWeight(const ParamT &param) const {
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
}
/*! \brief calculate gain of the solution */
template <typename ParamT>
inline double CalcGain(const ParamT& param) const {
return xgboost::tree::CalcGain(param, sum_grad, sum_hess);
}
/*! \brief add statistics to the data */ /*! \brief add statistics to the data */
inline void Add(const GradStats& b) { inline void Add(const GradStats& b) {
sum_grad += b.sum_grad; sum_grad += b.sum_grad;
@ -406,8 +392,6 @@ template <typename ParamT>
} }
/*! \return whether the statistics is not used yet */ /*! \return whether the statistics is not used yet */
inline bool Empty() const { return sum_hess == 0.0; } inline bool Empty() const { return sum_hess == 0.0; }
// constructor to allow inheritance
GradStats() = default;
/*! \brief add statistics to the data */ /*! \brief add statistics to the data */
inline void Add(double grad, double hess) { inline void Add(double grad, double hess) {
sum_grad += grad; sum_grad += grad;
@ -415,6 +399,7 @@ template <typename ParamT>
} }
}; };
// TODO(trivialfis): Remove this class.
struct ValueConstraint { struct ValueConstraint {
double lower_bound; double lower_bound;
double upper_bound; double upper_bound;
@ -424,9 +409,9 @@ struct ValueConstraint {
inline static void Init(TrainParam *param, unsigned num_feature) { inline static void Init(TrainParam *param, unsigned num_feature) {
param->monotone_constraints.resize(num_feature, 0); param->monotone_constraints.resize(num_feature, 0);
} }
template <typename ParamT> template <typename ParamT>
XGBOOST_DEVICE inline double CalcWeight(const ParamT &param, GradStats stats) const { XGBOOST_DEVICE inline double CalcWeight(const ParamT &param, GradStats stats) const {
double w = stats.CalcWeight(param); double w = xgboost::tree::CalcWeight(param, stats);
if (w < lower_bound) { if (w < lower_bound) {
return lower_bound; return lower_bound;
} }
@ -436,13 +421,13 @@ template <typename ParamT>
return w; return w;
} }
template <typename ParamT> template <typename ParamT>
XGBOOST_DEVICE inline double CalcGain(const ParamT &param, GradStats stats) const { XGBOOST_DEVICE inline double CalcGain(const ParamT &param, GradStats stats) const {
return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess, return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess,
CalcWeight(param, stats)); CalcWeight(param, stats));
} }
template <typename ParamT> template <typename ParamT>
XGBOOST_DEVICE inline double CalcSplitGain(const ParamT &param, int constraint, XGBOOST_DEVICE inline double CalcSplitGain(const ParamT &param, int constraint,
GradStats left, GradStats right) const { GradStats left, GradStats right) const {
const double negative_infinity = -std::numeric_limits<double>::infinity(); const double negative_infinity = -std::numeric_limits<double>::infinity();
@ -468,7 +453,7 @@ template <typename ParamT>
*cright = *this; *cright = *this;
if (c == 0) { if (c == 0) {
return; return;
} }
double wleft = CalcWeight(param, left); double wleft = CalcWeight(param, left);
double wright = CalcWeight(param, right); double wright = CalcWeight(param, right);
double mid = (wleft + wright) / 2; double mid = (wleft + wright) / 2;
@ -578,13 +563,13 @@ inline std::ostream &operator<<(std::ostream &os, const std::vector<int> &t) {
for (auto it = t.begin(); it != t.end(); ++it) { for (auto it = t.begin(); it != t.end(); ++it) {
if (it != t.begin()) { if (it != t.begin()) {
os << ','; os << ',';
} }
os << *it; os << *it;
} }
// python style tuple // python style tuple
if (t.size() == 1) { if (t.size() == 1) {
os << ','; os << ',';
} }
os << ')'; os << ')';
return os; return os;
} }
@ -603,7 +588,7 @@ inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
is.get(); is.get();
if (ch == '(') { if (ch == '(') {
break; break;
} }
if (!isspace(ch)) { if (!isspace(ch)) {
is.setstate(std::ios::failbit); is.setstate(std::ios::failbit);
return is; return is;
@ -635,7 +620,7 @@ inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
} }
if (ch == ')') { if (ch == ')') {
break; break;
} }
} else if (ch == ')') { } else if (ch == ')') {
break; break;
} else { } else {

View File

@ -50,6 +50,7 @@ void SplitEvaluator::AddSplit(bst_uint nodeid,
bst_uint featureid, bst_uint featureid,
bst_float leftweight, bst_float leftweight,
bst_float rightweight) {} bst_float rightweight) {}
bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid, bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
bst_uint featureid, bst_uint featureid,
const GradStats& left_stats, const GradStats& left_stats,

View File

@ -333,28 +333,28 @@ class BaseMaker: public TreeUpdater {
const MetaInfo &info = fmat.Info(); const MetaInfo &info = fmat.Info();
thread_temp.resize(omp_get_max_threads()); thread_temp.resize(omp_get_max_threads());
p_node_stats->resize(tree.param.num_nodes); p_node_stats->resize(tree.param.num_nodes);
#pragma omp parallel #pragma omp parallel
{ {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
thread_temp[tid].resize(tree.param.num_nodes, TStats(param_)); thread_temp[tid].resize(tree.param.num_nodes, TStats());
for (unsigned int nid : qexpand_) { for (unsigned int nid : qexpand_) {
thread_temp[tid][nid].Clear(); thread_temp[tid][nid] = TStats();
} }
} }
// setup position // setup position
const auto ndata = static_cast<bst_omp_uint>(fmat.Info().num_row_); const auto ndata = static_cast<bst_omp_uint>(fmat.Info().num_row_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) {
const int nid = position_[ridx]; const int nid = position_[ridx];
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
if (nid >= 0) { if (nid >= 0) {
thread_temp[tid][nid].Add(gpair, info, ridx); thread_temp[tid][nid].Add(gpair[ridx]);
} }
} }
// sum the per thread statistics together // sum the per thread statistics together
for (int nid : qexpand_) { for (int nid : qexpand_) {
TStats &s = (*p_node_stats)[nid]; TStats &s = (*p_node_stats)[nid];
s.Clear(); s = TStats();
for (size_t tid = 0; tid < thread_temp.size(); ++tid) { for (size_t tid = 0; tid < thread_temp.size(); ++tid) {
s.Add(thread_temp[tid][nid]); s.Add(thread_temp[tid][nid]);
} }

View File

@ -33,7 +33,6 @@ class ColMaker: public TreeUpdater {
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix* dmat, DMatrix* dmat,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
GradStats::CheckInfo(dmat->Info());
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
@ -66,9 +65,7 @@ class ColMaker: public TreeUpdater {
/*! \brief current best solution */ /*! \brief current best solution */
SplitEntry best; SplitEntry best;
// constructor // constructor
explicit ThreadEntry(const TrainParam &param) ThreadEntry() : last_fvalue{0}, first_fvalue{0} {}
: stats(param), stats_extra(param) {
}
}; };
struct NodeEntry { struct NodeEntry {
/*! \brief statics for node entry */ /*! \brief statics for node entry */
@ -80,9 +77,7 @@ class ColMaker: public TreeUpdater {
/*! \brief current best solution */ /*! \brief current best solution */
SplitEntry best; SplitEntry best;
// constructor // constructor
explicit NodeEntry(const TrainParam& param) NodeEntry() : root_gain(0.0f), weight(0.0f) {}
: stats(param), root_gain(0.0f), weight(0.0f){
}
}; };
// actual builder that runs the algorithm // actual builder that runs the algorithm
class Builder { class Builder {
@ -200,9 +195,9 @@ class ColMaker: public TreeUpdater {
{ {
// setup statistics space for each tree node // setup statistics space for each tree node
for (auto& i : stemp_) { for (auto& i : stemp_) {
i.resize(tree.param.num_nodes, ThreadEntry(param_)); i.resize(tree.param.num_nodes, ThreadEntry());
} }
snode_.resize(tree.param.num_nodes, NodeEntry(param_)); snode_.resize(tree.param.num_nodes, NodeEntry());
} }
const MetaInfo& info = fmat.Info(); const MetaInfo& info = fmat.Info();
// setup position // setup position
@ -211,11 +206,11 @@ class ColMaker: public TreeUpdater {
for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) { for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
if (position_[ridx] < 0) continue; if (position_[ridx] < 0) continue;
stemp_[tid][position_[ridx]].stats.Add(gpair, info, ridx); stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]);
} }
// sum the per thread statistics together // sum the per thread statistics together
for (int nid : qexpand) { for (int nid : qexpand) {
GradStats stats(param_); GradStats stats;
for (auto& s : stemp_) { for (auto& s : stemp_) {
stats.Add(s[nid].stats); stats.Add(s[nid].stats);
} }
@ -261,7 +256,7 @@ class ColMaker: public TreeUpdater {
std::vector<ThreadEntry> &temp = stemp_[tid]; std::vector<ThreadEntry> &temp = stemp_[tid];
// cleanup temp statistics // cleanup temp statistics
for (int j : qexpand) { for (int j : qexpand) {
temp[j].stats.Clear(); temp[j].stats = GradStats();
} }
bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_; bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_;
bst_uint end = std::min(static_cast<bst_uint>(col.size()), step * (tid + 1)); bst_uint end = std::min(static_cast<bst_uint>(col.size()), step * (tid + 1));
@ -273,7 +268,7 @@ class ColMaker: public TreeUpdater {
if (temp[nid].stats.Empty()) { if (temp[nid].stats.Empty()) {
temp[nid].first_fvalue = fvalue; temp[nid].first_fvalue = fvalue;
} }
temp[nid].stats.Add(gpair, info, ridx); temp[nid].stats.Add(gpair[ridx]);
temp[nid].last_fvalue = fvalue; temp[nid].last_fvalue = fvalue;
} }
} }
@ -282,7 +277,7 @@ class ColMaker: public TreeUpdater {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < nnode; ++j) { for (bst_omp_uint j = 0; j < nnode; ++j) {
const int nid = qexpand[j]; const int nid = qexpand[j];
GradStats sum(param_), tmp(param_), c(param_); GradStats sum, tmp, c;
for (int tid = 0; tid < this->nthread_; ++tid) { for (int tid = 0; tid < this->nthread_; ++tid) {
tmp = stemp_[tid][nid].stats; tmp = stemp_[tid][nid].stats;
stemp_[tid][nid].stats = sum; stemp_[tid][nid].stats = sum;
@ -342,7 +337,7 @@ class ColMaker: public TreeUpdater {
// rescan, generate candidate split // rescan, generate candidate split
#pragma omp parallel #pragma omp parallel
{ {
GradStats c(param_), cright(param_); GradStats c, cright;
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
std::vector<ThreadEntry> &temp = stemp_[tid]; std::vector<ThreadEntry> &temp = stemp_[tid];
bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_; bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_;
@ -355,7 +350,7 @@ class ColMaker: public TreeUpdater {
// get the statistics of nid // get the statistics of nid
ThreadEntry &e = temp[nid]; ThreadEntry &e = temp[nid];
if (e.stats.Empty()) { if (e.stats.Empty()) {
e.stats.Add(gpair, info, ridx); e.stats.Add(gpair[ridx]);
e.first_fvalue = fvalue; e.first_fvalue = fvalue;
} else { } else {
// forward default right // forward default right
@ -383,7 +378,7 @@ class ColMaker: public TreeUpdater {
} }
} }
} }
e.stats.Add(gpair, info, ridx); e.stats.Add(gpair[ridx]);
e.first_fvalue = fvalue; e.first_fvalue = fvalue;
} }
} }
@ -436,10 +431,10 @@ class ColMaker: public TreeUpdater {
const std::vector<int> &qexpand = qexpand_; const std::vector<int> &qexpand = qexpand_;
// clear all the temp statistics // clear all the temp statistics
for (auto nid : qexpand) { for (auto nid : qexpand) {
temp[nid].stats.Clear(); temp[nid].stats = GradStats();
} }
// left statistics // left statistics
GradStats c(param_); GradStats c;
// local cache buffer for position and gradient pair // local cache buffer for position and gradient pair
constexpr int kBuffer = 32; constexpr int kBuffer = 32;
int buf_position[kBuffer] = {}; int buf_position[kBuffer] = {};
@ -516,17 +511,17 @@ class ColMaker: public TreeUpdater {
const MetaInfo &info, const MetaInfo &info,
std::vector<ThreadEntry> &temp) { // NOLINT(*) std::vector<ThreadEntry> &temp) { // NOLINT(*)
// use cacheline aware optimization // use cacheline aware optimization
if (GradStats::kSimpleStats != 0 && param_.cache_opt != 0) { if (param_.cache_opt != 0) {
EnumerateSplitCacheOpt(begin, end, d_step, fid, gpair, temp); EnumerateSplitCacheOpt(begin, end, d_step, fid, gpair, temp);
return; return;
} }
const std::vector<int> &qexpand = qexpand_; const std::vector<int> &qexpand = qexpand_;
// clear all the temp statistics // clear all the temp statistics
for (auto nid : qexpand) { for (auto nid : qexpand) {
temp[nid].stats.Clear(); temp[nid].stats = GradStats();
} }
// left statistics // left statistics
GradStats c(param_); GradStats c;
for (const Entry *it = begin; it != end; it += d_step) { for (const Entry *it = begin; it != end; it += d_step) {
const bst_uint ridx = it->index; const bst_uint ridx = it->index;
const int nid = position_[ridx]; const int nid = position_[ridx];
@ -537,7 +532,7 @@ class ColMaker: public TreeUpdater {
ThreadEntry &e = temp[nid]; ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init // test if first hit, this is fine, because we set 0 during init
if (e.stats.Empty()) { if (e.stats.Empty()) {
e.stats.Add(gpair, info, ridx); e.stats.Add(gpair[ridx]);
e.last_fvalue = fvalue; e.last_fvalue = fvalue;
} else { } else {
// try to find a split // try to find a split
@ -562,7 +557,7 @@ class ColMaker: public TreeUpdater {
} }
} }
// update the statistics // update the statistics
e.stats.Add(gpair, info, ridx); e.stats.Add(gpair[ridx]);
e.last_fvalue = fvalue; e.last_fvalue = fvalue;
} }
} }
@ -783,7 +778,6 @@ class DistColMaker : public ColMaker {
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix* dmat, DMatrix* dmat,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
GradStats::CheckInfo(dmat->Info());
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time"; CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
Builder builder( Builder builder(
param_, param_,

View File

@ -16,6 +16,28 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu); DMLC_REGISTRY_FILE_TAG(updater_gpu);
template <typename GradientPairT>
XGBOOST_DEVICE float inline LossChangeMissing(const GradientPairT& scan,
const GradientPairT& missing,
const GradientPairT& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
// Put gradients of missing values to left
float missing_left_loss =
DeviceCalcLossChange(param, scan + missing, parent_sum, parent_gain);
float missing_right_loss =
DeviceCalcLossChange(param, scan, parent_sum, parent_gain);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
/** /**
* @brief Absolute BFS order IDs to col-wise unique IDs based on user input * @brief Absolute BFS order IDs to col-wise unique IDs based on user input
* @param tid the index of the element that this thread should access * @param tid the index of the element that this thread should access
@ -565,7 +587,6 @@ class GPUMaker : public TreeUpdater {
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) { const std::vector<RegTree*>& trees) {
GradStats::CheckInfo(dmat->Info());
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.learning_rate;
param.learning_rate = lr / trees.size(); param.learning_rate = lr / trees.size();
@ -633,7 +654,7 @@ class GPUMaker : public TreeUpdater {
// get the default direction for the current node // get the default direction for the current node
GradientPair missing = n.sum_gradients - gradSum; GradientPair missing = n.sum_gradients - gradSum;
LossChangeMissing(gradScan, missing, n.sum_gradients, n.root_gain, LossChangeMissing(gradScan, missing, n.sum_gradients, n.root_gain,
gpu_param, missingLeft); gpu_param, missingLeft);
// get the score/weight/id/gradSum for left and right child nodes // get the score/weight/id/gradSum for left and right child nodes
GradientPair lGradSum = missingLeft ? gradScan + missing : gradScan; GradientPair lGradSum = missingLeft ? gradScan + missing : gradScan;
GradientPair rGradSum = n.sum_gradients - lGradSum; GradientPair rGradSum = n.sum_gradients - lGradSum;

View File

@ -98,7 +98,7 @@ struct DeviceSplitCandidate {
template <typename ParamT> template <typename ParamT>
XGBOOST_DEVICE void Update(const DeviceSplitCandidate& other, XGBOOST_DEVICE void Update(const DeviceSplitCandidate& other,
const ParamT& param) { const ParamT& param) {
if (other.loss_chg > loss_chg && if (other.loss_chg > loss_chg &&
other.left_sum.GetHess() >= param.min_child_weight && other.left_sum.GetHess() >= param.min_child_weight &&
other.right_sum.GetHess() >= param.min_child_weight) { other.right_sum.GetHess() >= param.min_child_weight) {
@ -213,51 +213,6 @@ XGBOOST_DEVICE inline float DeviceCalcLossChange(const GPUTrainingParam& param,
return left_gain + right_gain - parent_gain; return left_gain + right_gain - parent_gain;
} }
// Without constraints
template <typename GradientPairT>
XGBOOST_DEVICE float inline LossChangeMissing(const GradientPairT& scan,
const GradientPairT& missing,
const GradientPairT& parent_sum,
const float& parent_gain,
const GPUTrainingParam& param,
bool& missing_left_out) { // NOLINT
// Put gradients of missing values to left
float missing_left_loss =
DeviceCalcLossChange(param, scan + missing, parent_sum, parent_gain);
float missing_right_loss =
DeviceCalcLossChange(param, scan, parent_sum, parent_gain);
if (missing_left_loss >= missing_right_loss) {
missing_left_out = true;
return missing_left_loss;
} else {
missing_left_out = false;
return missing_right_loss;
}
}
// With constraints
template <typename GradientPairT>
XGBOOST_DEVICE float inline LossChangeMissing(
const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum,
const float& parent_gain, const GPUTrainingParam& param, int constraint,
const ValueConstraint& value_constraint,
bool& missing_left_out) { // NOLINT
float missing_left_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan + missing),
GradStats(parent_sum - (scan + missing)));
float missing_right_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
if (missing_left_gain >= missing_right_gain) {
missing_left_out = true;
return missing_left_gain - parent_gain;
} else {
missing_left_out = false;
return missing_right_gain - parent_gain;
}
}
// Total number of nodes in tree, given depth // Total number of nodes in tree, given depth
XGBOOST_DEVICE inline int MaxNodesDepth(int depth) { XGBOOST_DEVICE inline int MaxNodesDepth(int depth) {
return (1 << (depth + 1)) - 1; return (1 << (depth + 1)) - 1;

View File

@ -50,6 +50,28 @@ struct GPUHistMakerTrainParam
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
// With constraints
template <typename GradientPairT>
XGBOOST_DEVICE float inline LossChangeMissing(
const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum,
const float& parent_gain, const GPUTrainingParam& param, int constraint,
const ValueConstraint& value_constraint,
bool& missing_left_out) { // NOLINT
float missing_left_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan + missing),
GradStats(parent_sum - (scan + missing)));
float missing_right_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
if (missing_left_gain >= missing_right_gain) {
missing_left_out = true;
return missing_left_gain - parent_gain;
} else {
missing_left_out = false;
return missing_right_gain - parent_gain;
}
}
/*! /*!
* \brief * \brief
* *
@ -942,7 +964,6 @@ class GPUHistMakerSpecialised{
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) { const std::vector<RegTree*>& trees) {
monitor_.Start("Update", dist_.Devices()); monitor_.Start("Update", dist_.Devices());
GradStats::CheckInfo(dmat->Info());
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
@ -1183,11 +1204,12 @@ class GPUHistMakerSpecialised{
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree; RegTree& tree = *p_tree;
GradStats left_stats(param_);
GradStats left_stats;
left_stats.Add(candidate.split.left_sum); left_stats.Add(candidate.split.left_sum);
GradStats right_stats(param_); GradStats right_stats;
right_stats.Add(candidate.split.right_sum); right_stats.Add(candidate.split.right_sum);
GradStats parent_sum(param_); GradStats parent_sum;
parent_sum.Add(left_stats); parent_sum.Add(left_stats);
parent_sum.Add(right_stats); parent_sum.Add(right_stats);
node_value_constraints_.resize(tree.GetNodes().size()); node_value_constraints_.resize(tree.GetNodes().size());

View File

@ -19,13 +19,11 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_histmaker); DMLC_REGISTRY_FILE_TAG(updater_histmaker);
template<typename TStats>
class HistMaker: public BaseMaker { class HistMaker: public BaseMaker {
public: public:
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(p_fmat->Info());
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
@ -42,13 +40,13 @@ class HistMaker: public BaseMaker {
/*! \brief cutting point of histogram, contains maximum point */ /*! \brief cutting point of histogram, contains maximum point */
const bst_float *cut; const bst_float *cut;
/*! \brief content of statistics data */ /*! \brief content of statistics data */
TStats *data; GradStats *data;
/*! \brief size of histogram */ /*! \brief size of histogram */
unsigned size; unsigned size;
// default constructor // default constructor
HistUnit() = default; HistUnit() = default;
// constructor // constructor
HistUnit(const bst_float *cut, TStats *data, unsigned size) HistUnit(const bst_float *cut, GradStats *data, unsigned size)
: cut(cut), data(data), size(size) {} : cut(cut), data(data), size(size) {}
/*! \brief add a histogram to data */ /*! \brief add a histogram to data */
inline void Add(bst_float fv, inline void Add(bst_float fv,
@ -58,7 +56,7 @@ class HistMaker: public BaseMaker {
unsigned i = std::upper_bound(cut, cut + size, fv) - cut; unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
CHECK_NE(size, 0U) << "try insert into size=0"; CHECK_NE(size, 0U) << "try insert into size=0";
CHECK_LT(i, size); CHECK_LT(i, size);
data[i].Add(gpair, info, ridx); data[i].Add(gpair[ridx]);
} }
}; };
/*! \brief a set of histograms from different index */ /*! \brief a set of histograms from different index */
@ -68,7 +66,7 @@ class HistMaker: public BaseMaker {
/*! \brief cutting points in each histunit */ /*! \brief cutting points in each histunit */
const bst_float *cut; const bst_float *cut;
/*! \brief data in different hist unit */ /*! \brief data in different hist unit */
std::vector<TStats> data; std::vector<GradStats> data;
/*! \brief */ /*! \brief */
inline HistUnit operator[](size_t fid) { inline HistUnit operator[](size_t fid) {
return HistUnit(cut + rptr[fid], return HistUnit(cut + rptr[fid],
@ -89,12 +87,10 @@ class HistMaker: public BaseMaker {
hset.resize(nthread); hset.resize(nthread);
// cleanup statistics // cleanup statistics
for (int tid = 0; tid < nthread; ++tid) { for (int tid = 0; tid < nthread; ++tid) {
for (size_t i = 0; i < hset[tid].data.size(); ++i) { for (auto& d : hset[tid].data) { d = GradStats(); }
hset[tid].data[i].Clear();
}
hset[tid].rptr = dmlc::BeginPtr(rptr); hset[tid].rptr = dmlc::BeginPtr(rptr);
hset[tid].cut = dmlc::BeginPtr(cut); hset[tid].cut = dmlc::BeginPtr(cut);
hset[tid].data.resize(cut.size(), TStats(param)); hset[tid].data.resize(cut.size(), GradStats());
} }
} }
// aggregate all statistics to hset[0] // aggregate all statistics to hset[0]
@ -119,7 +115,7 @@ class HistMaker: public BaseMaker {
// workspace of thread // workspace of thread
ThreadWSpace wspace_; ThreadWSpace wspace_;
// reducer for histogram // reducer for histogram
rabit::Reducer<TStats, TStats::Reduce> histred_; rabit::Reducer<GradStats, GradStats::Reduce> histred_;
// set of working features // set of working features
std::vector<bst_uint> fwork_set_; std::vector<bst_uint> fwork_set_;
// update function implementation // update function implementation
@ -147,8 +143,7 @@ class HistMaker: public BaseMaker {
// if nothing left to be expand, break // if nothing left to be expand, break
if (qexpand_.size() == 0) break; if (qexpand_.size() == 0) break;
} }
for (size_t i = 0; i < qexpand_.size(); ++i) { for (int const nid : qexpand_) {
const int nid = qexpand_[i];
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate); (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
} }
} }
@ -179,34 +174,35 @@ class HistMaker: public BaseMaker {
private: private:
inline void EnumerateSplit(const HistUnit &hist, inline void EnumerateSplit(const HistUnit &hist,
const TStats &node_sum, const GradStats &node_sum,
bst_uint fid, bst_uint fid,
SplitEntry *best, SplitEntry *best,
TStats *left_sum) { GradStats *left_sum) {
if (hist.size == 0) return; if (hist.size == 0) return;
double root_gain = node_sum.CalcGain(param_); double root_gain = CalcGain(param_, node_sum.GetGrad(), node_sum.GetHess());
TStats s(param_), c(param_); GradStats s, c;
for (bst_uint i = 0; i < hist.size; ++i) { for (bst_uint i = 0; i < hist.size; ++i) {
s.Add(hist.data[i]); s.Add(hist.data[i]);
if (s.sum_hess >= param_.min_child_weight) { if (s.sum_hess >= param_.min_child_weight) {
c.SetSubstract(node_sum, s); c.SetSubstract(node_sum, s);
if (c.sum_hess >= param_.min_child_weight) { if (c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
false, s, c)) { if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false, s, c)) {
*left_sum = s; *left_sum = s;
} }
} }
} }
} }
s.Clear(); s = GradStats();
for (bst_uint i = hist.size - 1; i != 0; --i) { for (bst_uint i = hist.size - 1; i != 0; --i) {
s.Add(hist.data[i]); s.Add(hist.data[i]);
if (s.sum_hess >= param_.min_child_weight) { if (s.sum_hess >= param_.min_child_weight) {
c.SetSubstract(node_sum, s); c.SetSubstract(node_sum, s);
if (c.sum_hess >= param_.min_child_weight) { if (c.sum_hess >= param_.min_child_weight) {
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain; double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) { if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) {
*left_sum = c; *left_sum = c;
} }
@ -222,14 +218,14 @@ class HistMaker: public BaseMaker {
const size_t num_feature = fset.size(); const size_t num_feature = fset.size();
// get the best split condition for each node // get the best split condition for each node
std::vector<SplitEntry> sol(qexpand_.size()); std::vector<SplitEntry> sol(qexpand_.size());
std::vector<TStats> left_sum(qexpand_.size()); std::vector<GradStats> left_sum(qexpand_.size());
auto nexpand = static_cast<bst_omp_uint>(qexpand_.size()); auto nexpand = static_cast<bst_omp_uint>(qexpand_.size());
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
const int nid = qexpand_[wid]; const int nid = qexpand_[wid];
CHECK_EQ(node2workindex_[nid], static_cast<int>(wid)); CHECK_EQ(node2workindex_[nid], static_cast<int>(wid));
SplitEntry &best = sol[wid]; SplitEntry &best = sol[wid];
TStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
for (size_t i = 0; i < fset.size(); ++i) { for (size_t i = 0; i < fset.size(); ++i) {
EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)], EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)],
node_sum, fset[i], &best, &left_sum[wid]); node_sum, fset[i], &best, &left_sum[wid]);
@ -239,13 +235,13 @@ class HistMaker: public BaseMaker {
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
const int nid = qexpand_[wid]; const int nid = qexpand_[wid];
const SplitEntry &best = sol[wid]; const SplitEntry &best = sol[wid];
const TStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; const GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
this->SetStats(p_tree, nid, node_sum); this->SetStats(p_tree, nid, node_sum);
// set up the values // set up the values
p_tree->Stat(nid).loss_chg = best.loss_chg; p_tree->Stat(nid).loss_chg = best.loss_chg;
// now we know the solution in snode[nid], set split // now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) { if (best.loss_chg > kRtEps) {
bst_float base_weight = node_sum.CalcWeight(param_); bst_float base_weight = CalcWeight(param_, node_sum);
bst_float left_leaf_weight = bst_float left_leaf_weight =
CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) * CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
param_.learning_rate; param_.learning_rate;
@ -258,7 +254,7 @@ class HistMaker: public BaseMaker {
right_leaf_weight, best.loss_chg, right_leaf_weight, best.loss_chg,
node_sum.sum_hess); node_sum.sum_hess);
// right side sum // right side sum
TStats right_sum; GradStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]); right_sum.SetSubstract(node_sum, left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].LeftChild(), left_sum[wid]); this->SetStats(p_tree, (*p_tree)[nid].LeftChild(), left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].RightChild(), right_sum); this->SetStats(p_tree, (*p_tree)[nid].RightChild(), right_sum);
@ -268,20 +264,20 @@ class HistMaker: public BaseMaker {
} }
} }
inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) { inline void SetStats(RegTree *p_tree, int nid, const GradStats &node_sum) {
p_tree->Stat(nid).base_weight = static_cast<bst_float>(node_sum.CalcWeight(param_)); p_tree->Stat(nid).base_weight =
static_cast<bst_float>(CalcWeight(param_, node_sum));
p_tree->Stat(nid).sum_hess = static_cast<bst_float>(node_sum.sum_hess); p_tree->Stat(nid).sum_hess = static_cast<bst_float>(node_sum.sum_hess);
} }
}; };
template<typename TStats> class CQHistMaker: public HistMaker {
class CQHistMaker: public HistMaker<TStats> {
public: public:
CQHistMaker() = default; CQHistMaker() = default;
protected: protected:
struct HistEntry { struct HistEntry {
typename HistMaker<TStats>::HistUnit hist; HistMaker::HistUnit hist;
unsigned istart; unsigned istart;
/*! /*!
* \brief add a histogram to data, * \brief add a histogram to data,
@ -293,7 +289,7 @@ class CQHistMaker: public HistMaker<TStats> {
const bst_uint ridx) { const bst_uint ridx) {
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart; while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
CHECK_NE(istart, hist.size); CHECK_NE(istart, hist.size);
hist.data[istart].Add(gpair, info, ridx); hist.data[istart].Add(gpair[ridx]);
} }
/*! /*!
* \brief add a histogram to data, * \brief add a histogram to data,
@ -352,7 +348,7 @@ class CQHistMaker: public HistMaker<TStats> {
for (const auto &batch : p_fmat->GetSortedColumnBatches()) { for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
// start enumeration // start enumeration
const auto nsize = static_cast<bst_omp_uint>(fset.size()); const auto nsize = static_cast<bst_omp_uint>(fset.size());
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) { for (bst_omp_uint i = 0; i < nsize; ++i) {
int fid = fset[i]; int fid = fset[i];
int offset = feat2workindex_[fid]; int offset = feat2workindex_[fid];
@ -366,8 +362,7 @@ class CQHistMaker: public HistMaker<TStats> {
// update node statistics. // update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree, this->GetNodeStats(gpair, *p_fmat, tree,
&thread_stats_, &node_stats_); &thread_stats_, &node_stats_);
for (size_t i = 0; i < this->qexpand_.size(); ++i) { for (int const nid : this->qexpand_) {
const int nid = this->qexpand_[i];
const int wid = this->node2workindex_[nid]; const int wid = this->node2workindex_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)] this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)]
.data[0] = node_stats_[nid]; .data[0] = node_stats_[nid];
@ -403,8 +398,8 @@ class CQHistMaker: public HistMaker<TStats> {
const size_t work_set_size = work_set_.size(); const size_t work_set_size = work_set_.size();
sketchs_.resize(this->qexpand_.size() * work_set_size); sketchs_.resize(this->qexpand_.size() * work_set_size);
for (size_t i = 0; i < sketchs_.size(); ++i) { for (auto& sketch : sketchs_) {
sketchs_[i].Init(info.num_row_, this->param_.sketch_eps); sketch.Init(info.num_row_, this->param_.sketch_eps);
} }
// intitialize the summary array // intitialize the summary array
summary_array_.resize(sketchs_.size()); summary_array_.resize(sketchs_.size());
@ -501,13 +496,12 @@ class CQHistMaker: public HistMaker<TStats> {
// initialize sbuilder for use // initialize sbuilder for use
std::vector<HistEntry> &hbuilder = *p_temp; std::vector<HistEntry> &hbuilder = *p_temp;
hbuilder.resize(tree.param.num_nodes); hbuilder.resize(tree.param.num_nodes);
for (size_t i = 0; i < this->qexpand_.size(); ++i) { for (int const nid : this->qexpand_) {
const unsigned nid = this->qexpand_[i];
const unsigned wid = this->node2workindex_[nid]; const unsigned wid = this->node2workindex_[nid];
hbuilder[nid].istart = 0; hbuilder[nid].istart = 0;
hbuilder[nid].hist = this->wspace_.hset[0][fid_offset + wid * (fset.size()+1)]; hbuilder[nid].hist = this->wspace_.hset[0][fid_offset + wid * (fset.size()+1)];
} }
if (TStats::kSimpleStats != 0 && this->param_.cache_opt != 0) { if (this->param_.cache_opt != 0) {
constexpr bst_uint kBuffer = 32; constexpr bst_uint kBuffer = 32;
bst_uint align_length = col.size() / kBuffer * kBuffer; bst_uint align_length = col.size() / kBuffer * kBuffer;
int buf_position[kBuffer]; int buf_position[kBuffer];
@ -552,13 +546,11 @@ class CQHistMaker: public HistMaker<TStats> {
// initialize sbuilder for use // initialize sbuilder for use
std::vector<BaseMaker::SketchEntry> &sbuilder = *p_temp; std::vector<BaseMaker::SketchEntry> &sbuilder = *p_temp;
sbuilder.resize(tree.param.num_nodes); sbuilder.resize(tree.param.num_nodes);
for (size_t i = 0; i < this->qexpand_.size(); ++i) { for (int const nid : this->qexpand_) {
const unsigned nid = this->qexpand_[i];
const unsigned wid = this->node2workindex_[nid]; const unsigned wid = this->node2workindex_[nid];
sbuilder[nid].sum_total = 0.0f; sbuilder[nid].sum_total = 0.0f;
sbuilder[nid].sketch = &sketchs_[wid * work_set_size + offset]; sbuilder[nid].sketch = &sketchs_[wid * work_set_size + offset];
} }
// first pass, get sum of weight, TODO, optimization to skip first pass // first pass, get sum of weight, TODO, optimization to skip first pass
for (const auto& c : col) { for (const auto& c : col) {
const bst_uint ridx = c.index; const bst_uint ridx = c.index;
@ -569,20 +561,19 @@ class CQHistMaker: public HistMaker<TStats> {
} }
// if only one value, no need to do second pass // if only one value, no need to do second pass
if (col[0].fvalue == col[col.size()-1].fvalue) { if (col[0].fvalue == col[col.size()-1].fvalue) {
for (size_t i = 0; i < this->qexpand_.size(); ++i) { for (int const nid : this->qexpand_) {
const int nid = this->qexpand_[i]; sbuilder[nid].sketch->Push(
sbuilder[nid].sketch->Push(col[0].fvalue, static_cast<bst_float>(sbuilder[nid].sum_total)); col[0].fvalue, static_cast<bst_float>(sbuilder[nid].sum_total));
} }
return; return;
} }
// two pass scan // two pass scan
unsigned max_size = this->param_.MaxSketchSize(); unsigned max_size = this->param_.MaxSketchSize();
for (size_t i = 0; i < this->qexpand_.size(); ++i) { for (int const nid : this->qexpand_) {
const int nid = this->qexpand_[i];
sbuilder[nid].Init(max_size); sbuilder[nid].Init(max_size);
} }
// second pass, build the sketch // second pass, build the sketch
if (TStats::kSimpleStats != 0 && this->param_.cache_opt != 0) { if (this->param_.cache_opt != 0) {
constexpr bst_uint kBuffer = 32; constexpr bst_uint kBuffer = 32;
bst_uint align_length = col.size() / kBuffer * kBuffer; bst_uint align_length = col.size() / kBuffer * kBuffer;
int buf_position[kBuffer]; int buf_position[kBuffer];
@ -616,10 +607,7 @@ class CQHistMaker: public HistMaker<TStats> {
} }
} }
} }
for (size_t i = 0; i < this->qexpand_.size(); ++i) { for (int const nid : this->qexpand_) { sbuilder[nid].Finalize(max_size); }
const int nid = this->qexpand_[i];
sbuilder[nid].Finalize(max_size);
}
} }
// cached dmatrix where we initialized the feature on. // cached dmatrix where we initialized the feature on.
const DMatrix* cache_dmatrix_{nullptr}; const DMatrix* cache_dmatrix_{nullptr};
@ -634,11 +622,11 @@ class CQHistMaker: public HistMaker<TStats> {
// thread temp data // thread temp data
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch_; std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch_;
// used to hold statistics // used to hold statistics
std::vector<std::vector<TStats> > thread_stats_; std::vector<std::vector<GradStats> > thread_stats_;
// used to hold start pointer // used to hold start pointer
std::vector<std::vector<HistEntry> > thread_hist_; std::vector<std::vector<HistEntry> > thread_hist_;
// node statistics // node statistics
std::vector<TStats> node_stats_; std::vector<GradStats> node_stats_;
// summary array // summary array
std::vector<WXQSketch::SummaryContainer> summary_array_; std::vector<WXQSketch::SummaryContainer> summary_array_;
// reducer for summary // reducer for summary
@ -648,8 +636,7 @@ class CQHistMaker: public HistMaker<TStats> {
}; };
// global proposal // global proposal
template<typename TStats> class GlobalProposalHistMaker: public CQHistMaker {
class GlobalProposalHistMaker: public CQHistMaker<TStats> {
protected: protected:
void ResetPosAndPropose(const std::vector<GradientPair> &gpair, void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
@ -661,7 +648,7 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
} }
if (cached_rptr_.size() == 0) { if (cached_rptr_.size() == 0) {
CHECK_EQ(this->qexpand_.size(), 1U); CHECK_EQ(this->qexpand_.size(), 1U);
CQHistMaker<TStats>::ResetPosAndPropose(gpair, p_fmat, fset, tree); CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
cached_rptr_ = this->wspace_.rptr; cached_rptr_ = this->wspace_.rptr;
cached_cut_ = this->wspace_.cut; cached_cut_ = this->wspace_.cut;
} else { } else {
@ -730,8 +717,7 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
// update node statistics. // update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree, this->GetNodeStats(gpair, *p_fmat, tree,
&(this->thread_stats_), &(this->node_stats_)); &(this->thread_stats_), &(this->node_stats_));
for (size_t i = 0; i < this->qexpand_.size(); ++i) { for (const int nid : this->qexpand_) {
const int nid = this->qexpand_[i];
const int wid = this->node2workindex_[nid]; const int wid = this->node2workindex_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)] this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = this->node_stats_[nid]; .data[0] = this->node_stats_[nid];
@ -750,19 +736,19 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.") .describe("Tree constructor that uses approximate histogram construction.")
.set_body([]() { .set_body([]() {
return new CQHistMaker<GradStats>(); return new CQHistMaker();
}); });
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker") XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker")
.describe("Tree constructor that uses approximate global proposal of histogram construction.") .describe("Tree constructor that uses approximate global proposal of histogram construction.")
.set_body([]() { .set_body([]() {
return new GlobalProposalHistMaker<GradStats>(); return new GlobalProposalHistMaker();
}); });
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.describe("Tree constructor that uses approximate global of histogram construction.") .describe("Tree constructor that uses approximate global of histogram construction.")
.set_body([]() { .set_body([]() {
return new GlobalProposalHistMaker<GradStats>(); return new GlobalProposalHistMaker();
}); });
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -50,9 +50,8 @@ void QuantileHistMaker::Init(const std::vector<std::pair<std::string, std::strin
} }
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair, void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *dmat, DMatrix *dmat,
const std::vector<RegTree *> &trees) { const std::vector<RegTree *> &trees) {
GradStats::CheckInfo(dmat->Info());
if (is_gmat_initialized_ == false) { if (is_gmat_initialized_ == false) {
double tstart = dmlc::GetTime(); double tstart = dmlc::GetTime();
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin)); gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
@ -91,11 +90,11 @@ bool QuantileHistMaker::UpdatePredictionCache(
} }
void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
const GHistIndexBlockMatrix& gmatb, const GHistIndexBlockMatrix& gmatb,
const ColumnMatrix& column_matrix, const ColumnMatrix& column_matrix,
HostDeviceVector<GradientPair>* gpair, HostDeviceVector<GradientPair>* gpair,
DMatrix* p_fmat, DMatrix* p_fmat,
RegTree* p_tree) { RegTree* p_tree) {
double gstart = dmlc::GetTime(); double gstart = dmlc::GetTime();
int num_leaves = 0; int num_leaves = 0;
@ -280,9 +279,9 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
} }
void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair, const std::vector<GradientPair>& gpair,
const DMatrix& fmat, const DMatrix& fmat,
const RegTree& tree) { const RegTree& tree) {
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots) CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
<< "ColMakerHist: can only grow new tree"; << "ColMakerHist: can only grow new tree";
CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
@ -395,11 +394,11 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
} }
} }
void QuantileHistMaker::Builder::EvaluateSplit(int nid, void QuantileHistMaker::Builder::EvaluateSplit(const int nid,
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
const HistCollection& hist, const HistCollection& hist,
const DMatrix& fmat, const DMatrix& fmat,
const RegTree& tree) { const RegTree& tree) {
// start enumeration // start enumeration
const MetaInfo& info = fmat.Info(); const MetaInfo& info = fmat.Info();
auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
@ -411,13 +410,14 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid,
for (bst_omp_uint tid = 0; tid < nthread; ++tid) { for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
best_split_tloc_[tid] = snode_[nid].best; best_split_tloc_[tid] = snode_[nid].best;
} }
GHistRow node_hist = hist[nid];
#pragma omp parallel for schedule(dynamic) num_threads(nthread) #pragma omp parallel for schedule(dynamic) num_threads(nthread)
for (bst_omp_uint i = 0; i < nfeature; ++i) { for (bst_omp_uint i = 0; i < nfeature; ++i) {
const bst_uint fid = feature_set[i]; const bst_uint fid = feature_set[i];
const unsigned tid = omp_get_thread_num(); const unsigned tid = omp_get_thread_num();
this->EnumerateSplit(-1, gmat, hist[nid], snode_[nid], info, this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info,
&best_split_tloc_[tid], fid, nid); &best_split_tloc_[tid], fid, nid);
this->EnumerateSplit(+1, gmat, hist[nid], snode_[nid], info, this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info,
&best_split_tloc_[tid], fid, nid); &best_split_tloc_[tid], fid, nid);
} }
for (unsigned tid = 0; tid < nthread; ++tid) { for (unsigned tid = 0; tid < nthread; ++tid) {
@ -426,11 +426,11 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid,
} }
void QuantileHistMaker::Builder::ApplySplit(int nid, void QuantileHistMaker::Builder::ApplySplit(int nid,
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
const ColumnMatrix& column_matrix, const ColumnMatrix& column_matrix,
const HistCollection& hist, const HistCollection& hist,
const DMatrix& fmat, const DMatrix& fmat,
RegTree* p_tree) { RegTree* p_tree) {
// TODO(hcho3): support feature sampling by levels // TODO(hcho3): support feature sampling by levels
/* 1. Create child nodes */ /* 1. Create child nodes */
@ -613,10 +613,10 @@ void QuantileHistMaker::Builder::ApplySplitSparseData(
} }
void QuantileHistMaker::Builder::InitNewNode(int nid, void QuantileHistMaker::Builder::InitNewNode(int nid,
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair, const std::vector<GradientPair>& gpair,
const DMatrix& fmat, const DMatrix& fmat,
const RegTree& tree) { const RegTree& tree) {
{ {
snode_.resize(tree.param.num_nodes, NodeEntry(param_)); snode_.resize(tree.param.num_nodes, NodeEntry(param_));
} }
@ -628,22 +628,24 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
// in distributed mode, the node's stats should be calculated from histogram, otherwise, // in distributed mode, the node's stats should be calculated from histogram, otherwise,
// we will have wrong results in EnumerateSplit() // we will have wrong results in EnumerateSplit()
// here we take the last feature in cut // here we take the last feature in cut
auto begin = hist.data();
for (size_t i = gmat.cut.row_ptr[0]; i < gmat.cut.row_ptr[1]; i++) { for (size_t i = gmat.cut.row_ptr[0]; i < gmat.cut.row_ptr[1]; i++) {
stats.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess); stats.Add(begin[i].sum_grad, begin[i].sum_hess);
} }
} else { } else {
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased || if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased ||
rabit::IsDistributed()) { rabit::IsDistributed()) {
/* specialized code for dense data /* specialized code for dense data
For dense data (with no missing value), For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] the sum of gradient histogram is equal to snode[nid]
GHistRow hist = hist_[nid];*/ GHistRow hist = hist_[nid];*/
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr; const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
const uint32_t ibegin = row_ptr[fid_least_bins_]; const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1]; const uint32_t iend = row_ptr[fid_least_bins_ + 1];
auto begin = hist.data();
for (uint32_t i = ibegin; i < iend; ++i) { for (uint32_t i = ibegin; i < iend; ++i) {
const GHistEntry et = hist.begin[i]; const GradStats et = begin[i];
stats.Add(et.sum_grad, et.sum_hess); stats.Add(et.sum_grad, et.sum_hess);
} }
} else { } else {
@ -653,27 +655,27 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
} }
} }
} }
}
// calculating the weights // calculating the weights
{ {
bst_uint parentid = tree[nid].Parent(); bst_uint parentid = tree[nid].Parent();
snode_[nid].weight = static_cast<float>( snode_[nid].weight = static_cast<float>(
spliteval_->ComputeWeight(parentid, snode_[nid].stats)); spliteval_->ComputeWeight(parentid, snode_[nid].stats));
snode_[nid].root_gain = static_cast<float>( snode_[nid].root_gain = static_cast<float>(
spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight)); spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight));
}
} }
} }
// enumerate the split values of specific feature // enumerate the split values of specific feature
void QuantileHistMaker::Builder::EnumerateSplit(int d_step, void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
const GHistRow& hist, const GHistRow& hist,
const NodeEntry& snode, const NodeEntry& snode,
const MetaInfo& info, const MetaInfo& info,
SplitEntry* p_best, SplitEntry* p_best,
bst_uint fid, bst_uint fid,
bst_uint nodeID) { bst_uint nodeID) {
CHECK(d_step == +1 || d_step == -1); CHECK(d_step == +1 || d_step == -1);
// aliases // aliases
@ -681,8 +683,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
const std::vector<bst_float>& cut_val = gmat.cut.cut; const std::vector<bst_float>& cut_val = gmat.cut.cut;
// statistics on both sides of split // statistics on both sides of split
GradStats c(param_); GradStats c;
GradStats e(param_); GradStats e;
// best split so far // best split so far
SplitEntry best; SplitEntry best;
@ -708,7 +710,7 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
for (int32_t i = ibegin; i != iend; i += d_step) { for (int32_t i = ibegin; i != iend; i += d_step) {
// start working // start working
// try to find a split // try to find a split
e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess); e.Add(hist[i].GetGrad(), hist[i].GetHess());
if (e.sum_hess >= param_.min_child_weight) { if (e.sum_hess >= param_.min_child_weight) {
c.SetSubstract(snode.stats, e); c.SetSubstract(snode.stats, e);
if (c.sum_hess >= param_.min_child_weight) { if (c.sum_hess >= param_.min_child_weight) {

View File

@ -30,7 +30,6 @@ using xgboost::common::HistCutMatrix;
using xgboost::common::GHistIndexMatrix; using xgboost::common::GHistIndexMatrix;
using xgboost::common::GHistIndexBlockMatrix; using xgboost::common::GHistIndexBlockMatrix;
using xgboost::common::GHistIndexRow; using xgboost::common::GHistIndexRow;
using xgboost::common::GHistEntry;
using xgboost::common::HistCollection; using xgboost::common::HistCollection;
using xgboost::common::RowSetCollection; using xgboost::common::RowSetCollection;
using xgboost::common::GHistRow; using xgboost::common::GHistRow;
@ -73,8 +72,7 @@ class QuantileHistMaker: public TreeUpdater {
SplitEntry best; SplitEntry best;
// constructor // constructor
explicit NodeEntry(const TrainParam& param) explicit NodeEntry(const TrainParam& param)
: stats(param), root_gain(0.0f), weight(0.0f) { : root_gain(0.0f), weight(0.0f) {}
}
}; };
// actual builder that runs the algorithm // actual builder that runs the algorithm
@ -105,7 +103,7 @@ class QuantileHistMaker: public TreeUpdater {
} else { } else {
hist_builder_.BuildHist(gpair, row_indices, gmat, hist); hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
} }
this->histred_.Allreduce(hist.begin, hist_builder_.GetNumBins()); this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins());
} }
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
@ -122,7 +120,7 @@ class QuantileHistMaker: public TreeUpdater {
const DMatrix& fmat, const DMatrix& fmat,
const RegTree& tree); const RegTree& tree);
void EvaluateSplit(int nid, void EvaluateSplit(const int nid,
const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat,
const HistCollection& hist, const HistCollection& hist,
const DMatrix& fmat, const DMatrix& fmat,
@ -227,7 +225,7 @@ class QuantileHistMaker: public TreeUpdater {
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
DataLayout data_layout_; DataLayout data_layout_;
rabit::Reducer<GHistEntry, GHistEntry::Reduce> histred_; rabit::Reducer<GradStats, GradStats::Reduce> histred_;
}; };
std::unique_ptr<Builder> builder_; std::unique_ptr<Builder> builder_;

View File

@ -19,7 +19,6 @@ namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_refresh); DMLC_REGISTRY_FILE_TAG(updater_refresh);
/*! \brief pruner that prunes a tree after growing finishs */ /*! \brief pruner that prunes a tree after growing finishs */
template<typename TStats>
class TreeRefresher: public TreeUpdater { class TreeRefresher: public TreeUpdater {
public: public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override { void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
@ -31,14 +30,13 @@ class TreeRefresher: public TreeUpdater {
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
if (trees.size() == 0) return; if (trees.size() == 0) return;
const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector(); const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector();
// number of threads
// thread temporal space // thread temporal space
std::vector<std::vector<TStats> > stemp; std::vector<std::vector<GradStats> > stemp;
std::vector<RegTree::FVec> fvec_temp; std::vector<RegTree::FVec> fvec_temp;
// setup temp space for each thread // setup temp space for each thread
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
fvec_temp.resize(nthread, RegTree::FVec()); fvec_temp.resize(nthread, RegTree::FVec());
stemp.resize(nthread, std::vector<TStats>()); stemp.resize(nthread, std::vector<GradStats>());
#pragma omp parallel #pragma omp parallel
{ {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
@ -46,8 +44,8 @@ class TreeRefresher: public TreeUpdater {
for (auto tree : trees) { for (auto tree : trees) {
num_nodes += tree->param.num_nodes; num_nodes += tree->param.num_nodes;
} }
stemp[tid].resize(num_nodes, TStats(param_)); stemp[tid].resize(num_nodes, GradStats());
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param_)); std::fill(stemp[tid].begin(), stemp[tid].end(), GradStats());
fvec_temp[tid].Init(trees[0]->param.num_feature); fvec_temp[tid].Init(trees[0]->param.num_feature);
} }
// if it is C++11, use lazy evaluation for Allreduce, // if it is C++11, use lazy evaluation for Allreduce,
@ -104,21 +102,22 @@ class TreeRefresher: public TreeUpdater {
const std::vector<GradientPair> &gpair, const std::vector<GradientPair> &gpair,
const MetaInfo &info, const MetaInfo &info,
const bst_uint ridx, const bst_uint ridx,
TStats *gstats) { GradStats *gstats) {
// start from groups that belongs to current data // start from groups that belongs to current data
auto pid = static_cast<int>(info.GetRoot(ridx)); auto pid = static_cast<int>(info.GetRoot(ridx));
gstats[pid].Add(gpair, info, ridx); gstats[pid].Add(gpair[ridx]);
// tranverse tree // tranverse tree
while (!tree[pid].IsLeaf()) { while (!tree[pid].IsLeaf()) {
unsigned split_index = tree[pid].SplitIndex(); unsigned split_index = tree[pid].SplitIndex();
pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
gstats[pid].Add(gpair, info, ridx); gstats[pid].Add(gpair[ridx]);
} }
} }
inline void Refresh(const TStats *gstats, inline void Refresh(const GradStats *gstats,
int nid, RegTree *p_tree) { int nid, RegTree *p_tree) {
RegTree &tree = *p_tree; RegTree &tree = *p_tree;
tree.Stat(nid).base_weight = static_cast<bst_float>(gstats[nid].CalcWeight(param_)); tree.Stat(nid).base_weight =
static_cast<bst_float>(CalcWeight(param_, gstats[nid]));
tree.Stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess); tree.Stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
if (tree[nid].IsLeaf()) { if (tree[nid].IsLeaf()) {
if (param_.refresh_leaf) { if (param_.refresh_leaf) {
@ -126,9 +125,9 @@ class TreeRefresher: public TreeUpdater {
} }
} else { } else {
tree.Stat(nid).loss_chg = static_cast<bst_float>( tree.Stat(nid).loss_chg = static_cast<bst_float>(
gstats[tree[nid].LeftChild()].CalcGain(param_) + xgboost::tree::CalcGain(param_, gstats[tree[nid].LeftChild()]) +
gstats[tree[nid].RightChild()].CalcGain(param_) - xgboost::tree::CalcGain(param_, gstats[tree[nid].RightChild()]) -
gstats[nid].CalcGain(param_)); xgboost::tree::CalcGain(param_, gstats[nid]));
this->Refresh(gstats, tree[nid].LeftChild(), p_tree); this->Refresh(gstats, tree[nid].LeftChild(), p_tree);
this->Refresh(gstats, tree[nid].RightChild(), p_tree); this->Refresh(gstats, tree[nid].RightChild(), p_tree);
} }
@ -136,13 +135,13 @@ class TreeRefresher: public TreeUpdater {
// training parameter // training parameter
TrainParam param_; TrainParam param_;
// reducer // reducer
rabit::Reducer<TStats, TStats::Reduce> reducer_; rabit::Reducer<GradStats, GradStats::Reduce> reducer_;
}; };
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
.describe("Refresher that refreshes the weight and statistics according to data.") .describe("Refresher that refreshes the weight and statistics according to data.")
.set_body([]() { .set_body([]() {
return new TreeRefresher<GradStats>(); return new TreeRefresher();
}); });
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -83,26 +83,17 @@ class SketchMaker: public BaseMaker {
double neg_grad; double neg_grad;
/*! \brief sum of hessian statistics */ /*! \brief sum of hessian statistics */
double sum_hess; double sum_hess;
SKStats() = default;
// constructor SKStats() : pos_grad{0}, neg_grad{0}, sum_hess{0} {}
explicit SKStats(const TrainParam &param) {
this->Clear();
}
/*! \brief clear the statistics */
inline void Clear() {
neg_grad = pos_grad = sum_hess = 0.0f;
}
// accumulate statistics // accumulate statistics
inline void Add(const std::vector<GradientPair> &gpair, void Add(const GradientPair& gpair) {
const MetaInfo &info, if (gpair.GetGrad() >= 0.0f) {
bst_uint ridx) { pos_grad += gpair.GetGrad();
const GradientPair &b = gpair[ridx];
if (b.GetGrad() >= 0.0f) {
pos_grad += b.GetGrad();
} else { } else {
neg_grad -= b.GetGrad(); neg_grad -= gpair.GetGrad();
} }
sum_hess += b.GetHess(); sum_hess += gpair.GetHess();
} }
/*! \brief calculate gain of the solution */ /*! \brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const { inline double CalcGain(const TrainParam &param) const {

View File

@ -57,10 +57,10 @@ class QuantileHistMock : public QuantileHistMaker {
{0.26f, 0.27f}, {0.23f, 0.24f}, {0.27f, 0.28f}, {0.26f, 0.27f}, {0.23f, 0.24f}, {0.27f, 0.28f},
{0.57f, 0.59f}, {0.23f, 0.24f}, {0.47f, 0.49f}}; {0.57f, 0.59f}, {0.23f, 0.24f}, {0.47f, 0.49f}};
for (size_t i = 0; i < hist_[nid].size; ++i) { for (size_t i = 0; i < hist_[nid].size(); ++i) {
GradientPairPrecise sol = solution[i]; GradientPairPrecise sol = solution[i];
ASSERT_NEAR(sol.GetGrad(), hist_[nid].begin[i].sum_grad, kEps); ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
ASSERT_NEAR(sol.GetHess(), hist_[nid].begin[i].sum_hess, kEps); ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
} }
} }