Remove obsoleted QuantileHistMaker. (#3761)

Fix #3755.
This commit is contained in:
trivialfis 2018-10-07 07:22:15 +13:00 committed by Philip Hyunsu Cho
parent 785094db53
commit 4b892c2b30

View File

@ -746,130 +746,6 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
std::vector<bst_float> cached_cut_; std::vector<bst_float> cached_cut_;
}; };
template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> {
protected:
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
const std::vector <bst_uint> &fset,
const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info();
// initialize the data structure
const int nthread = omp_get_max_threads();
sketchs_.resize(this->qexpand_.size() * tree.param.num_feature);
for (size_t i = 0; i < sketchs_.size(); ++i) {
sketchs_[i].Init(info.num_row_, this->param_.sketch_eps);
}
// start accumulating statistics
for (const auto &batch : p_fmat->GetRowBatches()) {
// parallel convert to column major format
common::ParallelGroupBuilder<Entry>
builder(&col_ptr_, &col_data_, &thread_col_ptr_);
builder.InitBudget(tree.param.num_feature, nthread);
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.Size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
SparsePage::Inst inst = batch[i];
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
int nid = this->position_[ridx];
if (nid >= 0) {
if (!tree[nid].IsLeaf()) {
this->position_[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
}
if (this->node2workindex_[nid] < 0) {
this->position_[ridx] = ~nid;
} else {
for (auto& ins : inst) {
builder.AddBudget(ins.index, omp_get_thread_num());
}
}
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nbatch; ++i) {
SparsePage::Inst inst = batch[i];
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
const int nid = this->position_[ridx];
if (nid >= 0) {
for (auto& ins : inst) {
builder.Push(ins.index,
Entry(nid, ins.fvalue),
omp_get_thread_num());
}
}
}
// start putting things into sketch
const bst_omp_uint nfeat = col_ptr_.size() - 1;
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint k = 0; k < nfeat; ++k) {
for (size_t i = col_ptr_[k]; i < col_ptr_[k+1]; ++i) {
const Entry &e = col_data_[i];
const int wid = this->node2workindex_[e.index];
sketchs_[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].GetHess());
}
}
}
// setup maximum size
unsigned max_size = this->param_.MaxSketchSize();
// synchronize sketch
summary_array_.resize(sketchs_.size());
for (size_t i = 0; i < sketchs_.size(); ++i) {
common::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs_[i].GetSummary(&out);
summary_array_[i].Reserve(max_size);
summary_array_[i].SetPrune(out, max_size);
}
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sreducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size());
// now we get the final result of sketch, setup the cut
this->wspace_.cut.clear();
this->wspace_.rptr.clear();
this->wspace_.rptr.push_back(0);
for (size_t wid = 0; wid < this->qexpand_.size(); ++wid) {
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
const WXQSketch::Summary &a = summary_array_[wid * tree.param.num_feature + fid];
for (size_t i = 1; i < a.size; ++i) {
bst_float cpt = a.data[i].value - kRtEps;
if (i == 1 || cpt > this->wspace_.cut.back()) {
this->wspace_.cut.push_back(cpt);
}
}
// push a value that is greater than anything
if (a.size != 0) {
bst_float cpt = a.data[a.size - 1].value;
// this must be bigger than last value in a scale
bst_float last = cpt + fabs(cpt) + kRtEps;
this->wspace_.cut.push_back(last);
}
this->wspace_.rptr.push_back(this->wspace_.cut.size());
}
// reserve last value for global statistics
this->wspace_.cut.push_back(0.0f);
this->wspace_.rptr.push_back(this->wspace_.cut.size());
}
CHECK_EQ(this->wspace_.rptr.size(),
(tree.param.num_feature + 1) * this->qexpand_.size() + 1);
}
private:
// summary array
std::vector<WXQSketch::SummaryContainer> summary_array_;
// reducer for summary
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer_;
// local temp column data structure
std::vector<size_t> col_ptr_;
// local storage of column data
std::vector<Entry> col_data_;
std::vector<std::vector<size_t> > thread_col_ptr_;
// per node, per feature sketch
std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs_;
};
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
.describe("Tree constructor that uses approximate histogram construction.") .describe("Tree constructor that uses approximate histogram construction.")
.set_body([]() { .set_body([]() {