Distributed Fast Histogram Algorithm (#4011)

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* init

* allow hist algo

* more changes

* temp

* update

* remove hist sync

* udpate rabit

* change hist size

* change the histogram

* update kfactor

* sync per node stats

* temp

* update

* final

* code clean

* update rabit

* more cleanup

* fix errors

* fix failed tests

* enforce c++11

* fix lint issue

* broadcast subsampled feature correctly

* revert some changes

* fix lint issue

* enable monotone and interaction constraints

* don't specify default for monotone and interactions

* update docs
This commit is contained in:
Nan Zhu
2019-02-05 05:12:53 -08:00
committed by GitHub
parent 8905df4a18
commit ae3bb9c2d5
16 changed files with 169 additions and 88 deletions

View File

@@ -126,6 +126,7 @@ class HistMaker: public BaseMaker {
virtual void Update(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
RegTree *p_tree) {
CHECK(param_.max_depth > 0) << "max_depth must be larger than 0";
this->InitData(gpair, *p_fmat, *p_tree);
this->InitWorkSet(p_fmat, *p_tree, &fwork_set_);
// mark root node as fresh.
@@ -345,10 +346,7 @@ class CQHistMaker: public HistMaker<TStats> {
this->wspace_.Init(this->param_, 1);
// if it is C++11, use lazy evaluation for Allreduce,
// to gain speedup in recovery
#if __cplusplus >= 201103L
auto lazy_get_hist = [&]()
#endif
{
auto lazy_get_hist = [&]() {
thread_hist_.resize(omp_get_max_threads());
// start accumulating statistics
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
@@ -371,22 +369,18 @@ class CQHistMaker: public HistMaker<TStats> {
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
const int nid = this->qexpand_[i];
const int wid = this->node2workindex_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = node_stats_[nid];
this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)]
.data[0] = node_stats_[nid];
}
};
// sync the histogram
// if it is C++11, use lazy evaluation for Allreduce
#if __cplusplus >= 201103L
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
this->wspace_.hset[0].data.size(), lazy_get_hist);
#else
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
this->wspace_.hset[0].data.size());
#endif
this->wspace_.hset[0].data.size(), lazy_get_hist);
}
void ResetPositionAfterSplit(DMatrix *p_fmat,
const RegTree &tree) override {
const RegTree &tree) override {
this->GetSplitSet(this->qexpand_, tree, &fsplit_set_);
}
void ResetPosAndPropose(const std::vector<GradientPair> &gpair,

View File

@@ -156,12 +156,18 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
const int cright = (*p_tree)[nid].RightChild();
hist_.AddHistRow(cleft);
hist_.AddHistRow(cright);
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
if (rabit::IsDistributed()) {
// in distributed mode, we need to keep consistent across workers
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]);
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
} else {
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright]);
SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]);
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
} else {
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright]);
SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
}
}
time_build_hist += dmlc::GetTime() - tstart;
@@ -617,23 +623,34 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
{
auto& stats = snode_[nid].stats;
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
/* specialized code for dense data
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
GHistRow hist = hist_[nid];
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
for (uint32_t i = ibegin; i < iend; ++i) {
const GHistEntry et = hist.begin[i];
stats.Add(et.sum_grad, et.sum_hess);
GHistRow hist = hist_[nid];
if (rabit::IsDistributed()) {
// in distributed mode, the node's stats should be calculated from histogram, otherwise,
// we will have wrong results in EnumerateSplit()
// here we take the last feature in cut
for (size_t i = gmat.cut.row_ptr[0]; i < gmat.cut.row_ptr[1]; i++) {
stats.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess);
}
} else {
const RowSetCollection::Elem e = row_set_collection_[nid];
for (const size_t* it = e.begin; it < e.end; ++it) {
stats.Add(gpair[*it]);
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased ||
rabit::IsDistributed()) {
/* specialized code for dense data
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid]
GHistRow hist = hist_[nid];*/
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
for (uint32_t i = ibegin; i < iend; ++i) {
const GHistEntry et = hist.begin[i];
stats.Add(et.sum_grad, et.sum_hess);
}
} else {
const RowSetCollection::Elem e = row_set_collection_[nid];
for (const size_t* it = e.begin; it < e.end; ++it) {
stats.Add(gpair[*it]);
}
}
}
}

View File

@@ -105,6 +105,7 @@ class QuantileHistMaker: public TreeUpdater {
} else {
hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
}
this->histred_.Allreduce(hist.begin, hist_builder_.GetNumBins());
}
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
@@ -225,6 +226,8 @@ class QuantileHistMaker: public TreeUpdater {
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
DataLayout data_layout_;
rabit::Reducer<GHistEntry, GHistEntry::Reduce> histred_;
};
std::unique_ptr<Builder> builder_;

View File

@@ -52,10 +52,7 @@ class TreeRefresher: public TreeUpdater {
}
// if it is C++11, use lazy evaluation for Allreduce,
// to gain speedup in recovery
#if __cplusplus >= 201103L
auto lazy_get_stats = [&]()
#endif
{
auto lazy_get_stats = [&]() {
const MetaInfo &info = p_fmat->Info();
// start accumulating statistics
for (const auto &batch : p_fmat->GetRowBatches()) {
@@ -86,11 +83,7 @@ class TreeRefresher: public TreeUpdater {
}
}
};
#if __cplusplus >= 201103L
reducer_.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
#else
reducer_.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size());
#endif
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();