change hist update to lazy
This commit is contained in:
parent
deb21351b9
commit
7a35e1a906
6
Makefile
6
Makefile
@ -10,6 +10,12 @@ else
|
|||||||
CFLAGS += -fopenmp
|
CFLAGS += -fopenmp
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
# by default use c++11
|
||||||
|
ifeq ($(no_cxx11),1)
|
||||||
|
else
|
||||||
|
CFLAGS += -std=c++11
|
||||||
|
endif
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = xgboost
|
BIN = xgboost
|
||||||
OBJ = updater.o gbm.o io.o main.o
|
OBJ = updater.o gbm.o io.o main.o
|
||||||
|
|||||||
@ -306,6 +306,12 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
// start to work
|
// start to work
|
||||||
this->wspace.Init(this->param, 1);
|
this->wspace.Init(this->param, 1);
|
||||||
|
// if it is C++11, use lazy evaluation for Allreduce,
|
||||||
|
// to gain speedup in recovery
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
auto lazy_get_hist = [&]()
|
||||||
|
#endif
|
||||||
|
{
|
||||||
thread_hist.resize(this->get_nthread());
|
thread_hist.resize(this->get_nthread());
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset);
|
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset);
|
||||||
@ -330,8 +336,15 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
|
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
|
||||||
.data[0] = node_stats[nid];
|
.data[0] = node_stats[nid];
|
||||||
}
|
}
|
||||||
|
};
|
||||||
// sync the histogram
|
// sync the histogram
|
||||||
|
// if it is C++11, use lazy evaluation for Allreduce
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data),
|
||||||
|
this->wspace.hset[0].data.size(), lazy_get_hist);
|
||||||
|
#else
|
||||||
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
@ -354,13 +367,24 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
feat2workindex[fset[i]] = -2;
|
feat2workindex[fset[i]] = -2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this->GetNodeStats(gpair, *p_fmat, tree, info,
|
this->GetNodeStats(gpair, *p_fmat, tree, info,
|
||||||
&thread_stats, &node_stats);
|
&thread_stats, &node_stats);
|
||||||
sketchs.resize(this->qexpand.size() * freal_set.size());
|
sketchs.resize(this->qexpand.size() * freal_set.size());
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||||
}
|
}
|
||||||
|
// intitialize the summary array
|
||||||
|
summary_array.resize(sketchs.size());
|
||||||
|
// setup maximum size
|
||||||
|
unsigned max_size = this->param.max_sketch_size();
|
||||||
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
|
summary_array[i].Reserve(max_size);
|
||||||
|
}
|
||||||
|
// if it is C++11, use lazy evaluation for Allreduce
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
auto lazy_get_summary = [&]()
|
||||||
|
#endif
|
||||||
|
{// get smmary
|
||||||
thread_sketch.resize(this->get_nthread());
|
thread_sketch.resize(this->get_nthread());
|
||||||
// number of rows in
|
// number of rows in
|
||||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
const size_t nrows = p_fmat->buffered_rowset().size();
|
||||||
@ -383,19 +407,20 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// setup maximum size
|
|
||||||
unsigned max_size = this->param.max_sketch_size();
|
|
||||||
// synchronize sketch
|
|
||||||
summary_array.resize(sketchs.size());
|
|
||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||||
sketchs[i].GetSummary(&out);
|
sketchs[i].GetSummary(&out);
|
||||||
summary_array[i].Reserve(max_size);
|
|
||||||
summary_array[i].SetPrune(out, max_size);
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
|
utils::Assert(summary_array.size() == sketchs.size(), "shape mismatch");
|
||||||
|
};
|
||||||
if (summary_array.size() != 0) {
|
if (summary_array.size() != 0) {
|
||||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size(), lazy_get_summary);
|
||||||
|
#else
|
||||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
this->wspace.cut.clear();
|
this->wspace.cut.clear();
|
||||||
@ -623,6 +648,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
summary_array[i].Reserve(max_size);
|
summary_array[i].Reserve(max_size);
|
||||||
summary_array[i].SetPrune(out, max_size);
|
summary_array[i].SetPrune(out, max_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
// now we get the final result of sketch, setup the cut
|
// now we get the final result of sketch, setup the cut
|
||||||
|
|||||||
@ -52,6 +52,12 @@ class TreeRefresher: public IUpdater {
|
|||||||
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param));
|
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param));
|
||||||
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
||||||
}
|
}
|
||||||
|
// if it is C++11, use lazy evaluation for Allreduce,
|
||||||
|
// to gain speedup in recovery
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
auto lazy_get_stats = [&]()
|
||||||
|
#endif
|
||||||
|
{
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
@ -84,8 +90,12 @@ class TreeRefresher: public IUpdater {
|
|||||||
stemp[0][nid].Add(stemp[tid][nid]);
|
stemp[0][nid].Add(stemp[tid][nid]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// AllReduce, add statistics up
|
};
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
|
||||||
|
#else
|
||||||
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
|
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
|
||||||
|
#endif
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param.learning_rate;
|
float lr = param.learning_rate;
|
||||||
param.learning_rate = lr / trees.size();
|
param.learning_rate = lr / trees.size();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user