[TREE] Enable global proposal for faster speed
This commit is contained in:
parent
2f2080a337
commit
ce4d59ed69
@ -268,6 +268,10 @@ class HistMaker: public BaseMaker {
|
||||
|
||||
template<typename TStats>
|
||||
class CQHistMaker: public HistMaker<TStats> {
|
||||
public:
|
||||
CQHistMaker() : cache_dmatrix_(nullptr) {
|
||||
}
|
||||
|
||||
protected:
|
||||
struct HistEntry {
|
||||
typename HistMaker<TStats>::HistUnit hist;
|
||||
@ -290,10 +294,14 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
*/
|
||||
inline void Add(bst_float fv,
|
||||
bst_gpair gstats) {
|
||||
if (fv < hist.cut[istart]) {
|
||||
hist.data[istart].Add(gstats);
|
||||
} else {
|
||||
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
|
||||
CHECK_NE(istart, hist.size);
|
||||
hist.data[istart].Add(gstats);
|
||||
}
|
||||
}
|
||||
};
|
||||
// sketch type used for this
|
||||
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
|
||||
@ -301,7 +309,10 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
void InitWorkSet(DMatrix *p_fmat,
|
||||
const RegTree &tree,
|
||||
std::vector<bst_uint> *p_fset) override {
|
||||
if (p_fmat != cache_dmatrix_) {
|
||||
feat_helper.InitByCol(p_fmat, tree);
|
||||
cache_dmatrix_ = p_fmat;
|
||||
}
|
||||
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
|
||||
}
|
||||
// code to create histogram
|
||||
@ -342,6 +353,9 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
}
|
||||
}
|
||||
// update node statistics.
|
||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||
&thread_stats, &node_stats);
|
||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||
const int nid = this->qexpand[i];
|
||||
const int wid = this->node2workindex[nid];
|
||||
@ -434,9 +448,6 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
}
|
||||
// update node statistics.
|
||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||
&thread_stats, &node_stats);
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
this->wspace.rptr.clear();
|
||||
@ -475,7 +486,6 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
(fset.size() + 1) * this->qexpand.size() + 1);
|
||||
}
|
||||
|
||||
private:
|
||||
inline void UpdateHistCol(const std::vector<bst_gpair> &gpair,
|
||||
const ColBatch::Inst &c,
|
||||
const MetaInfo &info,
|
||||
@ -607,6 +617,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
sbuilder[nid].Finalize(max_size);
|
||||
}
|
||||
}
|
||||
// cached dmatrix where we initialized the feature on.
|
||||
const DMatrix* cache_dmatrix_;
|
||||
// feature helper
|
||||
BaseMaker::FMetaHelper feat_helper;
|
||||
// temp space to map feature id to working index
|
||||
@ -631,6 +643,107 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
// global proposal
|
||||
template<typename TStats>
|
||||
class GlobalProposalHistMaker: public CQHistMaker<TStats> {
|
||||
protected:
|
||||
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<bst_uint> &fset,
|
||||
const RegTree &tree) override {
|
||||
if (this->qexpand.size() == 1 && !this->param.cache_global_proposal) {
|
||||
cached_rptr_.clear();
|
||||
cached_cut_.clear();
|
||||
}
|
||||
if (cached_rptr_.size() == 0) {
|
||||
CHECK_EQ(this->qexpand.size(), 1);
|
||||
CQHistMaker<TStats>::ResetPosAndPropose(gpair, p_fmat, fset, tree);
|
||||
cached_rptr_ = this->wspace.rptr;
|
||||
cached_cut_ = this->wspace.cut;
|
||||
} else {
|
||||
this->wspace.cut.clear();
|
||||
this->wspace.rptr.clear();
|
||||
this->wspace.rptr.push_back(0);
|
||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
|
||||
this->wspace.rptr.push_back(
|
||||
this->wspace.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
|
||||
}
|
||||
this->wspace.cut.insert(this->wspace.cut.end(), cached_cut_.begin(), cached_cut_.end());
|
||||
}
|
||||
CHECK_EQ(this->wspace.rptr.size(),
|
||||
(fset.size() + 1) * this->qexpand.size() + 1);
|
||||
CHECK_EQ(this->wspace.rptr.back(), this->wspace.cut.size());
|
||||
}
|
||||
}
|
||||
|
||||
// code to create histogram
|
||||
void CreateHist(const std::vector<bst_gpair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<bst_uint> &fset,
|
||||
const RegTree &tree) override {
|
||||
const MetaInfo &info = p_fmat->info();
|
||||
// fill in reverse map
|
||||
this->feat2workindex.resize(tree.param.num_feature);
|
||||
this->work_set = fset;
|
||||
std::fill(this->feat2workindex.begin(), this->feat2workindex.end(), -1);
|
||||
for (size_t i = 0; i < fset.size(); ++i) {
|
||||
this->feat2workindex[fset[i]] = static_cast<int>(i);
|
||||
}
|
||||
// start to work
|
||||
this->wspace.Init(this->param, 1);
|
||||
// to gain speedup in recovery
|
||||
{
|
||||
this->thread_hist.resize(this->get_nthread());
|
||||
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->SetDefaultPostion(p_fmat, tree);
|
||||
this->work_set.insert(this->work_set.end(), this->fsplit_set.begin(), this->fsplit_set.end());
|
||||
std::sort(this->work_set.begin(), this->work_set.end());
|
||||
this->work_set.resize(
|
||||
std::unique(this->work_set.begin(), this->work_set.end()) - this->work_set.begin());
|
||||
|
||||
// start accumulating statistics
|
||||
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(this->work_set);
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const ColBatch &batch = iter->Value();
|
||||
// TWOPASS: use the real set + split set in the column iteration.
|
||||
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set, tree);
|
||||
|
||||
// start enumeration
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
int offset = this->feat2workindex[batch.col_index[i]];
|
||||
if (offset >= 0) {
|
||||
this->UpdateHistCol(gpair, batch[i], info, tree,
|
||||
fset, offset,
|
||||
&this->thread_hist[omp_get_thread_num()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update node statistics.
|
||||
this->GetNodeStats(gpair, *p_fmat, tree,
|
||||
&(this->thread_stats), &(this->node_stats));
|
||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||
const int nid = this->qexpand[i];
|
||||
const int wid = this->node2workindex[nid];
|
||||
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
|
||||
.data[0] = this->node_stats[nid];
|
||||
}
|
||||
}
|
||||
this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
|
||||
this->wspace.hset[0].data.size());
|
||||
}
|
||||
|
||||
// cached unit pointer
|
||||
std::vector<unsigned> cached_rptr_;
|
||||
// cached cut value.
|
||||
std::vector<bst_float> cached_cut_;
|
||||
};
|
||||
|
||||
|
||||
template<typename TStats>
|
||||
class QuantileHistMaker: public HistMaker<TStats> {
|
||||
@ -763,5 +876,11 @@ XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
|
||||
.set_body([]() {
|
||||
return new CQHistMaker<GradStats>();
|
||||
});
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker")
|
||||
.describe("Tree constructor that uses approximate global proposal of histogram construction.")
|
||||
.set_body([]() {
|
||||
return new GlobalProposalHistMaker<GradStats>();
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user