parent
785094db53
commit
4b892c2b30
@ -746,130 +746,6 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
||||
std::vector<bst_float> cached_cut_;
|
||||
};
|
||||
|
||||
|
||||
template<typename TStats>
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
protected:
|
||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector <bst_uint> &fset,
|
||||
const RegTree &tree) override {
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// initialize the data structure
|
||||
const int nthread = omp_get_max_threads();
|
||||
sketchs_.resize(this->qexpand_.size() * tree.param.num_feature);
|
||||
for (size_t i = 0; i < sketchs_.size(); ++i) {
|
||||
sketchs_[i].Init(info.num_row_, this->param_.sketch_eps);
|
||||
}
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
// parallel convert to column major format
|
||||
common::ParallelGroupBuilder<Entry>
|
||||
builder(&col_ptr_, &col_data_, &thread_col_ptr_);
|
||||
builder.InitBudget(tree.param.num_feature, nthread);
|
||||
|
||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.Size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
SparsePage::Inst inst = batch[i];
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
int nid = this->position_[ridx];
|
||||
if (nid >= 0) {
|
||||
if (!tree[nid].IsLeaf()) {
|
||||
this->position_[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
||||
}
|
||||
if (this->node2workindex_[nid] < 0) {
|
||||
this->position_[ridx] = ~nid;
|
||||
} else {
|
||||
for (auto& ins : inst) {
|
||||
builder.AddBudget(ins.index, omp_get_thread_num());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
SparsePage::Inst inst = batch[i];
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
const int nid = this->position_[ridx];
|
||||
if (nid >= 0) {
|
||||
for (auto& ins : inst) {
|
||||
builder.Push(ins.index,
|
||||
Entry(nid, ins.fvalue),
|
||||
omp_get_thread_num());
|
||||
}
|
||||
}
|
||||
}
|
||||
// start putting things into sketch
|
||||
const bst_omp_uint nfeat = col_ptr_.size() - 1;
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint k = 0; k < nfeat; ++k) {
|
||||
for (size_t i = col_ptr_[k]; i < col_ptr_[k+1]; ++i) {
|
||||
const Entry &e = col_data_[i];
|
||||
const int wid = this->node2workindex_[e.index];
|
||||
sketchs_[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].GetHess());
|
||||
}
|
||||
}
|
||||
}
|
||||
// setup maximum size
|
||||
unsigned max_size = this->param_.MaxSketchSize();
|
||||
// synchronize sketch
|
||||
summary_array_.resize(sketchs_.size());
|
||||
for (size_t i = 0; i < sketchs_.size(); ++i) {
|
||||
common::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs_[i].GetSummary(&out);
|
||||
summary_array_[i].Reserve(max_size);
|
||||
summary_array_[i].SetPrune(out, max_size);
|
||||
}
|
||||
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sreducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size());
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace_.cut.clear();
|
||||
this->wspace_.rptr.clear();
|
||||
this->wspace_.rptr.push_back(0);
|
||||
for (size_t wid = 0; wid < this->qexpand_.size(); ++wid) {
|
||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
||||
const WXQSketch::Summary &a = summary_array_[wid * tree.param.num_feature + fid];
|
||||
for (size_t i = 1; i < a.size; ++i) {
|
||||
bst_float cpt = a.data[i].value - kRtEps;
|
||||
if (i == 1 || cpt > this->wspace_.cut.back()) {
|
||||
this->wspace_.cut.push_back(cpt);
|
||||
}
|
||||
}
|
||||
// push a value that is greater than anything
|
||||
if (a.size != 0) {
|
||||
bst_float cpt = a.data[a.size - 1].value;
|
||||
// this must be bigger than last value in a scale
|
||||
bst_float last = cpt + fabs(cpt) + kRtEps;
|
||||
this->wspace_.cut.push_back(last);
|
||||
}
|
||||
this->wspace_.rptr.push_back(this->wspace_.cut.size());
|
||||
}
|
||||
// reserve last value for global statistics
|
||||
this->wspace_.cut.push_back(0.0f);
|
||||
this->wspace_.rptr.push_back(this->wspace_.cut.size());
|
||||
}
|
||||
CHECK_EQ(this->wspace_.rptr.size(),
|
||||
(tree.param.num_feature + 1) * this->qexpand_.size() + 1);
|
||||
}
|
||||
|
||||
private:
|
||||
// summary array
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array_;
|
||||
// reducer for summary
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer_;
|
||||
// local temp column data structure
|
||||
std::vector<size_t> col_ptr_;
|
||||
// local storage of column data
|
||||
std::vector<Entry> col_data_;
|
||||
std::vector<std::vector<size_t> > thread_col_ptr_;
|
||||
// per node, per feature sketch
|
||||
std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||
.describe("Tree constructor that uses approximate histogram construction.")
|
||||
.set_body([]() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user