change hist update to lazy

This commit is contained in:
tqchen 2014-12-20 05:02:38 -08:00
parent deb21351b9
commit 7a35e1a906
3 changed files with 127 additions and 85 deletions

View File

@ -10,6 +10,12 @@ else
CFLAGS += -fopenmp CFLAGS += -fopenmp
endif endif
# by default use c++11
ifeq ($(no_cxx11),1)
else
CFLAGS += -std=c++11
endif
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost
OBJ = updater.o gbm.o io.o main.o OBJ = updater.o gbm.o io.o main.o

View File

@ -152,7 +152,7 @@ class HistMaker: public BaseMaker {
IFMatrix *p_fmat, IFMatrix *p_fmat,
const BoosterInfo &info, const BoosterInfo &info,
const std::vector <bst_uint> &fset, const std::vector <bst_uint> &fset,
const RegTree &tree) = 0; const RegTree &tree) = 0;
// initialize the current working set of features in this round // initialize the current working set of features in this round
virtual void InitWorkSet(IFMatrix *p_fmat, virtual void InitWorkSet(IFMatrix *p_fmat,
const RegTree &tree, const RegTree &tree,
@ -306,32 +306,45 @@ class CQHistMaker: public HistMaker<TStats> {
} }
// start to work // start to work
this->wspace.Init(this->param, 1); this->wspace.Init(this->param, 1);
thread_hist.resize(this->get_nthread()); // if it is C++11, use lazy evaluation for Allreduce,
// start accumulating statistics // to gain speedup in recovery
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset); #if __cplusplus >= 201103L
iter->BeforeFirst(); auto lazy_get_hist = [&]()
while (iter->Next()) { #endif
const ColBatch &batch = iter->Value(); {
// start enumeration thread_hist.resize(this->get_nthread());
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size); // start accumulating statistics
#pragma omp parallel for schedule(dynamic, 1) utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset);
for (bst_omp_uint i = 0; i < nsize; ++i) { iter->BeforeFirst();
int offset = feat2workindex[batch.col_index[i]]; while (iter->Next()) {
if (offset >= 0) { const ColBatch &batch = iter->Value();
this->UpdateHistCol(gpair, batch[i], info, tree, // start enumeration
fset, offset, const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
&thread_hist[omp_get_thread_num()]); #pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = feat2workindex[batch.col_index[i]];
if (offset >= 0) {
this->UpdateHistCol(gpair, batch[i], info, tree,
fset, offset,
&thread_hist[omp_get_thread_num()]);
}
} }
} }
} for (size_t i = 0; i < this->qexpand.size(); ++i) {
for (size_t i = 0; i < this->qexpand.size(); ++i) { const int nid = this->qexpand[i];
const int nid = this->qexpand[i]; const int wid = this->node2workindex[nid];
const int wid = this->node2workindex[nid]; this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)] .data[0] = node_stats[nid];
.data[0] = node_stats[nid]; }
} };
// sync the histogram // sync the histogram
// if it is C++11, use lazy evaluation for Allreduce
#if __cplusplus >= 201103L
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data),
this->wspace.hset[0].data.size(), lazy_get_hist);
#else
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size()); this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
#endif
} }
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat, virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
const RegTree &tree) { const RegTree &tree) {
@ -354,48 +367,60 @@ class CQHistMaker: public HistMaker<TStats> {
feat2workindex[fset[i]] = -2; feat2workindex[fset[i]] = -2;
} }
} }
this->GetNodeStats(gpair, *p_fmat, tree, info, this->GetNodeStats(gpair, *p_fmat, tree, info,
&thread_stats, &node_stats); &thread_stats, &node_stats);
sketchs.resize(this->qexpand.size() * freal_set.size()); sketchs.resize(this->qexpand.size() * freal_set.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
sketchs[i].Init(info.num_row, this->param.sketch_eps); sketchs[i].Init(info.num_row, this->param.sketch_eps);
} }
thread_sketch.resize(this->get_nthread()); // intitialize the summary array
// number of rows in summary_array.resize(sketchs.size());
const size_t nrows = p_fmat->buffered_rowset().size();
// start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(freal_set);
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = feat2workindex[batch.col_index[i]];
if (offset >= 0) {
this->UpdateSketchCol(gpair, batch[i], tree,
node_stats,
freal_set, offset,
batch[i].length == nrows,
&thread_sketch[omp_get_thread_num()]);
}
}
}
// setup maximum size // setup maximum size
unsigned max_size = this->param.max_sketch_size(); unsigned max_size = this->param.max_sketch_size();
// synchronize sketch
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array[i].Reserve(max_size); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size);
} }
// if it is C++11, use lazy evaluation for Allreduce
#if __cplusplus >= 201103L
auto lazy_get_summary = [&]()
#endif
{// get smmary
thread_sketch.resize(this->get_nthread());
// number of rows in
const size_t nrows = p_fmat->buffered_rowset().size();
// start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(freal_set);
iter->BeforeFirst();
while (iter->Next()) {
const ColBatch &batch = iter->Value();
// start enumeration
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < nsize; ++i) {
int offset = feat2workindex[batch.col_index[i]];
if (offset >= 0) {
this->UpdateSketchCol(gpair, batch[i], tree,
node_stats,
freal_set, offset,
batch[i].length == nrows,
&thread_sketch[omp_get_thread_num()]);
}
}
}
for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array[i].SetPrune(out, max_size);
}
utils::Assert(summary_array.size() == sketchs.size(), "shape mismatch");
};
if (summary_array.size() != 0) { if (summary_array.size() != 0) {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
#if __cplusplus >= 201103L
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size(), lazy_get_summary);
#else
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
#endif
} }
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
this->wspace.cut.clear(); this->wspace.cut.clear();
@ -623,6 +648,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
summary_array[i].Reserve(max_size); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size); summary_array[i].SetPrune(out, max_size);
} }
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut

View File

@ -52,40 +52,50 @@ class TreeRefresher: public IUpdater {
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param)); std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param));
fvec_temp[tid].Init(trees[0]->param.num_feature); fvec_temp[tid].Init(trees[0]->param.num_feature);
} }
// start accumulating statistics // if it is C++11, use lazy evaluation for Allreduce,
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator(); // to gain speedup in recovery
iter->BeforeFirst(); #if __cplusplus >= 201103L
while (iter->Next()) { auto lazy_get_stats = [&]()
const RowBatch &batch = iter->Value(); #endif
utils::Check(batch.size < std::numeric_limits<unsigned>::max(), {
"too large batch size "); // start accumulating statistics
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size); utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
#pragma omp parallel for schedule(static) iter->BeforeFirst();
for (bst_omp_uint i = 0; i < nbatch; ++i) { while (iter->Next()) {
RowBatch::Inst inst = batch[i]; const RowBatch &batch = iter->Value();
const int tid = omp_get_thread_num(); utils::Check(batch.size < std::numeric_limits<unsigned>::max(),
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i); "too large batch size ");
RegTree::FVec &feats = fvec_temp[tid]; const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
feats.Fill(inst); #pragma omp parallel for schedule(static)
int offset = 0; for (bst_omp_uint i = 0; i < nbatch; ++i) {
for (size_t j = 0; j < trees.size(); ++j) { RowBatch::Inst inst = batch[i];
AddStats(*trees[j], feats, gpair, info, ridx, const int tid = omp_get_thread_num();
BeginPtr(stemp[tid]) + offset); const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
offset += trees[j]->param.num_nodes; RegTree::FVec &feats = fvec_temp[tid];
feats.Fill(inst);
int offset = 0;
for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair, info, ridx,
BeginPtr(stemp[tid]) + offset);
offset += trees[j]->param.num_nodes;
}
feats.Drop(inst);
} }
feats.Drop(inst);
} }
} // aggregate the statistics
// aggregate the statistics int num_nodes = static_cast<int>(stemp[0].size());
int num_nodes = static_cast<int>(stemp[0].size()); #pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) for (int nid = 0; nid < num_nodes; ++nid) {
for (int nid = 0; nid < num_nodes; ++nid) { for (int tid = 1; tid < nthread; ++tid) {
for (int tid = 1; tid < nthread; ++tid) { stemp[0][nid].Add(stemp[tid][nid]);
stemp[0][nid].Add(stemp[tid][nid]); }
} }
} };
// AllReduce, add statistics up #if __cplusplus >= 201103L
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
#else
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size()); reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
#endif
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.learning_rate;
param.learning_rate = lr / trees.size(); param.learning_rate = lr / trees.size();