Clean up training code. (#3825)
* Remove GHistRow, GHistEntry, GHistIndexRow. * Remove kSimpleStats. * Remove CheckInfo, SetLeafVec in GradStats and in SKStats. * Clean up the GradStats. * Cleanup calcgain. * Move LossChangeMissing out of common. * Remove [] operator from GHistIndexBlock.
This commit is contained in:
parent
325b16bccd
commit
017c97b8ce
@ -1005,7 +1005,7 @@ class AllReducer {
|
||||
*/
|
||||
void Synchronize() {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
for (int i = 0; i < device_ordinals.size(); i++) {
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinals[i]));
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
|
||||
}
|
||||
@ -1051,7 +1051,7 @@ template <typename T, typename FunctionT>
|
||||
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
||||
SaveCudaContext{[&]() {
|
||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
for (size_t shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shard, shards->at(shard));
|
||||
}
|
||||
}};
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017 by Contributors
|
||||
* Copyright 2017-2018 by Contributors
|
||||
* \file hist_util.h
|
||||
* \brief Utilities to store histograms
|
||||
* \author Philip Cho, Tianqi Chen
|
||||
@ -417,7 +417,7 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const size_t* row_ptr = gmat.row_ptr.data();
|
||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
||||
|
||||
double* hist_data = reinterpret_cast<double*>(hist.begin);
|
||||
double* hist_data = reinterpret_cast<double*>(hist.data());
|
||||
double* data = reinterpret_cast<double*>(data_.data());
|
||||
|
||||
const size_t block_size = 512;
|
||||
@ -432,11 +432,11 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
|
||||
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
|
||||
|
||||
#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
|
||||
#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
|
||||
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
|
||||
dmlc::omp_uint tid = omp_get_thread_num();
|
||||
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
|
||||
reinterpret_cast<double*>(data_.data() + tid * nbins_));
|
||||
reinterpret_cast<double*>(data_.data() + tid * nbins_));
|
||||
|
||||
if (!thread_init_[tid]) {
|
||||
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
|
||||
@ -477,7 +477,7 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
|
||||
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
|
||||
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
|
||||
const size_t istart = iblock * block_size;
|
||||
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
|
||||
@ -507,8 +507,9 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
#if defined(_OPENMP)
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
#endif
|
||||
tree::GradStats* p_hist = hist.data();
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(guided)
|
||||
#pragma omp parallel for num_threads(nthread) schedule(guided)
|
||||
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
|
||||
auto gmat = gmatb[bid];
|
||||
|
||||
@ -517,20 +518,17 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
size_t ibegin[kUnroll];
|
||||
size_t iend[kUnroll];
|
||||
GradientPair stat[kUnroll];
|
||||
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
rid[k] = row_indices.begin[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ibegin[k] = gmat.row_ptr[rid[k]];
|
||||
iend[k] = gmat.row_ptr[rid[k] + 1];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
stat[k] = gpair[rid[k]];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
|
||||
const uint32_t bin = gmat.index[j];
|
||||
hist.begin[bin].Add(stat[k]);
|
||||
p_hist[bin].Add(stat[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -541,7 +539,7 @@ void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
const GradientPair stat = gpair[rid];
|
||||
for (size_t j = ibegin; j < iend; ++j) {
|
||||
const uint32_t bin = gmat.index[j];
|
||||
hist.begin[bin].Add(stat);
|
||||
p_hist[bin].Add(stat);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -555,24 +553,27 @@ void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow pa
|
||||
#if defined(_OPENMP)
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
#endif
|
||||
tree::GradStats* p_self = self.data();
|
||||
tree::GradStats* p_sibling = sibling.data();
|
||||
tree::GradStats* p_parent = parent.data();
|
||||
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||
for (bst_omp_uint bin_id = 0;
|
||||
bin_id < static_cast<bst_omp_uint>(nbins - rest); bin_id += kUnroll) {
|
||||
GHistEntry pb[kUnroll];
|
||||
GHistEntry sb[kUnroll];
|
||||
tree::GradStats pb[kUnroll];
|
||||
tree::GradStats sb[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
pb[k] = parent.begin[bin_id + k];
|
||||
pb[k] = p_parent[bin_id + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
sb[k] = sibling.begin[bin_id + k];
|
||||
sb[k] = p_sibling[bin_id + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
self.begin[bin_id + k].SetSubtract(pb[k], sb[k]);
|
||||
p_self[bin_id + k].SetSubstract(pb[k], sb[k]);
|
||||
}
|
||||
}
|
||||
for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
|
||||
self.begin[bin_id].SetSubtract(parent.begin[bin_id], sibling.begin[bin_id]);
|
||||
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -16,45 +16,8 @@
|
||||
#include "../include/rabit/rabit.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
namespace common {
|
||||
|
||||
/*! \brief sums of gradient statistics corresponding to a histogram bin */
|
||||
struct GHistEntry {
|
||||
/*! \brief sum of first-order gradient statistics */
|
||||
double sum_grad{0};
|
||||
/*! \brief sum of second-order gradient statistics */
|
||||
double sum_hess{0};
|
||||
|
||||
GHistEntry() = default;
|
||||
|
||||
inline void Clear() {
|
||||
sum_grad = sum_hess = 0;
|
||||
}
|
||||
|
||||
/*! \brief add a GradientPair to the sum */
|
||||
inline void Add(const GradientPair& e) {
|
||||
sum_grad += e.GetGrad();
|
||||
sum_hess += e.GetHess();
|
||||
}
|
||||
|
||||
/*! \brief add a GHistEntry to the sum */
|
||||
inline void Add(const GHistEntry& e) {
|
||||
sum_grad += e.sum_grad;
|
||||
sum_hess += e.sum_hess;
|
||||
}
|
||||
|
||||
inline static void Reduce(GHistEntry& a, const GHistEntry& b) { // NOLINT(*)
|
||||
a.Add(b);
|
||||
}
|
||||
|
||||
/*! \brief set sum to be difference of two GHistEntry's */
|
||||
inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) {
|
||||
sum_grad = a.sum_grad - b.sum_grad;
|
||||
sum_hess = a.sum_hess - b.sum_hess;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Cut configuration for all the features. */
|
||||
struct HistCutMatrix {
|
||||
/*! \brief Unit pointer to rows by element position */
|
||||
@ -83,15 +46,7 @@ void DeviceSketch
|
||||
* \brief A single row in global histogram index.
|
||||
* Directly represent the global index in the histogram entry.
|
||||
*/
|
||||
struct GHistIndexRow {
|
||||
/*! \brief The index of the histogram */
|
||||
const uint32_t* index;
|
||||
/*! \brief The size of the histogram */
|
||||
size_t size;
|
||||
GHistIndexRow() = default;
|
||||
GHistIndexRow(const uint32_t* index, size_t size)
|
||||
: index(index), size(size) {}
|
||||
};
|
||||
using GHistIndexRow = Span<uint32_t const>;
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
@ -111,7 +66,9 @@ struct GHistIndexMatrix {
|
||||
void Init(DMatrix* p_fmat, int max_num_bins);
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](size_t i) const {
|
||||
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
|
||||
return {&index[0] + row_ptr[i],
|
||||
static_cast<GHistIndexRow::index_type>(
|
||||
row_ptr[i + 1] - row_ptr[i])};
|
||||
}
|
||||
inline void GetFeatureCounts(size_t* counts) const {
|
||||
auto nfeature = cut.row_ptr.size() - 1;
|
||||
@ -134,11 +91,6 @@ struct GHistIndexBlock {
|
||||
|
||||
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
|
||||
: row_ptr(row_ptr), index(index) {}
|
||||
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](size_t i) const {
|
||||
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
|
||||
}
|
||||
};
|
||||
|
||||
class ColumnMatrix;
|
||||
@ -171,21 +123,12 @@ class GHistIndexBlockMatrix {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief histogram of gradient statistics for a single node.
|
||||
* Consists of multiple GHistEntry's, each entry showing total graident statistics
|
||||
* \brief histogram of graident statistics for a single node.
|
||||
* Consists of multiple GradStats, each entry showing total graident statistics
|
||||
* for that particular bin
|
||||
* Uses global bin id so as to represent all features simultaneously
|
||||
*/
|
||||
struct GHistRow {
|
||||
/*! \brief base pointer to first entry */
|
||||
GHistEntry* begin;
|
||||
/*! \brief number of entries */
|
||||
uint32_t size;
|
||||
|
||||
GHistRow() = default;
|
||||
GHistRow(GHistEntry* begin, uint32_t size)
|
||||
: begin(begin), size(size) {}
|
||||
};
|
||||
using GHistRow = Span<tree::GradStats>;
|
||||
|
||||
/*!
|
||||
* \brief histogram of gradient statistics for multiple nodes
|
||||
@ -193,27 +136,29 @@ struct GHistRow {
|
||||
class HistCollection {
|
||||
public:
|
||||
// access histogram for i-th node
|
||||
inline GHistRow operator[](bst_uint nid) const {
|
||||
GHistRow operator[](bst_uint nid) const {
|
||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
CHECK_NE(row_ptr_[nid], kMax);
|
||||
return {const_cast<GHistEntry*>(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_};
|
||||
tree::GradStats* ptr =
|
||||
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
|
||||
return {ptr, nbins_};
|
||||
}
|
||||
|
||||
// have we computed a histogram for i-th node?
|
||||
inline bool RowExists(bst_uint nid) const {
|
||||
bool RowExists(bst_uint nid) const {
|
||||
const uint32_t k_max = std::numeric_limits<uint32_t>::max();
|
||||
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
|
||||
}
|
||||
|
||||
// initialize histogram collection
|
||||
inline void Init(uint32_t nbins) {
|
||||
void Init(uint32_t nbins) {
|
||||
nbins_ = nbins;
|
||||
row_ptr_.clear();
|
||||
data_.clear();
|
||||
}
|
||||
|
||||
// create an empty histogram for i-th node
|
||||
inline void AddHistRow(bst_uint nid) {
|
||||
void AddHistRow(bst_uint nid) {
|
||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
if (nid >= row_ptr_.size()) {
|
||||
row_ptr_.resize(nid + 1, kMax);
|
||||
@ -228,7 +173,7 @@ class HistCollection {
|
||||
/*! \brief number of all bins over all features */
|
||||
uint32_t nbins_;
|
||||
|
||||
std::vector<GHistEntry> data_;
|
||||
std::vector<tree::GradStats> data_;
|
||||
|
||||
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
|
||||
std::vector<size_t> row_ptr_;
|
||||
@ -268,8 +213,8 @@ class GHistBuilder {
|
||||
size_t nthread_;
|
||||
/*! \brief number of all bins over all features */
|
||||
uint32_t nbins_;
|
||||
std::vector<GHistEntry> data_;
|
||||
std::vector<size_t> thread_init_;
|
||||
std::vector<tree::GradStats> data_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -140,7 +140,7 @@ class GPUDistribution {
|
||||
return begin;
|
||||
}
|
||||
|
||||
size_t ShardSize(size_t size, int index) const {
|
||||
size_t ShardSize(size_t size, size_t index) const {
|
||||
if (size == 0) { return 0; }
|
||||
if (offsets_.size() > 0) {
|
||||
// explicit offsets are provided
|
||||
@ -154,7 +154,7 @@ class GPUDistribution {
|
||||
return end - begin;
|
||||
}
|
||||
|
||||
size_t ShardProperSize(size_t size, int index) const {
|
||||
size_t ShardProperSize(size_t size, size_t index) const {
|
||||
if (size == 0) { return 0; }
|
||||
return ShardSize(size, index) - (devices_.Size() - 1 > index ? overlap_ : 0);
|
||||
}
|
||||
|
||||
@ -554,8 +554,8 @@ class Span {
|
||||
detail::ptrdiff_t _offset,
|
||||
detail::ptrdiff_t _count = dynamic_extent) const {
|
||||
SPAN_CHECK(_offset >= 0 && _offset < size());
|
||||
SPAN_CHECK(_count == dynamic_extent ||
|
||||
_count >= 0 && _offset + _count <= size());
|
||||
SPAN_CHECK((_count == dynamic_extent) ||
|
||||
(_count >= 0 && _offset + _count <= size()));
|
||||
|
||||
return {data() + _offset, _count ==
|
||||
dynamic_extent ? size() - _offset : _count};
|
||||
|
||||
@ -58,12 +58,12 @@ class Transform {
|
||||
public:
|
||||
Evaluator(Functor func, Range range, GPUSet devices, bool reshard) :
|
||||
func_(func), range_{std::move(range)},
|
||||
distribution_{std::move(GPUDistribution::Block(devices))},
|
||||
reshard_{reshard} {}
|
||||
reshard_{reshard},
|
||||
distribution_{std::move(GPUDistribution::Block(devices))} {}
|
||||
Evaluator(Functor func, Range range, GPUDistribution dist,
|
||||
bool reshard) :
|
||||
func_(func), range_{std::move(range)}, distribution_{std::move(dist)},
|
||||
reshard_{reshard} {}
|
||||
func_(func), range_{std::move(range)}, reshard_{reshard},
|
||||
distribution_{std::move(dist)} {}
|
||||
|
||||
/*!
|
||||
* \brief Evaluate the functor with input pointers to HostDeviceVector.
|
||||
@ -159,7 +159,7 @@ class Transform {
|
||||
|
||||
template <typename... HDV>
|
||||
void LaunchCPU(Functor func, HDV*... vectors) const {
|
||||
auto end = *(range_.end());
|
||||
omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong idx = 0; idx < end; ++idx) {
|
||||
func(idx, UnpackHDV(vectors)...);
|
||||
|
||||
117
src/tree/param.h
117
src/tree/param.h
@ -256,12 +256,12 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
|
||||
// functions for L1 cost
|
||||
template <typename T1, typename T2>
|
||||
XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) {
|
||||
if (w > +lambda) {
|
||||
return w - lambda;
|
||||
XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 alpha) {
|
||||
if (w > + alpha) {
|
||||
return w - alpha;
|
||||
}
|
||||
if (w < -lambda) {
|
||||
return w + lambda;
|
||||
if (w < - alpha) {
|
||||
return w + alpha;
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
@ -271,9 +271,9 @@ XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
|
||||
|
||||
// calculate the cost of loss function
|
||||
template <typename TrainingParams, typename T>
|
||||
XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad,
|
||||
T sum_hess, T w) {
|
||||
return -(2.0 * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w));
|
||||
XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p,
|
||||
T sum_grad, T sum_hess, T w) {
|
||||
return -(T(2.0) * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w));
|
||||
}
|
||||
|
||||
// calculate the cost of loss function
|
||||
@ -281,44 +281,51 @@ template <typename TrainingParams, typename T>
|
||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
|
||||
if (sum_hess < p.min_child_weight) {
|
||||
return T(0.0);
|
||||
}
|
||||
}
|
||||
if (p.max_delta_step == 0.0f) {
|
||||
if (p.reg_alpha == 0.0f) {
|
||||
return Sqr(sum_grad) / (sum_hess + p.reg_lambda);
|
||||
} else {
|
||||
return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) /
|
||||
(sum_hess + p.reg_lambda);
|
||||
(sum_hess + p.reg_lambda);
|
||||
}
|
||||
} else {
|
||||
T w = CalcWeight(p, sum_grad, sum_hess);
|
||||
T ret = sum_grad * w + T(0.5) * (sum_hess + p.reg_lambda) * Sqr(w);
|
||||
T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w);
|
||||
if (p.reg_alpha == 0.0f) {
|
||||
return T(-2.0) * ret;
|
||||
return ret;
|
||||
} else {
|
||||
return T(-2.0) * (ret + p.reg_alpha * std::abs(w));
|
||||
return ret + p.reg_alpha * std::abs(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TrainingParams,
|
||||
typename StatT, typename T = decltype(StatT().GetHess())>
|
||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) {
|
||||
return CalcGain(p, stat.GetGrad(), stat.GetHess());
|
||||
}
|
||||
|
||||
// calculate cost of loss function with four statistics
|
||||
template <typename TrainingParams, typename T>
|
||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
||||
T test_grad, T test_hess) {
|
||||
T test_grad, T test_hess) {
|
||||
T w = CalcWeight(sum_grad, sum_hess);
|
||||
T ret = test_grad * w + 0.5 * (test_hess + p.reg_lambda) * Sqr(w);
|
||||
T ret = CalcGainGivenWeight(p, test_grad, test_hess);
|
||||
if (p.reg_alpha == 0.0f) {
|
||||
return -2.0 * ret;
|
||||
return ret;
|
||||
} else {
|
||||
return -2.0 * (ret + p.reg_alpha * std::abs(w));
|
||||
return ret + p.reg_alpha * std::abs(w);
|
||||
}
|
||||
}
|
||||
|
||||
// calculate weight given the statistics
|
||||
template <typename TrainingParams, typename T>
|
||||
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
||||
T sum_hess) {
|
||||
T sum_hess) {
|
||||
if (sum_hess < p.min_child_weight || sum_hess <= 0.0) {
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
T dw;
|
||||
if (p.reg_alpha == 0.0f) {
|
||||
dw = -sum_grad / (sum_hess + p.reg_lambda);
|
||||
@ -328,14 +335,15 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
||||
if (p.max_delta_step != 0.0f) {
|
||||
if (dw > p.max_delta_step) {
|
||||
dw = p.max_delta_step;
|
||||
}
|
||||
}
|
||||
if (dw < -p.max_delta_step) {
|
||||
dw = -p.max_delta_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
return dw;
|
||||
}
|
||||
|
||||
// Used in gpu code where GradientPair is used for gradient sum, not GradStats.
|
||||
template <typename TrainingParams, typename GpairT>
|
||||
XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad) {
|
||||
return CalcWeight(p, sum_grad.GetGrad(), sum_grad.GetHess());
|
||||
@ -347,49 +355,27 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*!
|
||||
* \brief whether this is simply statistics and we only need to call
|
||||
* Add(gpair), instead of Add(gpair, info, ridx)
|
||||
*/
|
||||
static const int kSimpleStats = 1;
|
||||
/*! \brief constructor, the object must be cleared during construction */
|
||||
explicit GradStats(const TrainParam& param) { this->Clear(); }
|
||||
explicit GradStats(double sum_grad, double sum_hess)
|
||||
: sum_grad(sum_grad), sum_hess(sum_hess) {}
|
||||
|
||||
public:
|
||||
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
|
||||
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
|
||||
|
||||
XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
|
||||
static_assert(sizeof(GradStats) == 16,
|
||||
"Size of GradStats is not 16 bytes.");
|
||||
}
|
||||
|
||||
template <typename GpairT>
|
||||
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
|
||||
: sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {}
|
||||
/*! \brief clear the statistics */
|
||||
inline void Clear() { sum_grad = sum_hess = 0.0f; }
|
||||
/*! \brief check if necessary information is ready */
|
||||
inline static void CheckInfo(const MetaInfo& info) {}
|
||||
explicit GradStats(const double grad, const double hess)
|
||||
: sum_grad(grad), sum_hess(hess) {}
|
||||
/*!
|
||||
* \brief accumulate statistics
|
||||
* \param p the gradient pair
|
||||
*/
|
||||
inline void Add(GradientPair p) { this->Add(p.GetGrad(), p.GetHess()); }
|
||||
/*!
|
||||
* \brief accumulate statistics, more complicated version
|
||||
* \param gpair the vector storing the gradient statistics
|
||||
* \param info the additional information
|
||||
* \param ridx instance index of this instance
|
||||
*/
|
||||
inline void Add(const std::vector<GradientPair>& gpair, const MetaInfo& info,
|
||||
bst_uint ridx) {
|
||||
const GradientPair& b = gpair[ridx];
|
||||
this->Add(b.GetGrad(), b.GetHess());
|
||||
}
|
||||
/*! \brief calculate leaf weight */
|
||||
template <typename ParamT>
|
||||
XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m) const {
|
||||
return xgboost::tree::CalcWeight(param, sum_grad, sum_hess);
|
||||
}
|
||||
/*! \brief calculate gain of the solution */
|
||||
template <typename ParamT>
|
||||
inline double CalcGain(const ParamT& param) const {
|
||||
return xgboost::tree::CalcGain(param, sum_grad, sum_hess);
|
||||
}
|
||||
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(const GradStats& b) {
|
||||
sum_grad += b.sum_grad;
|
||||
@ -406,8 +392,6 @@ template <typename ParamT>
|
||||
}
|
||||
/*! \return whether the statistics is not used yet */
|
||||
inline bool Empty() const { return sum_hess == 0.0; }
|
||||
// constructor to allow inheritance
|
||||
GradStats() = default;
|
||||
/*! \brief add statistics to the data */
|
||||
inline void Add(double grad, double hess) {
|
||||
sum_grad += grad;
|
||||
@ -415,6 +399,7 @@ template <typename ParamT>
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(trivialfis): Remove this class.
|
||||
struct ValueConstraint {
|
||||
double lower_bound;
|
||||
double upper_bound;
|
||||
@ -424,9 +409,9 @@ struct ValueConstraint {
|
||||
inline static void Init(TrainParam *param, unsigned num_feature) {
|
||||
param->monotone_constraints.resize(num_feature, 0);
|
||||
}
|
||||
template <typename ParamT>
|
||||
template <typename ParamT>
|
||||
XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m, GradStats stats) const {
|
||||
double w = stats.CalcWeight(param);
|
||||
double w = xgboost::tree::CalcWeight(param, stats);
|
||||
if (w < lower_bound) {
|
||||
return lower_bound;
|
||||
}
|
||||
@ -436,13 +421,13 @@ template <typename ParamT>
|
||||
return w;
|
||||
}
|
||||
|
||||
template <typename ParamT>
|
||||
template <typename ParamT>
|
||||
XGBOOST_DEVICE inline double CalcGain(const ParamT ¶m, GradStats stats) const {
|
||||
return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess,
|
||||
CalcWeight(param, stats));
|
||||
}
|
||||
|
||||
template <typename ParamT>
|
||||
template <typename ParamT>
|
||||
XGBOOST_DEVICE inline double CalcSplitGain(const ParamT ¶m, int constraint,
|
||||
GradStats left, GradStats right) const {
|
||||
const double negative_infinity = -std::numeric_limits<double>::infinity();
|
||||
@ -468,7 +453,7 @@ template <typename ParamT>
|
||||
*cright = *this;
|
||||
if (c == 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
double wleft = CalcWeight(param, left);
|
||||
double wright = CalcWeight(param, right);
|
||||
double mid = (wleft + wright) / 2;
|
||||
@ -578,13 +563,13 @@ inline std::ostream &operator<<(std::ostream &os, const std::vector<int> &t) {
|
||||
for (auto it = t.begin(); it != t.end(); ++it) {
|
||||
if (it != t.begin()) {
|
||||
os << ',';
|
||||
}
|
||||
}
|
||||
os << *it;
|
||||
}
|
||||
// python style tuple
|
||||
if (t.size() == 1) {
|
||||
os << ',';
|
||||
}
|
||||
}
|
||||
os << ')';
|
||||
return os;
|
||||
}
|
||||
@ -603,7 +588,7 @@ inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
|
||||
is.get();
|
||||
if (ch == '(') {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isspace(ch)) {
|
||||
is.setstate(std::ios::failbit);
|
||||
return is;
|
||||
@ -635,7 +620,7 @@ inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
|
||||
}
|
||||
if (ch == ')') {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (ch == ')') {
|
||||
break;
|
||||
} else {
|
||||
|
||||
@ -50,6 +50,7 @@ void SplitEvaluator::AddSplit(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
bst_float leftweight,
|
||||
bst_float rightweight) {}
|
||||
|
||||
bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left_stats,
|
||||
|
||||
@ -333,28 +333,28 @@ class BaseMaker: public TreeUpdater {
|
||||
const MetaInfo &info = fmat.Info();
|
||||
thread_temp.resize(omp_get_max_threads());
|
||||
p_node_stats->resize(tree.param.num_nodes);
|
||||
#pragma omp parallel
|
||||
#pragma omp parallel
|
||||
{
|
||||
const int tid = omp_get_thread_num();
|
||||
thread_temp[tid].resize(tree.param.num_nodes, TStats(param_));
|
||||
thread_temp[tid].resize(tree.param.num_nodes, TStats());
|
||||
for (unsigned int nid : qexpand_) {
|
||||
thread_temp[tid][nid].Clear();
|
||||
thread_temp[tid][nid] = TStats();
|
||||
}
|
||||
}
|
||||
// setup position
|
||||
const auto ndata = static_cast<bst_omp_uint>(fmat.Info().num_row_);
|
||||
#pragma omp parallel for schedule(static)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) {
|
||||
const int nid = position_[ridx];
|
||||
const int tid = omp_get_thread_num();
|
||||
if (nid >= 0) {
|
||||
thread_temp[tid][nid].Add(gpair, info, ridx);
|
||||
thread_temp[tid][nid].Add(gpair[ridx]);
|
||||
}
|
||||
}
|
||||
// sum the per thread statistics together
|
||||
for (int nid : qexpand_) {
|
||||
TStats &s = (*p_node_stats)[nid];
|
||||
s.Clear();
|
||||
s = TStats();
|
||||
for (size_t tid = 0; tid < thread_temp.size(); ++tid) {
|
||||
s.Add(thread_temp[tid][nid]);
|
||||
}
|
||||
|
||||
@ -33,7 +33,6 @@ class ColMaker: public TreeUpdater {
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
GradStats::CheckInfo(dmat->Info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
@ -66,9 +65,7 @@ class ColMaker: public TreeUpdater {
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
// constructor
|
||||
explicit ThreadEntry(const TrainParam ¶m)
|
||||
: stats(param), stats_extra(param) {
|
||||
}
|
||||
ThreadEntry() : last_fvalue{0}, first_fvalue{0} {}
|
||||
};
|
||||
struct NodeEntry {
|
||||
/*! \brief statics for node entry */
|
||||
@ -80,9 +77,7 @@ class ColMaker: public TreeUpdater {
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
// constructor
|
||||
explicit NodeEntry(const TrainParam& param)
|
||||
: stats(param), root_gain(0.0f), weight(0.0f){
|
||||
}
|
||||
NodeEntry() : root_gain(0.0f), weight(0.0f) {}
|
||||
};
|
||||
// actual builder that runs the algorithm
|
||||
class Builder {
|
||||
@ -200,9 +195,9 @@ class ColMaker: public TreeUpdater {
|
||||
{
|
||||
// setup statistics space for each tree node
|
||||
for (auto& i : stemp_) {
|
||||
i.resize(tree.param.num_nodes, ThreadEntry(param_));
|
||||
i.resize(tree.param.num_nodes, ThreadEntry());
|
||||
}
|
||||
snode_.resize(tree.param.num_nodes, NodeEntry(param_));
|
||||
snode_.resize(tree.param.num_nodes, NodeEntry());
|
||||
}
|
||||
const MetaInfo& info = fmat.Info();
|
||||
// setup position
|
||||
@ -211,11 +206,11 @@ class ColMaker: public TreeUpdater {
|
||||
for (bst_omp_uint ridx = 0; ridx < ndata; ++ridx) {
|
||||
const int tid = omp_get_thread_num();
|
||||
if (position_[ridx] < 0) continue;
|
||||
stemp_[tid][position_[ridx]].stats.Add(gpair, info, ridx);
|
||||
stemp_[tid][position_[ridx]].stats.Add(gpair[ridx]);
|
||||
}
|
||||
// sum the per thread statistics together
|
||||
for (int nid : qexpand) {
|
||||
GradStats stats(param_);
|
||||
GradStats stats;
|
||||
for (auto& s : stemp_) {
|
||||
stats.Add(s[nid].stats);
|
||||
}
|
||||
@ -261,7 +256,7 @@ class ColMaker: public TreeUpdater {
|
||||
std::vector<ThreadEntry> &temp = stemp_[tid];
|
||||
// cleanup temp statistics
|
||||
for (int j : qexpand) {
|
||||
temp[j].stats.Clear();
|
||||
temp[j].stats = GradStats();
|
||||
}
|
||||
bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_;
|
||||
bst_uint end = std::min(static_cast<bst_uint>(col.size()), step * (tid + 1));
|
||||
@ -273,7 +268,7 @@ class ColMaker: public TreeUpdater {
|
||||
if (temp[nid].stats.Empty()) {
|
||||
temp[nid].first_fvalue = fvalue;
|
||||
}
|
||||
temp[nid].stats.Add(gpair, info, ridx);
|
||||
temp[nid].stats.Add(gpair[ridx]);
|
||||
temp[nid].last_fvalue = fvalue;
|
||||
}
|
||||
}
|
||||
@ -282,7 +277,7 @@ class ColMaker: public TreeUpdater {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint j = 0; j < nnode; ++j) {
|
||||
const int nid = qexpand[j];
|
||||
GradStats sum(param_), tmp(param_), c(param_);
|
||||
GradStats sum, tmp, c;
|
||||
for (int tid = 0; tid < this->nthread_; ++tid) {
|
||||
tmp = stemp_[tid][nid].stats;
|
||||
stemp_[tid][nid].stats = sum;
|
||||
@ -342,7 +337,7 @@ class ColMaker: public TreeUpdater {
|
||||
// rescan, generate candidate split
|
||||
#pragma omp parallel
|
||||
{
|
||||
GradStats c(param_), cright(param_);
|
||||
GradStats c, cright;
|
||||
const int tid = omp_get_thread_num();
|
||||
std::vector<ThreadEntry> &temp = stemp_[tid];
|
||||
bst_uint step = (col.size() + this->nthread_ - 1) / this->nthread_;
|
||||
@ -355,7 +350,7 @@ class ColMaker: public TreeUpdater {
|
||||
// get the statistics of nid
|
||||
ThreadEntry &e = temp[nid];
|
||||
if (e.stats.Empty()) {
|
||||
e.stats.Add(gpair, info, ridx);
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.first_fvalue = fvalue;
|
||||
} else {
|
||||
// forward default right
|
||||
@ -383,7 +378,7 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
}
|
||||
e.stats.Add(gpair, info, ridx);
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.first_fvalue = fvalue;
|
||||
}
|
||||
}
|
||||
@ -436,10 +431,10 @@ class ColMaker: public TreeUpdater {
|
||||
const std::vector<int> &qexpand = qexpand_;
|
||||
// clear all the temp statistics
|
||||
for (auto nid : qexpand) {
|
||||
temp[nid].stats.Clear();
|
||||
temp[nid].stats = GradStats();
|
||||
}
|
||||
// left statistics
|
||||
GradStats c(param_);
|
||||
GradStats c;
|
||||
// local cache buffer for position and gradient pair
|
||||
constexpr int kBuffer = 32;
|
||||
int buf_position[kBuffer] = {};
|
||||
@ -516,17 +511,17 @@ class ColMaker: public TreeUpdater {
|
||||
const MetaInfo &info,
|
||||
std::vector<ThreadEntry> &temp) { // NOLINT(*)
|
||||
// use cacheline aware optimization
|
||||
if (GradStats::kSimpleStats != 0 && param_.cache_opt != 0) {
|
||||
if (param_.cache_opt != 0) {
|
||||
EnumerateSplitCacheOpt(begin, end, d_step, fid, gpair, temp);
|
||||
return;
|
||||
}
|
||||
const std::vector<int> &qexpand = qexpand_;
|
||||
// clear all the temp statistics
|
||||
for (auto nid : qexpand) {
|
||||
temp[nid].stats.Clear();
|
||||
temp[nid].stats = GradStats();
|
||||
}
|
||||
// left statistics
|
||||
GradStats c(param_);
|
||||
GradStats c;
|
||||
for (const Entry *it = begin; it != end; it += d_step) {
|
||||
const bst_uint ridx = it->index;
|
||||
const int nid = position_[ridx];
|
||||
@ -537,7 +532,7 @@ class ColMaker: public TreeUpdater {
|
||||
ThreadEntry &e = temp[nid];
|
||||
// test if first hit, this is fine, because we set 0 during init
|
||||
if (e.stats.Empty()) {
|
||||
e.stats.Add(gpair, info, ridx);
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.last_fvalue = fvalue;
|
||||
} else {
|
||||
// try to find a split
|
||||
@ -562,7 +557,7 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
// update the statistics
|
||||
e.stats.Add(gpair, info, ridx);
|
||||
e.stats.Add(gpair[ridx]);
|
||||
e.last_fvalue = fvalue;
|
||||
}
|
||||
}
|
||||
@ -783,7 +778,6 @@ class DistColMaker : public ColMaker {
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
GradStats::CheckInfo(dmat->Info());
|
||||
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
|
||||
Builder builder(
|
||||
param_,
|
||||
|
||||
@ -16,6 +16,28 @@ namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu);
|
||||
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float inline LossChangeMissing(const GradientPairT& scan,
|
||||
const GradientPairT& missing,
|
||||
const GradientPairT& parent_sum,
|
||||
const float& parent_gain,
|
||||
const GPUTrainingParam& param,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
// Put gradients of missing values to left
|
||||
float missing_left_loss =
|
||||
DeviceCalcLossChange(param, scan + missing, parent_sum, parent_gain);
|
||||
float missing_right_loss =
|
||||
DeviceCalcLossChange(param, scan, parent_sum, parent_gain);
|
||||
|
||||
if (missing_left_loss >= missing_right_loss) {
|
||||
missing_left_out = true;
|
||||
return missing_left_loss;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_loss;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Absolute BFS order IDs to col-wise unique IDs based on user input
|
||||
* @param tid the index of the element that this thread should access
|
||||
@ -565,7 +587,6 @@ class GPUMaker : public TreeUpdater {
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
GradStats::CheckInfo(dmat->Info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
@ -633,7 +654,7 @@ class GPUMaker : public TreeUpdater {
|
||||
// get the default direction for the current node
|
||||
GradientPair missing = n.sum_gradients - gradSum;
|
||||
LossChangeMissing(gradScan, missing, n.sum_gradients, n.root_gain,
|
||||
gpu_param, missingLeft);
|
||||
gpu_param, missingLeft);
|
||||
// get the score/weight/id/gradSum for left and right child nodes
|
||||
GradientPair lGradSum = missingLeft ? gradScan + missing : gradScan;
|
||||
GradientPair rGradSum = n.sum_gradients - lGradSum;
|
||||
|
||||
@ -98,7 +98,7 @@ struct DeviceSplitCandidate {
|
||||
|
||||
template <typename ParamT>
|
||||
XGBOOST_DEVICE void Update(const DeviceSplitCandidate& other,
|
||||
const ParamT& param) {
|
||||
const ParamT& param) {
|
||||
if (other.loss_chg > loss_chg &&
|
||||
other.left_sum.GetHess() >= param.min_child_weight &&
|
||||
other.right_sum.GetHess() >= param.min_child_weight) {
|
||||
@ -213,51 +213,6 @@ XGBOOST_DEVICE inline float DeviceCalcLossChange(const GPUTrainingParam& param,
|
||||
return left_gain + right_gain - parent_gain;
|
||||
}
|
||||
|
||||
// Without constraints
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float inline LossChangeMissing(const GradientPairT& scan,
|
||||
const GradientPairT& missing,
|
||||
const GradientPairT& parent_sum,
|
||||
const float& parent_gain,
|
||||
const GPUTrainingParam& param,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
// Put gradients of missing values to left
|
||||
float missing_left_loss =
|
||||
DeviceCalcLossChange(param, scan + missing, parent_sum, parent_gain);
|
||||
float missing_right_loss =
|
||||
DeviceCalcLossChange(param, scan, parent_sum, parent_gain);
|
||||
|
||||
if (missing_left_loss >= missing_right_loss) {
|
||||
missing_left_out = true;
|
||||
return missing_left_loss;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_loss;
|
||||
}
|
||||
}
|
||||
|
||||
// With constraints
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float inline LossChangeMissing(
|
||||
const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum,
|
||||
const float& parent_gain, const GPUTrainingParam& param, int constraint,
|
||||
const ValueConstraint& value_constraint,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
float missing_left_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan + missing),
|
||||
GradStats(parent_sum - (scan + missing)));
|
||||
float missing_right_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
|
||||
|
||||
if (missing_left_gain >= missing_right_gain) {
|
||||
missing_left_out = true;
|
||||
return missing_left_gain - parent_gain;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_gain - parent_gain;
|
||||
}
|
||||
}
|
||||
|
||||
// Total number of nodes in tree, given depth
|
||||
XGBOOST_DEVICE inline int MaxNodesDepth(int depth) {
|
||||
return (1 << (depth + 1)) - 1;
|
||||
|
||||
@ -50,6 +50,28 @@ struct GPUHistMakerTrainParam
|
||||
|
||||
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||
|
||||
// With constraints
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float inline LossChangeMissing(
|
||||
const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum,
|
||||
const float& parent_gain, const GPUTrainingParam& param, int constraint,
|
||||
const ValueConstraint& value_constraint,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
float missing_left_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan + missing),
|
||||
GradStats(parent_sum - (scan + missing)));
|
||||
float missing_right_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
|
||||
|
||||
if (missing_left_gain >= missing_right_gain) {
|
||||
missing_left_out = true;
|
||||
return missing_left_gain - parent_gain;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_gain - parent_gain;
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief
|
||||
*
|
||||
@ -942,7 +964,6 @@ class GPUHistMakerSpecialised{
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
monitor_.Start("Update", dist_.Devices());
|
||||
GradStats::CheckInfo(dmat->Info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
@ -1183,11 +1204,12 @@ class GPUHistMakerSpecialised{
|
||||
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
GradStats left_stats(param_);
|
||||
|
||||
GradStats left_stats;
|
||||
left_stats.Add(candidate.split.left_sum);
|
||||
GradStats right_stats(param_);
|
||||
GradStats right_stats;
|
||||
right_stats.Add(candidate.split.right_sum);
|
||||
GradStats parent_sum(param_);
|
||||
GradStats parent_sum;
|
||||
parent_sum.Add(left_stats);
|
||||
parent_sum.Add(right_stats);
|
||||
node_value_constraints_.resize(tree.GetNodes().size());
|
||||
|
||||
@ -19,13 +19,11 @@ namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_histmaker);
|
||||
|
||||
template<typename TStats>
|
||||
class HistMaker: public BaseMaker {
|
||||
public:
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
TStats::CheckInfo(p_fmat->Info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
@ -42,13 +40,13 @@ class HistMaker: public BaseMaker {
|
||||
/*! \brief cutting point of histogram, contains maximum point */
|
||||
const bst_float *cut;
|
||||
/*! \brief content of statistics data */
|
||||
TStats *data;
|
||||
GradStats *data;
|
||||
/*! \brief size of histogram */
|
||||
unsigned size;
|
||||
// default constructor
|
||||
HistUnit() = default;
|
||||
// constructor
|
||||
HistUnit(const bst_float *cut, TStats *data, unsigned size)
|
||||
HistUnit(const bst_float *cut, GradStats *data, unsigned size)
|
||||
: cut(cut), data(data), size(size) {}
|
||||
/*! \brief add a histogram to data */
|
||||
inline void Add(bst_float fv,
|
||||
@ -58,7 +56,7 @@ class HistMaker: public BaseMaker {
|
||||
unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
|
||||
CHECK_NE(size, 0U) << "try insert into size=0";
|
||||
CHECK_LT(i, size);
|
||||
data[i].Add(gpair, info, ridx);
|
||||
data[i].Add(gpair[ridx]);
|
||||
}
|
||||
};
|
||||
/*! \brief a set of histograms from different index */
|
||||
@ -68,7 +66,7 @@ class HistMaker: public BaseMaker {
|
||||
/*! \brief cutting points in each histunit */
|
||||
const bst_float *cut;
|
||||
/*! \brief data in different hist unit */
|
||||
std::vector<TStats> data;
|
||||
std::vector<GradStats> data;
|
||||
/*! \brief */
|
||||
inline HistUnit operator[](size_t fid) {
|
||||
return HistUnit(cut + rptr[fid],
|
||||
@ -89,12 +87,10 @@ class HistMaker: public BaseMaker {
|
||||
hset.resize(nthread);
|
||||
// cleanup statistics
|
||||
for (int tid = 0; tid < nthread; ++tid) {
|
||||
for (size_t i = 0; i < hset[tid].data.size(); ++i) {
|
||||
hset[tid].data[i].Clear();
|
||||
}
|
||||
for (auto& d : hset[tid].data) { d = GradStats(); }
|
||||
hset[tid].rptr = dmlc::BeginPtr(rptr);
|
||||
hset[tid].cut = dmlc::BeginPtr(cut);
|
||||
hset[tid].data.resize(cut.size(), TStats(param));
|
||||
hset[tid].data.resize(cut.size(), GradStats());
|
||||
}
|
||||
}
|
||||
// aggregate all statistics to hset[0]
|
||||
@ -119,7 +115,7 @@ class HistMaker: public BaseMaker {
|
||||
// workspace of thread
|
||||
ThreadWSpace wspace_;
|
||||
// reducer for histogram
|
||||
rabit::Reducer<TStats, TStats::Reduce> histred_;
|
||||
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
|
||||
// set of working features
|
||||
std::vector<bst_uint> fwork_set_;
|
||||
// update function implementation
|
||||
@ -147,8 +143,7 @@ class HistMaker: public BaseMaker {
|
||||
// if nothing left to be expand, break
|
||||
if (qexpand_.size() == 0) break;
|
||||
}
|
||||
for (size_t i = 0; i < qexpand_.size(); ++i) {
|
||||
const int nid = qexpand_[i];
|
||||
for (int const nid : qexpand_) {
|
||||
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
|
||||
}
|
||||
}
|
||||
@ -179,34 +174,35 @@ class HistMaker: public BaseMaker {
|
||||
|
||||
private:
|
||||
inline void EnumerateSplit(const HistUnit &hist,
|
||||
const TStats &node_sum,
|
||||
const GradStats &node_sum,
|
||||
bst_uint fid,
|
||||
SplitEntry *best,
|
||||
TStats *left_sum) {
|
||||
GradStats *left_sum) {
|
||||
if (hist.size == 0) return;
|
||||
|
||||
double root_gain = node_sum.CalcGain(param_);
|
||||
TStats s(param_), c(param_);
|
||||
double root_gain = CalcGain(param_, node_sum.GetGrad(), node_sum.GetHess());
|
||||
GradStats s, c;
|
||||
for (bst_uint i = 0; i < hist.size; ++i) {
|
||||
s.Add(hist.data[i]);
|
||||
if (s.sum_hess >= param_.min_child_weight) {
|
||||
c.SetSubstract(node_sum, s);
|
||||
if (c.sum_hess >= param_.min_child_weight) {
|
||||
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
||||
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i],
|
||||
false, s, c)) {
|
||||
double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
|
||||
CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
|
||||
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false, s, c)) {
|
||||
*left_sum = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.Clear();
|
||||
s = GradStats();
|
||||
for (bst_uint i = hist.size - 1; i != 0; --i) {
|
||||
s.Add(hist.data[i]);
|
||||
if (s.sum_hess >= param_.min_child_weight) {
|
||||
c.SetSubstract(node_sum, s);
|
||||
if (c.sum_hess >= param_.min_child_weight) {
|
||||
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
||||
double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) +
|
||||
CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain;
|
||||
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) {
|
||||
*left_sum = c;
|
||||
}
|
||||
@ -222,14 +218,14 @@ class HistMaker: public BaseMaker {
|
||||
const size_t num_feature = fset.size();
|
||||
// get the best split condition for each node
|
||||
std::vector<SplitEntry> sol(qexpand_.size());
|
||||
std::vector<TStats> left_sum(qexpand_.size());
|
||||
std::vector<GradStats> left_sum(qexpand_.size());
|
||||
auto nexpand = static_cast<bst_omp_uint>(qexpand_.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
|
||||
const int nid = qexpand_[wid];
|
||||
CHECK_EQ(node2workindex_[nid], static_cast<int>(wid));
|
||||
SplitEntry &best = sol[wid];
|
||||
TStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
for (size_t i = 0; i < fset.size(); ++i) {
|
||||
EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)],
|
||||
node_sum, fset[i], &best, &left_sum[wid]);
|
||||
@ -239,13 +235,13 @@ class HistMaker: public BaseMaker {
|
||||
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
|
||||
const int nid = qexpand_[wid];
|
||||
const SplitEntry &best = sol[wid];
|
||||
const TStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
const GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
|
||||
this->SetStats(p_tree, nid, node_sum);
|
||||
// set up the values
|
||||
p_tree->Stat(nid).loss_chg = best.loss_chg;
|
||||
// now we know the solution in snode[nid], set split
|
||||
if (best.loss_chg > kRtEps) {
|
||||
bst_float base_weight = node_sum.CalcWeight(param_);
|
||||
bst_float base_weight = CalcWeight(param_, node_sum);
|
||||
bst_float left_leaf_weight =
|
||||
CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
|
||||
param_.learning_rate;
|
||||
@ -258,7 +254,7 @@ class HistMaker: public BaseMaker {
|
||||
right_leaf_weight, best.loss_chg,
|
||||
node_sum.sum_hess);
|
||||
// right side sum
|
||||
TStats right_sum;
|
||||
GradStats right_sum;
|
||||
right_sum.SetSubstract(node_sum, left_sum[wid]);
|
||||
this->SetStats(p_tree, (*p_tree)[nid].LeftChild(), left_sum[wid]);
|
||||
this->SetStats(p_tree, (*p_tree)[nid].RightChild(), right_sum);
|
||||
@ -268,20 +264,20 @@ class HistMaker: public BaseMaker {
|
||||
}
|
||||
}
|
||||
|
||||
inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) {
|
||||
p_tree->Stat(nid).base_weight = static_cast<bst_float>(node_sum.CalcWeight(param_));
|
||||
inline void SetStats(RegTree *p_tree, int nid, const GradStats &node_sum) {
|
||||
p_tree->Stat(nid).base_weight =
|
||||
static_cast<bst_float>(CalcWeight(param_, node_sum));
|
||||
p_tree->Stat(nid).sum_hess = static_cast<bst_float>(node_sum.sum_hess);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TStats>
|
||||
class CQHistMaker: public HistMaker<TStats> {
|
||||
class CQHistMaker: public HistMaker {
|
||||
public:
|
||||
CQHistMaker() = default;
|
||||
|
||||
protected:
|
||||
struct HistEntry {
|
||||
typename HistMaker<TStats>::HistUnit hist;
|
||||
HistMaker::HistUnit hist;
|
||||
unsigned istart;
|
||||
/*!
|
||||
* \brief add a histogram to data,
|
||||
@ -293,7 +289,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
const bst_uint ridx) {
|
||||
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
|
||||
CHECK_NE(istart, hist.size);
|
||||
hist.data[istart].Add(gpair, info, ridx);
|
||||
hist.data[istart].Add(gpair[ridx]);
|
||||
}
|
||||
/*!
|
||||
* \brief add a histogram to data,
|
||||
@ -352,7 +348,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
// start enumeration
|
||||
const auto nsize = static_cast<bst_omp_uint>(fset.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
int fid = fset[i];
|
||||
int offset = feat2workindex_[fid];
|
||||
@ -366,8 +362,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// update node statistics.
|
||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||
&thread_stats_, &node_stats_);
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const int nid = this->qexpand_[i];
|
||||
for (int const nid : this->qexpand_) {
|
||||
const int wid = this->node2workindex_[nid];
|
||||
this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)]
|
||||
.data[0] = node_stats_[nid];
|
||||
@ -403,8 +398,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
const size_t work_set_size = work_set_.size();
|
||||
|
||||
sketchs_.resize(this->qexpand_.size() * work_set_size);
|
||||
for (size_t i = 0; i < sketchs_.size(); ++i) {
|
||||
sketchs_[i].Init(info.num_row_, this->param_.sketch_eps);
|
||||
for (auto& sketch : sketchs_) {
|
||||
sketch.Init(info.num_row_, this->param_.sketch_eps);
|
||||
}
|
||||
// intitialize the summary array
|
||||
summary_array_.resize(sketchs_.size());
|
||||
@ -501,13 +496,12 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// initialize sbuilder for use
|
||||
std::vector<HistEntry> &hbuilder = *p_temp;
|
||||
hbuilder.resize(tree.param.num_nodes);
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const unsigned nid = this->qexpand_[i];
|
||||
for (int const nid : this->qexpand_) {
|
||||
const unsigned wid = this->node2workindex_[nid];
|
||||
hbuilder[nid].istart = 0;
|
||||
hbuilder[nid].hist = this->wspace_.hset[0][fid_offset + wid * (fset.size()+1)];
|
||||
}
|
||||
if (TStats::kSimpleStats != 0 && this->param_.cache_opt != 0) {
|
||||
if (this->param_.cache_opt != 0) {
|
||||
constexpr bst_uint kBuffer = 32;
|
||||
bst_uint align_length = col.size() / kBuffer * kBuffer;
|
||||
int buf_position[kBuffer];
|
||||
@ -552,13 +546,11 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// initialize sbuilder for use
|
||||
std::vector<BaseMaker::SketchEntry> &sbuilder = *p_temp;
|
||||
sbuilder.resize(tree.param.num_nodes);
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const unsigned nid = this->qexpand_[i];
|
||||
for (int const nid : this->qexpand_) {
|
||||
const unsigned wid = this->node2workindex_[nid];
|
||||
sbuilder[nid].sum_total = 0.0f;
|
||||
sbuilder[nid].sketch = &sketchs_[wid * work_set_size + offset];
|
||||
}
|
||||
|
||||
// first pass, get sum of weight, TODO, optimization to skip first pass
|
||||
for (const auto& c : col) {
|
||||
const bst_uint ridx = c.index;
|
||||
@ -569,20 +561,19 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
// if only one value, no need to do second pass
|
||||
if (col[0].fvalue == col[col.size()-1].fvalue) {
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const int nid = this->qexpand_[i];
|
||||
sbuilder[nid].sketch->Push(col[0].fvalue, static_cast<bst_float>(sbuilder[nid].sum_total));
|
||||
for (int const nid : this->qexpand_) {
|
||||
sbuilder[nid].sketch->Push(
|
||||
col[0].fvalue, static_cast<bst_float>(sbuilder[nid].sum_total));
|
||||
}
|
||||
return;
|
||||
}
|
||||
// two pass scan
|
||||
unsigned max_size = this->param_.MaxSketchSize();
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const int nid = this->qexpand_[i];
|
||||
for (int const nid : this->qexpand_) {
|
||||
sbuilder[nid].Init(max_size);
|
||||
}
|
||||
// second pass, build the sketch
|
||||
if (TStats::kSimpleStats != 0 && this->param_.cache_opt != 0) {
|
||||
if (this->param_.cache_opt != 0) {
|
||||
constexpr bst_uint kBuffer = 32;
|
||||
bst_uint align_length = col.size() / kBuffer * kBuffer;
|
||||
int buf_position[kBuffer];
|
||||
@ -616,10 +607,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const int nid = this->qexpand_[i];
|
||||
sbuilder[nid].Finalize(max_size);
|
||||
}
|
||||
for (int const nid : this->qexpand_) { sbuilder[nid].Finalize(max_size); }
|
||||
}
|
||||
// cached dmatrix where we initialized the feature on.
|
||||
const DMatrix* cache_dmatrix_{nullptr};
|
||||
@ -634,11 +622,11 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// thread temp data
|
||||
std::vector<std::vector<BaseMaker::SketchEntry> > thread_sketch_;
|
||||
// used to hold statistics
|
||||
std::vector<std::vector<TStats> > thread_stats_;
|
||||
std::vector<std::vector<GradStats> > thread_stats_;
|
||||
// used to hold start pointer
|
||||
std::vector<std::vector<HistEntry> > thread_hist_;
|
||||
// node statistics
|
||||
std::vector<TStats> node_stats_;
|
||||
std::vector<GradStats> node_stats_;
|
||||
// summary array
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array_;
|
||||
// reducer for summary
|
||||
@ -648,8 +636,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
};
|
||||
|
||||
// global proposal
|
||||
template<typename TStats>
|
||||
class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
||||
class GlobalProposalHistMaker: public CQHistMaker {
|
||||
protected:
|
||||
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
@ -661,7 +648,7 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
||||
}
|
||||
if (cached_rptr_.size() == 0) {
|
||||
CHECK_EQ(this->qexpand_.size(), 1U);
|
||||
CQHistMaker<TStats>::ResetPosAndPropose(gpair, p_fmat, fset, tree);
|
||||
CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree);
|
||||
cached_rptr_ = this->wspace_.rptr;
|
||||
cached_cut_ = this->wspace_.cut;
|
||||
} else {
|
||||
@ -730,8 +717,7 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
||||
// update node statistics.
|
||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||
&(this->thread_stats_), &(this->node_stats_));
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const int nid = this->qexpand_[i];
|
||||
for (const int nid : this->qexpand_) {
|
||||
const int wid = this->node2workindex_[nid];
|
||||
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
|
||||
.data[0] = this->node_stats_[nid];
|
||||
@ -750,19 +736,19 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
||||
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||
.describe("Tree constructor that uses approximate histogram construction.")
|
||||
.set_body([]() {
|
||||
return new CQHistMaker<GradStats>();
|
||||
return new CQHistMaker();
|
||||
});
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker")
|
||||
.describe("Tree constructor that uses approximate global proposal of histogram construction.")
|
||||
.set_body([]() {
|
||||
return new GlobalProposalHistMaker<GradStats>();
|
||||
return new GlobalProposalHistMaker();
|
||||
});
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
||||
.describe("Tree constructor that uses approximate global of histogram construction.")
|
||||
.set_body([]() {
|
||||
return new GlobalProposalHistMaker<GradStats>();
|
||||
return new GlobalProposalHistMaker();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -50,9 +50,8 @@ void QuantileHistMaker::Init(const std::vector<std::pair<std::string, std::strin
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
GradStats::CheckInfo(dmat->Info());
|
||||
DMatrix *dmat,
|
||||
const std::vector<RegTree *> &trees) {
|
||||
if (is_gmat_initialized_ == false) {
|
||||
double tstart = dmlc::GetTime();
|
||||
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
|
||||
@ -91,11 +90,11 @@ bool QuantileHistMaker::UpdatePredictionCache(
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
const ColumnMatrix& column_matrix,
|
||||
HostDeviceVector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* p_tree) {
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
const ColumnMatrix& column_matrix,
|
||||
HostDeviceVector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* p_tree) {
|
||||
double gstart = dmlc::GetTime();
|
||||
|
||||
int num_leaves = 0;
|
||||
@ -280,9 +279,9 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
CHECK_EQ(tree.param.num_nodes, tree.param.num_roots)
|
||||
<< "ColMakerHist: can only grow new tree";
|
||||
CHECK((param_.max_depth > 0 || param_.max_leaves > 0))
|
||||
@ -395,11 +394,11 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::EvaluateSplit(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
void QuantileHistMaker::Builder::EvaluateSplit(const int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
// start enumeration
|
||||
const MetaInfo& info = fmat.Info();
|
||||
auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
|
||||
@ -411,13 +410,14 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid,
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
best_split_tloc_[tid] = snode_[nid].best;
|
||||
}
|
||||
GHistRow node_hist = hist[nid];
|
||||
#pragma omp parallel for schedule(dynamic) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < nfeature; ++i) {
|
||||
const bst_uint fid = feature_set[i];
|
||||
const unsigned tid = omp_get_thread_num();
|
||||
this->EnumerateSplit(-1, gmat, hist[nid], snode_[nid], info,
|
||||
this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info,
|
||||
&best_split_tloc_[tid], fid, nid);
|
||||
this->EnumerateSplit(+1, gmat, hist[nid], snode_[nid], info,
|
||||
this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info,
|
||||
&best_split_tloc_[tid], fid, nid);
|
||||
}
|
||||
for (unsigned tid = 0; tid < nthread; ++tid) {
|
||||
@ -426,11 +426,11 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid,
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::ApplySplit(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
RegTree* p_tree) {
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
RegTree* p_tree) {
|
||||
// TODO(hcho3): support feature sampling by levels
|
||||
|
||||
/* 1. Create child nodes */
|
||||
@ -613,10 +613,10 @@ void QuantileHistMaker::Builder::ApplySplitSparseData(
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
{
|
||||
snode_.resize(tree.param.num_nodes, NodeEntry(param_));
|
||||
}
|
||||
@ -628,22 +628,24 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
// in distributed mode, the node's stats should be calculated from histogram, otherwise,
|
||||
// we will have wrong results in EnumerateSplit()
|
||||
// here we take the last feature in cut
|
||||
auto begin = hist.data();
|
||||
for (size_t i = gmat.cut.row_ptr[0]; i < gmat.cut.row_ptr[1]; i++) {
|
||||
stats.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess);
|
||||
stats.Add(begin[i].sum_grad, begin[i].sum_hess);
|
||||
}
|
||||
} else {
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased ||
|
||||
rabit::IsDistributed()) {
|
||||
/* specialized code for dense data
|
||||
For dense data (with no missing value),
|
||||
the sum of gradient histogram is equal to snode[nid]
|
||||
GHistRow hist = hist_[nid];*/
|
||||
For dense data (with no missing value),
|
||||
the sum of gradient histogram is equal to snode[nid]
|
||||
GHistRow hist = hist_[nid];*/
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
|
||||
|
||||
const uint32_t ibegin = row_ptr[fid_least_bins_];
|
||||
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
|
||||
auto begin = hist.data();
|
||||
for (uint32_t i = ibegin; i < iend; ++i) {
|
||||
const GHistEntry et = hist.begin[i];
|
||||
const GradStats et = begin[i];
|
||||
stats.Add(et.sum_grad, et.sum_hess);
|
||||
}
|
||||
} else {
|
||||
@ -653,27 +655,27 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculating the weights
|
||||
{
|
||||
bst_uint parentid = tree[nid].Parent();
|
||||
snode_[nid].weight = static_cast<float>(
|
||||
spliteval_->ComputeWeight(parentid, snode_[nid].stats));
|
||||
snode_[nid].root_gain = static_cast<float>(
|
||||
spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight));
|
||||
// calculating the weights
|
||||
{
|
||||
bst_uint parentid = tree[nid].Parent();
|
||||
snode_[nid].weight = static_cast<float>(
|
||||
spliteval_->ComputeWeight(parentid, snode_[nid].stats));
|
||||
snode_[nid].root_gain = static_cast<float>(
|
||||
spliteval_->ComputeScore(parentid, snode_[nid].stats, snode_[nid].weight));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// enumerate the split values of specific feature
|
||||
void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistRow& hist,
|
||||
const NodeEntry& snode,
|
||||
const MetaInfo& info,
|
||||
SplitEntry* p_best,
|
||||
bst_uint fid,
|
||||
bst_uint nodeID) {
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistRow& hist,
|
||||
const NodeEntry& snode,
|
||||
const MetaInfo& info,
|
||||
SplitEntry* p_best,
|
||||
bst_uint fid,
|
||||
bst_uint nodeID) {
|
||||
CHECK(d_step == +1 || d_step == -1);
|
||||
|
||||
// aliases
|
||||
@ -681,8 +683,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
||||
const std::vector<bst_float>& cut_val = gmat.cut.cut;
|
||||
|
||||
// statistics on both sides of split
|
||||
GradStats c(param_);
|
||||
GradStats e(param_);
|
||||
GradStats c;
|
||||
GradStats e;
|
||||
// best split so far
|
||||
SplitEntry best;
|
||||
|
||||
@ -708,7 +710,7 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
||||
for (int32_t i = ibegin; i != iend; i += d_step) {
|
||||
// start working
|
||||
// try to find a split
|
||||
e.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess);
|
||||
e.Add(hist[i].GetGrad(), hist[i].GetHess());
|
||||
if (e.sum_hess >= param_.min_child_weight) {
|
||||
c.SetSubstract(snode.stats, e);
|
||||
if (c.sum_hess >= param_.min_child_weight) {
|
||||
|
||||
@ -30,7 +30,6 @@ using xgboost::common::HistCutMatrix;
|
||||
using xgboost::common::GHistIndexMatrix;
|
||||
using xgboost::common::GHistIndexBlockMatrix;
|
||||
using xgboost::common::GHistIndexRow;
|
||||
using xgboost::common::GHistEntry;
|
||||
using xgboost::common::HistCollection;
|
||||
using xgboost::common::RowSetCollection;
|
||||
using xgboost::common::GHistRow;
|
||||
@ -73,8 +72,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
SplitEntry best;
|
||||
// constructor
|
||||
explicit NodeEntry(const TrainParam& param)
|
||||
: stats(param), root_gain(0.0f), weight(0.0f) {
|
||||
}
|
||||
: root_gain(0.0f), weight(0.0f) {}
|
||||
};
|
||||
// actual builder that runs the algorithm
|
||||
|
||||
@ -105,7 +103,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
} else {
|
||||
hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
|
||||
}
|
||||
this->histred_.Allreduce(hist.begin, hist_builder_.GetNumBins());
|
||||
this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins());
|
||||
}
|
||||
|
||||
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
@ -122,7 +120,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
|
||||
void EvaluateSplit(int nid,
|
||||
void EvaluateSplit(const int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
@ -227,7 +225,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
||||
DataLayout data_layout_;
|
||||
|
||||
rabit::Reducer<GHistEntry, GHistEntry::Reduce> histred_;
|
||||
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
|
||||
};
|
||||
|
||||
std::unique_ptr<Builder> builder_;
|
||||
|
||||
@ -19,7 +19,6 @@ namespace tree {
|
||||
DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
||||
|
||||
/*! \brief pruner that prunes a tree after growing finishs */
|
||||
template<typename TStats>
|
||||
class TreeRefresher: public TreeUpdater {
|
||||
public:
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
@ -31,14 +30,13 @@ class TreeRefresher: public TreeUpdater {
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
if (trees.size() == 0) return;
|
||||
const std::vector<GradientPair> &gpair_h = gpair->ConstHostVector();
|
||||
// number of threads
|
||||
// thread temporal space
|
||||
std::vector<std::vector<TStats> > stemp;
|
||||
std::vector<std::vector<GradStats> > stemp;
|
||||
std::vector<RegTree::FVec> fvec_temp;
|
||||
// setup temp space for each thread
|
||||
const int nthread = omp_get_max_threads();
|
||||
fvec_temp.resize(nthread, RegTree::FVec());
|
||||
stemp.resize(nthread, std::vector<TStats>());
|
||||
stemp.resize(nthread, std::vector<GradStats>());
|
||||
#pragma omp parallel
|
||||
{
|
||||
int tid = omp_get_thread_num();
|
||||
@ -46,8 +44,8 @@ class TreeRefresher: public TreeUpdater {
|
||||
for (auto tree : trees) {
|
||||
num_nodes += tree->param.num_nodes;
|
||||
}
|
||||
stemp[tid].resize(num_nodes, TStats(param_));
|
||||
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param_));
|
||||
stemp[tid].resize(num_nodes, GradStats());
|
||||
std::fill(stemp[tid].begin(), stemp[tid].end(), GradStats());
|
||||
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
||||
}
|
||||
// if it is C++11, use lazy evaluation for Allreduce,
|
||||
@ -104,21 +102,22 @@ class TreeRefresher: public TreeUpdater {
|
||||
const std::vector<GradientPair> &gpair,
|
||||
const MetaInfo &info,
|
||||
const bst_uint ridx,
|
||||
TStats *gstats) {
|
||||
GradStats *gstats) {
|
||||
// start from groups that belongs to current data
|
||||
auto pid = static_cast<int>(info.GetRoot(ridx));
|
||||
gstats[pid].Add(gpair, info, ridx);
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
// tranverse tree
|
||||
while (!tree[pid].IsLeaf()) {
|
||||
unsigned split_index = tree[pid].SplitIndex();
|
||||
pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
gstats[pid].Add(gpair, info, ridx);
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
}
|
||||
}
|
||||
inline void Refresh(const TStats *gstats,
|
||||
inline void Refresh(const GradStats *gstats,
|
||||
int nid, RegTree *p_tree) {
|
||||
RegTree &tree = *p_tree;
|
||||
tree.Stat(nid).base_weight = static_cast<bst_float>(gstats[nid].CalcWeight(param_));
|
||||
tree.Stat(nid).base_weight =
|
||||
static_cast<bst_float>(CalcWeight(param_, gstats[nid]));
|
||||
tree.Stat(nid).sum_hess = static_cast<bst_float>(gstats[nid].sum_hess);
|
||||
if (tree[nid].IsLeaf()) {
|
||||
if (param_.refresh_leaf) {
|
||||
@ -126,9 +125,9 @@ class TreeRefresher: public TreeUpdater {
|
||||
}
|
||||
} else {
|
||||
tree.Stat(nid).loss_chg = static_cast<bst_float>(
|
||||
gstats[tree[nid].LeftChild()].CalcGain(param_) +
|
||||
gstats[tree[nid].RightChild()].CalcGain(param_) -
|
||||
gstats[nid].CalcGain(param_));
|
||||
xgboost::tree::CalcGain(param_, gstats[tree[nid].LeftChild()]) +
|
||||
xgboost::tree::CalcGain(param_, gstats[tree[nid].RightChild()]) -
|
||||
xgboost::tree::CalcGain(param_, gstats[nid]));
|
||||
this->Refresh(gstats, tree[nid].LeftChild(), p_tree);
|
||||
this->Refresh(gstats, tree[nid].RightChild(), p_tree);
|
||||
}
|
||||
@ -136,13 +135,13 @@ class TreeRefresher: public TreeUpdater {
|
||||
// training parameter
|
||||
TrainParam param_;
|
||||
// reducer
|
||||
rabit::Reducer<TStats, TStats::Reduce> reducer_;
|
||||
rabit::Reducer<GradStats, GradStats::Reduce> reducer_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh")
|
||||
.describe("Refresher that refreshes the weight and statistics according to data.")
|
||||
.set_body([]() {
|
||||
return new TreeRefresher<GradStats>();
|
||||
return new TreeRefresher();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -83,26 +83,17 @@ class SketchMaker: public BaseMaker {
|
||||
double neg_grad;
|
||||
/*! \brief sum of hessian statistics */
|
||||
double sum_hess;
|
||||
SKStats() = default;
|
||||
// constructor
|
||||
explicit SKStats(const TrainParam ¶m) {
|
||||
this->Clear();
|
||||
}
|
||||
/*! \brief clear the statistics */
|
||||
inline void Clear() {
|
||||
neg_grad = pos_grad = sum_hess = 0.0f;
|
||||
}
|
||||
|
||||
SKStats() : pos_grad{0}, neg_grad{0}, sum_hess{0} {}
|
||||
|
||||
// accumulate statistics
|
||||
inline void Add(const std::vector<GradientPair> &gpair,
|
||||
const MetaInfo &info,
|
||||
bst_uint ridx) {
|
||||
const GradientPair &b = gpair[ridx];
|
||||
if (b.GetGrad() >= 0.0f) {
|
||||
pos_grad += b.GetGrad();
|
||||
void Add(const GradientPair& gpair) {
|
||||
if (gpair.GetGrad() >= 0.0f) {
|
||||
pos_grad += gpair.GetGrad();
|
||||
} else {
|
||||
neg_grad -= b.GetGrad();
|
||||
neg_grad -= gpair.GetGrad();
|
||||
}
|
||||
sum_hess += b.GetHess();
|
||||
sum_hess += gpair.GetHess();
|
||||
}
|
||||
/*! \brief calculate gain of the solution */
|
||||
inline double CalcGain(const TrainParam ¶m) const {
|
||||
|
||||
@ -57,10 +57,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{0.26f, 0.27f}, {0.23f, 0.24f}, {0.27f, 0.28f},
|
||||
{0.57f, 0.59f}, {0.23f, 0.24f}, {0.47f, 0.49f}};
|
||||
|
||||
for (size_t i = 0; i < hist_[nid].size; ++i) {
|
||||
for (size_t i = 0; i < hist_[nid].size(); ++i) {
|
||||
GradientPairPrecise sol = solution[i];
|
||||
ASSERT_NEAR(sol.GetGrad(), hist_[nid].begin[i].sum_grad, kEps);
|
||||
ASSERT_NEAR(sol.GetHess(), hist_[nid].begin[i].sum_hess, kEps);
|
||||
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
|
||||
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user