Clean up training code. (#3825)
* Remove GHistRow, GHistEntry, GHistIndexRow. * Remove kSimpleStats. * Remove CheckInfo, SetLeafVec in GradStats and in SKStats. * Clean up the GradStats. * Cleanup calcgain. * Move LossChangeMissing out of common. * Remove [] operator from GHistIndexBlock.
This commit is contained in:
@@ -16,45 +16,8 @@
|
||||
#include "../include/rabit/rabit.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
namespace common {
|
||||
|
||||
/*! \brief sums of gradient statistics corresponding to a histogram bin */
|
||||
struct GHistEntry {
|
||||
/*! \brief sum of first-order gradient statistics */
|
||||
double sum_grad{0};
|
||||
/*! \brief sum of second-order gradient statistics */
|
||||
double sum_hess{0};
|
||||
|
||||
GHistEntry() = default;
|
||||
|
||||
inline void Clear() {
|
||||
sum_grad = sum_hess = 0;
|
||||
}
|
||||
|
||||
/*! \brief add a GradientPair to the sum */
|
||||
inline void Add(const GradientPair& e) {
|
||||
sum_grad += e.GetGrad();
|
||||
sum_hess += e.GetHess();
|
||||
}
|
||||
|
||||
/*! \brief add a GHistEntry to the sum */
|
||||
inline void Add(const GHistEntry& e) {
|
||||
sum_grad += e.sum_grad;
|
||||
sum_hess += e.sum_hess;
|
||||
}
|
||||
|
||||
inline static void Reduce(GHistEntry& a, const GHistEntry& b) { // NOLINT(*)
|
||||
a.Add(b);
|
||||
}
|
||||
|
||||
/*! \brief set sum to be difference of two GHistEntry's */
|
||||
inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) {
|
||||
sum_grad = a.sum_grad - b.sum_grad;
|
||||
sum_hess = a.sum_hess - b.sum_hess;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Cut configuration for all the features. */
|
||||
struct HistCutMatrix {
|
||||
/*! \brief Unit pointer to rows by element position */
|
||||
@@ -83,15 +46,7 @@ void DeviceSketch
|
||||
* \brief A single row in global histogram index.
|
||||
* Directly represent the global index in the histogram entry.
|
||||
*/
|
||||
struct GHistIndexRow {
|
||||
/*! \brief The index of the histogram */
|
||||
const uint32_t* index;
|
||||
/*! \brief The size of the histogram */
|
||||
size_t size;
|
||||
GHistIndexRow() = default;
|
||||
GHistIndexRow(const uint32_t* index, size_t size)
|
||||
: index(index), size(size) {}
|
||||
};
|
||||
using GHistIndexRow = Span<uint32_t const>;
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
@@ -111,7 +66,9 @@ struct GHistIndexMatrix {
|
||||
void Init(DMatrix* p_fmat, int max_num_bins);
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](size_t i) const {
|
||||
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
|
||||
return {&index[0] + row_ptr[i],
|
||||
static_cast<GHistIndexRow::index_type>(
|
||||
row_ptr[i + 1] - row_ptr[i])};
|
||||
}
|
||||
inline void GetFeatureCounts(size_t* counts) const {
|
||||
auto nfeature = cut.row_ptr.size() - 1;
|
||||
@@ -134,11 +91,6 @@ struct GHistIndexBlock {
|
||||
|
||||
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
|
||||
: row_ptr(row_ptr), index(index) {}
|
||||
|
||||
// get i-th row
|
||||
inline GHistIndexRow operator[](size_t i) const {
|
||||
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
|
||||
}
|
||||
};
|
||||
|
||||
class ColumnMatrix;
|
||||
@@ -171,21 +123,12 @@ class GHistIndexBlockMatrix {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief histogram of gradient statistics for a single node.
|
||||
* Consists of multiple GHistEntry's, each entry showing total graident statistics
|
||||
* \brief histogram of graident statistics for a single node.
|
||||
* Consists of multiple GradStats, each entry showing total graident statistics
|
||||
* for that particular bin
|
||||
* Uses global bin id so as to represent all features simultaneously
|
||||
*/
|
||||
struct GHistRow {
|
||||
/*! \brief base pointer to first entry */
|
||||
GHistEntry* begin;
|
||||
/*! \brief number of entries */
|
||||
uint32_t size;
|
||||
|
||||
GHistRow() = default;
|
||||
GHistRow(GHistEntry* begin, uint32_t size)
|
||||
: begin(begin), size(size) {}
|
||||
};
|
||||
using GHistRow = Span<tree::GradStats>;
|
||||
|
||||
/*!
|
||||
* \brief histogram of gradient statistics for multiple nodes
|
||||
@@ -193,27 +136,29 @@ struct GHistRow {
|
||||
class HistCollection {
|
||||
public:
|
||||
// access histogram for i-th node
|
||||
inline GHistRow operator[](bst_uint nid) const {
|
||||
GHistRow operator[](bst_uint nid) const {
|
||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
CHECK_NE(row_ptr_[nid], kMax);
|
||||
return {const_cast<GHistEntry*>(dmlc::BeginPtr(data_) + row_ptr_[nid]), nbins_};
|
||||
tree::GradStats* ptr =
|
||||
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
|
||||
return {ptr, nbins_};
|
||||
}
|
||||
|
||||
// have we computed a histogram for i-th node?
|
||||
inline bool RowExists(bst_uint nid) const {
|
||||
bool RowExists(bst_uint nid) const {
|
||||
const uint32_t k_max = std::numeric_limits<uint32_t>::max();
|
||||
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
|
||||
}
|
||||
|
||||
// initialize histogram collection
|
||||
inline void Init(uint32_t nbins) {
|
||||
void Init(uint32_t nbins) {
|
||||
nbins_ = nbins;
|
||||
row_ptr_.clear();
|
||||
data_.clear();
|
||||
}
|
||||
|
||||
// create an empty histogram for i-th node
|
||||
inline void AddHistRow(bst_uint nid) {
|
||||
void AddHistRow(bst_uint nid) {
|
||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
||||
if (nid >= row_ptr_.size()) {
|
||||
row_ptr_.resize(nid + 1, kMax);
|
||||
@@ -228,7 +173,7 @@ class HistCollection {
|
||||
/*! \brief number of all bins over all features */
|
||||
uint32_t nbins_;
|
||||
|
||||
std::vector<GHistEntry> data_;
|
||||
std::vector<tree::GradStats> data_;
|
||||
|
||||
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
|
||||
std::vector<size_t> row_ptr_;
|
||||
@@ -268,8 +213,8 @@ class GHistBuilder {
|
||||
size_t nthread_;
|
||||
/*! \brief number of all bins over all features */
|
||||
uint32_t nbins_;
|
||||
std::vector<GHistEntry> data_;
|
||||
std::vector<size_t> thread_init_;
|
||||
std::vector<tree::GradStats> data_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user