Testing hist_util (#5251)
* Rank tests * Remove categorical split specialisation * Extend tests to multiple features, switch to WQSketch * Add tests for SparseCuts * Add external memory quantile tests, fix some existing tests
This commit is contained in:
@@ -70,7 +70,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
bool const use_group_ind,
|
||||
uint32_t beg_col, uint32_t end_col,
|
||||
uint32_t thread_id) {
|
||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||
CHECK_GE(end_col, beg_col);
|
||||
constexpr float kFactor = 8;
|
||||
|
||||
@@ -80,7 +79,7 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
|
||||
for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) {
|
||||
// Using a local variable makes things easier, but at the cost of memory trashing.
|
||||
WXQSketch sketch;
|
||||
WQSketch sketch;
|
||||
common::Span<xgboost::Entry const> const column = page[col_id];
|
||||
uint32_t const n_bins = std::min(static_cast<uint32_t>(column.size()),
|
||||
max_num_bins);
|
||||
@@ -104,18 +103,18 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
sketch.Push(entry.fvalue, info.GetWeight(weight_ind));
|
||||
}
|
||||
|
||||
WXQSketch::SummaryContainer out_summary;
|
||||
WQSketch::SummaryContainer out_summary;
|
||||
sketch.GetSummary(&out_summary);
|
||||
WXQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_bins);
|
||||
summary.SetPrune(out_summary, n_bins);
|
||||
WQSketch::SummaryContainer summary;
|
||||
summary.Reserve(n_bins + 1);
|
||||
summary.SetPrune(out_summary, n_bins + 1);
|
||||
|
||||
// Can be use data[1] as the min values so that we don't need to
|
||||
// store another array?
|
||||
float mval = summary.data[0].value;
|
||||
p_cuts_->min_vals_[col_id - beg_col] = mval - (fabs(mval) + 1e-5);
|
||||
|
||||
this->AddCutPoint(summary);
|
||||
this->AddCutPoint(summary, max_num_bins);
|
||||
|
||||
bst_float cpt = (summary.size > 0) ?
|
||||
summary.data[summary.size - 1].value :
|
||||
@@ -234,7 +233,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
|
||||
// safe factor for better accuracy
|
||||
constexpr int kFactor = 8;
|
||||
std::vector<WXQSketch> sketchs;
|
||||
std::vector<WQSketch> sketchs;
|
||||
|
||||
const int nthread = omp_get_max_threads();
|
||||
|
||||
@@ -292,34 +291,34 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
}
|
||||
|
||||
void DenseCuts::Init
|
||||
(std::vector<WXQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<WXQSketch>& sketchs = *in_sketchs;
|
||||
std::vector<WQSketch>& sketchs = *in_sketchs;
|
||||
constexpr int kFactor = 8;
|
||||
// gather the histogram data
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WQSketch::SummaryContainer> summary_array;
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
WXQSketch::SummaryContainer out;
|
||||
WQSketch::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].Reserve(max_num_bins * kFactor);
|
||||
summary_array[i].SetPrune(out, max_num_bins * kFactor);
|
||||
}
|
||||
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
|
||||
// we need to move this allreduce before loadcheckpoint call in future
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
p_cuts_->min_vals_.resize(sketchs.size());
|
||||
|
||||
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
|
||||
WXQSketch::SummaryContainer a;
|
||||
a.Reserve(max_num_bins);
|
||||
a.SetPrune(summary_array[fid], max_num_bins);
|
||||
WQSketch::SummaryContainer a;
|
||||
a.Reserve(max_num_bins + 1);
|
||||
a.SetPrune(summary_array[fid], max_num_bins + 1);
|
||||
const bst_float mval = a.data[0].value;
|
||||
p_cuts_->min_vals_[fid] = mval - (fabs(mval) + 1e-5);
|
||||
AddCutPoint(a);
|
||||
AddCutPoint(a, max_num_bins);
|
||||
// push a value that is greater than anything
|
||||
const bst_float cpt
|
||||
= (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_[fid];
|
||||
|
||||
Reference in New Issue
Block a user