[TREE] Enable global proposal for faster speed

This commit is contained in:
tqchen 2016-01-28 17:16:38 -08:00
parent 2f2080a337
commit ce4d59ed69
2 changed files with 141 additions and 22 deletions

View File

@ -268,6 +268,10 @@ class HistMaker: public BaseMaker {
template<typename TStats>
class CQHistMaker: public HistMaker<TStats> {
public:
CQHistMaker() : cache_dmatrix_(nullptr) {
}
protected:
struct HistEntry {
typename HistMaker<TStats>::HistUnit hist;
@ -290,10 +294,14 @@ class CQHistMaker: public HistMaker<TStats> {
*/
inline void Add(bst_float fv,
bst_gpair gstats) {
if (fv < hist.cut[istart]) {
hist.data[istart].Add(gstats);
} else {
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
CHECK_NE(istart, hist.size);
hist.data[istart].Add(gstats);
}
}
};
// sketch type used for this
typedef common::WXQuantileSketch<bst_float, bst_float> WXQSketch;
@ -301,7 +309,10 @@ class CQHistMaker: public HistMaker<TStats> {
void InitWorkSet(DMatrix *p_fmat,
const RegTree &tree,
std::vector<bst_uint> *p_fset) override {
if (p_fmat != cache_dmatrix_) {
feat_helper.InitByCol(p_fmat, tree);
cache_dmatrix_ = p_fmat;
}
feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
}
// code to create histogram
@ -342,6 +353,9 @@ class CQHistMaker: public HistMaker<TStats> {
}
}
}
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&thread_stats, &node_stats);
for (size_t i = 0; i < this->qexpand.size(); ++i) {
const int nid = this->qexpand[i];
const int wid = this->node2workindex[nid];
@ -434,9 +448,6 @@ class CQHistMaker: public HistMaker<TStats> {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
}
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&thread_stats, &node_stats);
// now we get the final result of sketch, setup the cut
this->wspace.cut.clear();
this->wspace.rptr.clear();
@ -475,7 +486,6 @@ class CQHistMaker: public HistMaker<TStats> {
(fset.size() + 1) * this->qexpand.size() + 1);
}
private:
inline void UpdateHistCol(const std::vector<bst_gpair> &gpair,
const ColBatch::Inst &c,
const MetaInfo &info,
@ -607,6 +617,8 @@ class CQHistMaker: public HistMaker<TStats> {
sbuilder[nid].Finalize(max_size);
}
}
// cached dmatrix where we initialized the feature on.
const DMatrix* cache_dmatrix_;
// feature helper
BaseMaker::FMetaHelper feat_helper;
// temp space to map feature id to working index
@ -631,6 +643,107 @@ class CQHistMaker: public HistMaker<TStats> {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
};
// global proposal
template<typename TStats>
class GlobalProposalHistMaker: public CQHistMaker<TStats> {
protected:
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_uint> &fset,
const RegTree &tree) override {
if (this->qexpand.size() == 1 && !this->param.cache_global_proposal) {
cached_rptr_.clear();
cached_cut_.clear();
}
if (cached_rptr_.size() == 0) {
CHECK_EQ(this->qexpand.size(), 1);
CQHistMaker<TStats>::ResetPosAndPropose(gpair, p_fmat, fset, tree);
cached_rptr_ = this->wspace.rptr;
cached_cut_ = this->wspace.cut;
} else {
this->wspace.cut.clear();
this->wspace.rptr.clear();
this->wspace.rptr.push_back(0);
for (size_t i = 0; i < this->qexpand.size(); ++i) {
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
this->wspace.rptr.push_back(
this->wspace.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
}
this->wspace.cut.insert(this->wspace.cut.end(), cached_cut_.begin(), cached_cut_.end());
}
CHECK_EQ(this->wspace.rptr.size(),
(fset.size() + 1) * this->qexpand.size() + 1);
CHECK_EQ(this->wspace.rptr.back(), this->wspace.cut.size());
}
}
// code to create histogram
void CreateHist(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_uint> &fset,
const RegTree &tree) override {
const MetaInfo &info = p_fmat->info();
// fill in reverse map
this->feat2workindex.resize(tree.param.num_feature);
this->work_set = fset;
std::fill(this->feat2workindex.begin(), this->feat2workindex.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
this->feat2workindex[fset[i]] = static_cast<int>(i);
}
// start to work
this->wspace.Init(this->param, 1);
// to gain speedup in recovery
{
this->thread_hist.resize(this->get_nthread());
// TWOPASS: use the real set + split set in the column iteration.
this->SetDefaultPostion(p_fmat, tree);
this->work_set.insert(this->work_set.end(), this->fsplit_set.begin(), this->fsplit_set.end());
std::sort(this->work_set.begin(), this->work_set.end());
this->work_set.resize(
std::unique(this->work_set.begin(), this->work_set.end()) - this->work_set.begin());
// start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(this->work_set);
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set, tree);
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = this->feat2workindex[batch.col_index[i]];
if (offset >= 0) {
this->UpdateHistCol(gpair, batch[i], info, tree,
fset, offset,
&this->thread_hist[omp_get_thread_num()]);
}
}
}
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&(this->thread_stats), &(this->node_stats));
for (size_t i = 0; i < this->qexpand.size(); ++i) {
const int nid = this->qexpand[i];
const int wid = this->node2workindex[nid];
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = this->node_stats[nid];
}
}
this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
this->wspace.hset[0].data.size());
}
// cached unit pointer
std::vector<unsigned> cached_rptr_;
// cached cut value.
std::vector<bst_float> cached_cut_;
};
template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> {
@ -763,5 +876,11 @@ XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.set_body([]() {
return new CQHistMaker<GradStats>();
});
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker")
.describe("Tree constructor that uses approximate global proposal of histogram construction.")
.set_body([]() {
return new GlobalProposalHistMaker<GradStats>();
});
} // namespace tree
} // namespace xgboost