[TREE] Enable global proposal for faster speed

This commit is contained in:
tqchen 2016-01-28 17:16:38 -08:00
parent 2f2080a337
commit ce4d59ed69
2 changed files with 141 additions and 22 deletions

View File

@ -138,23 +138,23 @@ DMatrix* DMatrix::Load(const std::string& uri,
cache_file = uri.substr(dlm_pos + 1, uri.length()); cache_file = uri.substr(dlm_pos + 1, uri.length());
fname = uri.substr(0, dlm_pos); fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos) CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification."; << "Only one `#` is allowed in file path for cache file specification.";
if (load_row_split) { if (load_row_split) {
std::ostringstream os; std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':'); std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) { for (size_t i = 0; i < cache_shards.size(); ++i) {
size_t pos = cache_shards[i].rfind('.'); size_t pos = cache_shards[i].rfind('.');
if (pos == std::string::npos) { if (pos == std::string::npos) {
os << cache_shards[i] os << cache_shards[i]
<< ".r" << rabit::GetRank() << ".r" << rabit::GetRank()
<< "-" << rabit::GetWorldSize(); << "-" << rabit::GetWorldSize();
} else { } else {
os << cache_shards[i].substr(0, pos) os << cache_shards[i].substr(0, pos)
<< ".r" << rabit::GetRank() << ".r" << rabit::GetRank()
<< "-" << rabit::GetWorldSize() << "-" << rabit::GetWorldSize()
<< cache_shards[i].substr(pos, cache_shards[i].length()); << cache_shards[i].substr(pos, cache_shards[i].length());
} }
if (i + 1 != cache_shards.size()) os << ':'; if (i + 1 != cache_shards.size()) os << ':';
} }
cache_file = os.str(); cache_file = os.str();
} }
@ -172,7 +172,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
if (npart != 1) { if (npart != 1) {
LOG(CONSOLE) << "Load part of data " << partid LOG(CONSOLE) << "Load part of data " << partid
<< " of " << npart << " parts"; << " of " << npart << " parts";
} }
// legacy handling of binary data loading // legacy handling of binary data loading
if (file_format == "auto" && !load_row_split) { if (file_format == "auto" && !load_row_split) {

View File

@ -268,6 +268,10 @@ class HistMaker: public BaseMaker {
template<typename TStats> template<typename TStats>
class CQHistMaker: public HistMaker<TStats> { class CQHistMaker: public HistMaker<TStats> {
public:
CQHistMaker() : cache_dmatrix_(nullptr) {
}
protected: protected:
struct HistEntry { struct HistEntry {
typename HistMaker<TStats>::HistUnit hist; typename HistMaker<TStats>::HistUnit hist;
@ -290,9 +294,13 @@ class CQHistMaker: public HistMaker<TStats> {
*/ */
inline void Add(bst_float fv, inline void Add(bst_float fv,
bst_gpair gstats) { bst_gpair gstats) {
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart; if (fv < hist.cut[istart]) {
CHECK_NE(istart, hist.size); hist.data[istart].Add(gstats);
hist.data[istart].Add(gstats); } else {
while (istart < hist.size && !(fv < hist.cut[istart])) ++istart;
CHECK_NE(istart, hist.size);
hist.data[istart].Add(gstats);
}
} }
}; };
// sketch type used for this // sketch type used for this
@ -301,7 +309,10 @@ class CQHistMaker: public HistMaker<TStats> {
void InitWorkSet(DMatrix *p_fmat, void InitWorkSet(DMatrix *p_fmat,
const RegTree &tree, const RegTree &tree,
std::vector<bst_uint> *p_fset) override { std::vector<bst_uint> *p_fset) override {
feat_helper.InitByCol(p_fmat, tree); if (p_fmat != cache_dmatrix_) {
feat_helper.InitByCol(p_fmat, tree);
cache_dmatrix_ = p_fmat;
}
feat_helper.SampleCol(this->param.colsample_bytree, p_fset); feat_helper.SampleCol(this->param.colsample_bytree, p_fset);
} }
// code to create histogram // code to create histogram
@ -342,6 +353,9 @@ class CQHistMaker: public HistMaker<TStats> {
} }
} }
} }
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&thread_stats, &node_stats);
for (size_t i = 0; i < this->qexpand.size(); ++i) { for (size_t i = 0; i < this->qexpand.size(); ++i) {
const int nid = this->qexpand[i]; const int nid = this->qexpand[i];
const int wid = this->node2workindex[nid]; const int wid = this->node2workindex[nid];
@ -434,9 +448,6 @@ class CQHistMaker: public HistMaker<TStats> {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
} }
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&thread_stats, &node_stats);
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
this->wspace.cut.clear(); this->wspace.cut.clear();
this->wspace.rptr.clear(); this->wspace.rptr.clear();
@ -475,7 +486,6 @@ class CQHistMaker: public HistMaker<TStats> {
(fset.size() + 1) * this->qexpand.size() + 1); (fset.size() + 1) * this->qexpand.size() + 1);
} }
private:
inline void UpdateHistCol(const std::vector<bst_gpair> &gpair, inline void UpdateHistCol(const std::vector<bst_gpair> &gpair,
const ColBatch::Inst &c, const ColBatch::Inst &c,
const MetaInfo &info, const MetaInfo &info,
@ -607,6 +617,8 @@ class CQHistMaker: public HistMaker<TStats> {
sbuilder[nid].Finalize(max_size); sbuilder[nid].Finalize(max_size);
} }
} }
// cached dmatrix where we initialized the feature on.
const DMatrix* cache_dmatrix_;
// feature helper // feature helper
BaseMaker::FMetaHelper feat_helper; BaseMaker::FMetaHelper feat_helper;
// temp space to map feature id to working index // temp space to map feature id to working index
@ -631,6 +643,107 @@ class CQHistMaker: public HistMaker<TStats> {
std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs; std::vector<common::WXQuantileSketch<bst_float, bst_float> > sketchs;
}; };
// global proposal
template<typename TStats>
class GlobalProposalHistMaker: public CQHistMaker<TStats> {
protected:
void ResetPosAndPropose(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_uint> &fset,
const RegTree &tree) override {
if (this->qexpand.size() == 1 && !this->param.cache_global_proposal) {
cached_rptr_.clear();
cached_cut_.clear();
}
if (cached_rptr_.size() == 0) {
CHECK_EQ(this->qexpand.size(), 1);
CQHistMaker<TStats>::ResetPosAndPropose(gpair, p_fmat, fset, tree);
cached_rptr_ = this->wspace.rptr;
cached_cut_ = this->wspace.cut;
} else {
this->wspace.cut.clear();
this->wspace.rptr.clear();
this->wspace.rptr.push_back(0);
for (size_t i = 0; i < this->qexpand.size(); ++i) {
for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) {
this->wspace.rptr.push_back(
this->wspace.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]);
}
this->wspace.cut.insert(this->wspace.cut.end(), cached_cut_.begin(), cached_cut_.end());
}
CHECK_EQ(this->wspace.rptr.size(),
(fset.size() + 1) * this->qexpand.size() + 1);
CHECK_EQ(this->wspace.rptr.back(), this->wspace.cut.size());
}
}
// code to create histogram
void CreateHist(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
const std::vector<bst_uint> &fset,
const RegTree &tree) override {
const MetaInfo &info = p_fmat->info();
// fill in reverse map
this->feat2workindex.resize(tree.param.num_feature);
this->work_set = fset;
std::fill(this->feat2workindex.begin(), this->feat2workindex.end(), -1);
for (size_t i = 0; i < fset.size(); ++i) {
this->feat2workindex[fset[i]] = static_cast<int>(i);
}
// start to work
this->wspace.Init(this->param, 1);
// to gain speedup in recovery
{
this->thread_hist.resize(this->get_nthread());
// TWOPASS: use the real set + split set in the column iteration.
this->SetDefaultPostion(p_fmat, tree);
this->work_set.insert(this->work_set.end(), this->fsplit_set.begin(), this->fsplit_set.end());
std::sort(this->work_set.begin(), this->work_set.end());
this->work_set.resize(
std::unique(this->work_set.begin(), this->work_set.end()) - this->work_set.begin());
// start accumulating statistics
dmlc::DataIter<ColBatch> *iter = p_fmat->ColIterator(this->work_set);
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// TWOPASS: use the real set + split set in the column iteration.
this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set, tree);
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = this->feat2workindex[batch.col_index[i]];
if (offset >= 0) {
this->UpdateHistCol(gpair, batch[i], info, tree,
fset, offset,
&this->thread_hist[omp_get_thread_num()]);
}
}
}
// update node statistics.
this->GetNodeStats(gpair, *p_fmat, tree,
&(this->thread_stats), &(this->node_stats));
for (size_t i = 0; i < this->qexpand.size(); ++i) {
const int nid = this->qexpand[i];
const int wid = this->node2workindex[nid];
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = this->node_stats[nid];
}
}
this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data),
this->wspace.hset[0].data.size());
}
// cached unit pointer
std::vector<unsigned> cached_rptr_;
// cached cut value.
std::vector<bst_float> cached_cut_;
};
template<typename TStats> template<typename TStats>
class QuantileHistMaker: public HistMaker<TStats> { class QuantileHistMaker: public HistMaker<TStats> {
@ -763,5 +876,11 @@ XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.set_body([]() { .set_body([]() {
return new CQHistMaker<GradStats>(); return new CQHistMaker<GradStats>();
}); });
XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker")
.describe("Tree constructor that uses approximate global proposal of histogram construction.")
.set_body([]() {
return new GlobalProposalHistMaker<GradStats>();
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost