parent
785094db53
commit
4b892c2b30
@ -746,130 +746,6 @@ class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
|||||||
std::vector<bst_float> cached_cut_;
|
std::vector<bst_float> cached_cut_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template<typename TStats>
|
|
||||||
class QuantileHistMaker: public HistMaker<TStats> {
|
|
||||||
protected:
|
|
||||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
|
||||||
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
|
|
||||||
DMatrix *p_fmat,
|
|
||||||
const std::vector <bst_uint> &fset,
|
|
||||||
const RegTree &tree) override {
|
|
||||||
const MetaInfo &info = p_fmat->Info();
|
|
||||||
// initialize the data structure
|
|
||||||
const int nthread = omp_get_max_threads();
|
|
||||||
sketchs_.resize(this->qexpand_.size() * tree.param.num_feature);
|
|
||||||
for (size_t i = 0; i < sketchs_.size(); ++i) {
|
|
||||||
sketchs_[i].Init(info.num_row_, this->param_.sketch_eps);
|
|
||||||
}
|
|
||||||
// start accumulating statistics
|
|
||||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
|
||||||
// parallel convert to column major format
|
|
||||||
common::ParallelGroupBuilder<Entry>
|
|
||||||
builder(&col_ptr_, &col_data_, &thread_col_ptr_);
|
|
||||||
builder.InitBudget(tree.param.num_feature, nthread);
|
|
||||||
|
|
||||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.Size());
|
|
||||||
#pragma omp parallel for schedule(static)
|
|
||||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
|
||||||
SparsePage::Inst inst = batch[i];
|
|
||||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
|
||||||
int nid = this->position_[ridx];
|
|
||||||
if (nid >= 0) {
|
|
||||||
if (!tree[nid].IsLeaf()) {
|
|
||||||
this->position_[ridx] = nid = HistMaker<TStats>::NextLevel(inst, tree, nid);
|
|
||||||
}
|
|
||||||
if (this->node2workindex_[nid] < 0) {
|
|
||||||
this->position_[ridx] = ~nid;
|
|
||||||
} else {
|
|
||||||
for (auto& ins : inst) {
|
|
||||||
builder.AddBudget(ins.index, omp_get_thread_num());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
builder.InitStorage();
|
|
||||||
#pragma omp parallel for schedule(static)
|
|
||||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
|
||||||
SparsePage::Inst inst = batch[i];
|
|
||||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
|
||||||
const int nid = this->position_[ridx];
|
|
||||||
if (nid >= 0) {
|
|
||||||
for (auto& ins : inst) {
|
|
||||||
builder.Push(ins.index,
|
|
||||||
Entry(nid, ins.fvalue),
|
|
||||||
omp_get_thread_num());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// start putting things into sketch
|
|
||||||
const bst_omp_uint nfeat = col_ptr_.size() - 1;
|
|
||||||
#pragma omp parallel for schedule(dynamic, 1)
|
|
||||||
for (bst_omp_uint k = 0; k < nfeat; ++k) {
|
|
||||||
for (size_t i = col_ptr_[k]; i < col_ptr_[k+1]; ++i) {
|
|
||||||
const Entry &e = col_data_[i];
|
|
||||||
const int wid = this->node2workindex_[e.index];
|
|
||||||
sketchs_[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].GetHess());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// setup maximum size
|
|
||||||
unsigned max_size = this->param_.MaxSketchSize();
|
|
||||||
// synchronize sketch
|
|
||||||
summary_array_.resize(sketchs_.size());
|
|
||||||
for (size_t i = 0; i < sketchs_.size(); ++i) {
|
|
||||||
common::WQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
|
||||||
sketchs_[i].GetSummary(&out);
|
|
||||||
summary_array_[i].Reserve(max_size);
|
|
||||||
summary_array_[i].SetPrune(out, max_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
|
||||||
sreducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size());
|
|
||||||
// now we get the final result of sketch, setup the cut
|
|
||||||
this->wspace_.cut.clear();
|
|
||||||
this->wspace_.rptr.clear();
|
|
||||||
this->wspace_.rptr.push_back(0);
|
|
||||||
for (size_t wid = 0; wid < this->qexpand_.size(); ++wid) {
|
|
||||||
for (int fid = 0; fid < tree.param.num_feature; ++fid) {
|
|
||||||
const WXQSketch::Summary &a = summary_array_[wid * tree.param.num_feature + fid];
|
|
||||||
for (size_t i = 1; i < a.size; ++i) {
|
|
||||||
bst_float cpt = a.data[i].value - kRtEps;
|
|
||||||
if (i == 1 || cpt > this->wspace_.cut.back()) {
|
|
||||||
this->wspace_.cut.push_back(cpt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// push a value that is greater than anything
|
|
||||||
if (a.size != 0) {
|
|
||||||
bst_float cpt = a.data[a.size - 1].value;
|
|
||||||
// this must be bigger than last value in a scale
|
|
||||||
bst_float last = cpt + fabs(cpt) + kRtEps;
|
|
||||||
this->wspace_.cut.push_back(last);
|
|
||||||
}
|
|
||||||
this->wspace_.rptr.push_back(this->wspace_.cut.size());
|
|
||||||
}
|
|
||||||
// reserve last value for global statistics
|
|
||||||
this->wspace_.cut.push_back(0.0f);
|
|
||||||
this->wspace_.rptr.push_back(this->wspace_.cut.size());
|
|
||||||
}
|
|
||||||
CHECK_EQ(this->wspace_.rptr.size(),
|
|
||||||
(tree.param.num_feature + 1) * this->qexpand_.size() + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// summary array
|
|
||||||
std::vector<WXQSketch::SummaryContainer> summary_array_;
|
|
||||||
// reducer for summary
|
|
||||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer_;
|
|
||||||
// local temp column data structure
|
|
||||||
std::vector<size_t> col_ptr_;
|
|
||||||
// local storage of column data
|
|
||||||
std::vector<Entry> col_data_;
|
|
||||||
std::vector<std::vector<size_t> > thread_col_ptr_;
|
|
||||||
// per node, per feature sketch
|
|
||||||
std::vector<common::WQuantileSketch<bst_float, bst_float> > sketchs_;
|
|
||||||
};
|
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
|
||||||
.describe("Tree constructor that uses approximate histogram construction.")
|
.describe("Tree constructor that uses approximate histogram construction.")
|
||||||
.set_body([]() {
|
.set_body([]() {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user