Testing hist_util (#5251)

* Rank tests

* Remove categorical split specialisation

* Extend tests to multiple features, switch to WQSketch

* Add tests for SparseCuts

* Add external memory quantile tests, fix some existing tests
This commit is contained in:
Rory Mitchell
2020-02-14 14:36:43 +13:00
committed by GitHub
parent 911a902835
commit 24ad9dec0b
10 changed files with 354 additions and 93 deletions

View File

@@ -70,7 +70,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
bool const use_group_ind,
uint32_t beg_col, uint32_t end_col,
uint32_t thread_id) {
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
CHECK_GE(end_col, beg_col);
constexpr float kFactor = 8;
@@ -80,7 +79,7 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) {
// Using a local variable makes things easier, but at the cost of memory trashing.
WXQSketch sketch;
WQSketch sketch;
common::Span<xgboost::Entry const> const column = page[col_id];
uint32_t const n_bins = std::min(static_cast<uint32_t>(column.size()),
max_num_bins);
@@ -104,18 +103,18 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
sketch.Push(entry.fvalue, info.GetWeight(weight_ind));
}
WXQSketch::SummaryContainer out_summary;
WQSketch::SummaryContainer out_summary;
sketch.GetSummary(&out_summary);
WXQSketch::SummaryContainer summary;
summary.Reserve(n_bins);
summary.SetPrune(out_summary, n_bins);
WQSketch::SummaryContainer summary;
summary.Reserve(n_bins + 1);
summary.SetPrune(out_summary, n_bins + 1);
// Can be use data[1] as the min values so that we don't need to
// store another array?
float mval = summary.data[0].value;
p_cuts_->min_vals_[col_id - beg_col] = mval - (fabs(mval) + 1e-5);
this->AddCutPoint(summary);
this->AddCutPoint(summary, max_num_bins);
bst_float cpt = (summary.size > 0) ?
summary.data[summary.size - 1].value :
@@ -234,7 +233,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
// safe factor for better accuracy
constexpr int kFactor = 8;
std::vector<WXQSketch> sketchs;
std::vector<WQSketch> sketchs;
const int nthread = omp_get_max_threads();
@@ -292,34 +291,34 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
}
void DenseCuts::Init
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins) {
monitor_.Start(__func__);
std::vector<WXQSketch>& sketchs = *in_sketchs;
std::vector<WQSketch>& sketchs = *in_sketchs;
constexpr int kFactor = 8;
// gather the histogram data
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
std::vector<WXQSketch::SummaryContainer> summary_array;
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
std::vector<WQSketch::SummaryContainer> summary_array;
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) {
WXQSketch::SummaryContainer out;
WQSketch::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array[i].Reserve(max_num_bins * kFactor);
summary_array[i].SetPrune(out, max_num_bins * kFactor);
}
CHECK_EQ(summary_array.size(), in_sketchs->size());
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
// we need to move this allreduce before loadcheckpoint call in future
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
p_cuts_->min_vals_.resize(sketchs.size());
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
WXQSketch::SummaryContainer a;
a.Reserve(max_num_bins);
a.SetPrune(summary_array[fid], max_num_bins);
WQSketch::SummaryContainer a;
a.Reserve(max_num_bins + 1);
a.SetPrune(summary_array[fid], max_num_bins + 1);
const bst_float mval = a.data[0].value;
p_cuts_->min_vals_[fid] = mval - (fabs(mval) + 1e-5);
AddCutPoint(a);
AddCutPoint(a, max_num_bins);
// push a value that is greater than anything
const bst_float cpt
= (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_[fid];