Testing hist_util (#5251)
* Rank tests * Remove categorical split specialisation * Extend tests to multiple features, switch to WQSketch * Add tests for SparseCuts * Add external memory quantile tests, fix some existing tests
This commit is contained in:
@@ -101,6 +101,7 @@ struct SimpleArray {
|
||||
using GHistIndexRow = Span<uint32_t const>;
|
||||
|
||||
// A CSC matrix representing histogram cuts, used in CPU quantile hist.
|
||||
// The cut values represent upper bounds of bins containing approximately equal numbers of elements
|
||||
class HistogramCuts {
|
||||
// Using friends to avoid creating a virtual class, since HistogramCuts is used as value
|
||||
// object in many places.
|
||||
@@ -147,7 +148,9 @@ class HistogramCuts {
|
||||
|
||||
size_t TotalBins() const { return cut_ptrs_.back(); }
|
||||
|
||||
BinIdx SearchBin(float value, uint32_t column_id) {
|
||||
// Return the index of a cut point that is strictly greater than the input
|
||||
// value, or the last available index if none exists
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
auto beg = cut_ptrs_.at(column_id);
|
||||
auto end = cut_ptrs_.at(column_id + 1);
|
||||
auto it = std::upper_bound(cut_values_.cbegin() + beg, cut_values_.cbegin() + end, value);
|
||||
@@ -171,7 +174,7 @@ class HistogramCuts {
|
||||
*/
|
||||
class CutsBuilder {
|
||||
public:
|
||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||
using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
|
||||
|
||||
protected:
|
||||
HistogramCuts* p_cuts_;
|
||||
@@ -195,21 +198,12 @@ class CutsBuilder {
|
||||
return group_ind;
|
||||
}
|
||||
|
||||
void AddCutPoint(WXQSketch::SummaryContainer const& summary) {
|
||||
if (summary.size > 1 && summary.size <= 16) {
|
||||
/* specialized code categorial / ordinal data -- use midpoints */
|
||||
for (size_t i = 1; i < summary.size; ++i) {
|
||||
bst_float cpt = (summary.data[i].value + summary.data[i - 1].value) / 2.0f;
|
||||
if (i == 1 || cpt > p_cuts_->cut_values_.back()) {
|
||||
p_cuts_->cut_values_.push_back(cpt);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 2; i < summary.size; ++i) {
|
||||
bst_float cpt = summary.data[i - 1].value;
|
||||
if (i == 2 || cpt > p_cuts_->cut_values_.back()) {
|
||||
p_cuts_->cut_values_.push_back(cpt);
|
||||
}
|
||||
void AddCutPoint(WQSketch::SummaryContainer const& summary, int max_bin) {
|
||||
int required_cuts = std::min(static_cast<int>(summary.size), max_bin);
|
||||
for (size_t i = 1; i < required_cuts; ++i) {
|
||||
bst_float cpt = summary.data[i].value;
|
||||
if (i == 1 || cpt > p_cuts_->cut_values_.back()) {
|
||||
p_cuts_->cut_values_.push_back(cpt);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -250,7 +244,7 @@ class DenseCuts : public CutsBuilder {
|
||||
CutsBuilder(container) {
|
||||
monitor_.Init(__FUNCTION__);
|
||||
}
|
||||
void Init(std::vector<WXQSketch>* sketchs, uint32_t max_num_bins);
|
||||
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins);
|
||||
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user