change allreduce lib to rabit library, xgboost now run with rabit
This commit is contained in:
@@ -8,8 +8,8 @@
|
||||
#include "./updater_refresh-inl.hpp"
|
||||
#include "./updater_colmaker-inl.hpp"
|
||||
#include "./updater_distcol-inl.hpp"
|
||||
//#include "./updater_skmaker-inl.hpp"
|
||||
#include "./updater_histmaker-inl.hpp"
|
||||
//#include "./updater_skmaker-inl.hpp"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <rabit.h>
|
||||
#include "../utils/random.h"
|
||||
#include "../utils/quantile.h"
|
||||
|
||||
@@ -50,7 +51,7 @@ class BaseMaker: public IUpdater {
|
||||
}
|
||||
}
|
||||
}
|
||||
sync::AllReduce(BeginPtr(fminmax), fminmax.size(), sync::kMax);
|
||||
rabit::Allreduce<rabit::op::Max>(BeginPtr(fminmax), fminmax.size());
|
||||
}
|
||||
// get feature type, 0:empty 1:binary 2:real
|
||||
inline int Type(bst_uint fid) const {
|
||||
@@ -80,11 +81,11 @@ class BaseMaker: public IUpdater {
|
||||
std::string s_cache;
|
||||
utils::MemoryBufferStream fc(&s_cache);
|
||||
utils::IStream &fs = fc;
|
||||
if (sync::GetRank() == 0) {
|
||||
if (rabit::GetRank() == 0) {
|
||||
fs.Write(findex);
|
||||
sync::Bcast(&s_cache, 0);
|
||||
rabit::Broadcast(&s_cache, 0);
|
||||
} else {
|
||||
sync::Bcast(&s_cache, 0);
|
||||
rabit::Broadcast(&s_cache, 0);
|
||||
fs.Read(&findex);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
* and construct a tree
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <rabit.h>
|
||||
#include "../utils/bitmap.h"
|
||||
#include "../utils/io.h"
|
||||
#include "../sync/sync.h"
|
||||
#include "./updater_colmaker-inl.hpp"
|
||||
#include "./updater_prune-inl.hpp"
|
||||
|
||||
@@ -114,7 +114,7 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
|
||||
bitmap.InitFromBool(boolmap);
|
||||
// communicate bitmap
|
||||
sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR);
|
||||
rabit::Allreduce<rabit::op::BitOR>(BeginPtr(bitmap.data), bitmap.data.size());
|
||||
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
|
||||
// get the new position
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
@@ -142,8 +142,9 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
}
|
||||
vec.push_back(this->snode[nid].best);
|
||||
}
|
||||
// TODO, lazy version
|
||||
// communicate best solution
|
||||
reducer.AllReduce(BeginPtr(vec), vec.size());
|
||||
reducer.Allreduce(BeginPtr(vec), vec.size());
|
||||
// assign solution back
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
const int nid = qexpand[i];
|
||||
@@ -154,7 +155,7 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
private:
|
||||
utils::BitMap bitmap;
|
||||
std::vector<int> boolmap;
|
||||
sync::Reducer<SplitEntry> reducer;
|
||||
rabit::Reducer<SplitEntry> reducer;
|
||||
};
|
||||
// we directly introduce pruner here
|
||||
TreePruner pruner;
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
*/
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
#include <rabit.h>
|
||||
#include "../utils/quantile.h"
|
||||
#include "../utils/group_data.h"
|
||||
#include "./updater_basemaker-inl.hpp"
|
||||
@@ -117,7 +117,7 @@ class HistMaker: public BaseMaker {
|
||||
// workspace of thread
|
||||
ThreadWSpace wspace;
|
||||
// reducer for histogram
|
||||
sync::Reducer<TStats> histred;
|
||||
rabit::Reducer<TStats> histred;
|
||||
// set of working features
|
||||
std::vector<bst_uint> fwork_set;
|
||||
// update function implementation
|
||||
@@ -331,7 +331,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
.data[0] = node_stats[nid];
|
||||
}
|
||||
// sync the histogram
|
||||
this->histred.AllReduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
||||
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
||||
}
|
||||
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
@@ -394,8 +394,8 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
if (summary_array.size() != 0) {
|
||||
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
||||
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
}
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
@@ -540,7 +540,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
// summary array
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||
// reducer for summary
|
||||
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
@@ -623,8 +623,8 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
size_t n4bytes = (WXQSketch::SummaryContainer::CalcMemCost(max_size) + 3) / 4;
|
||||
sreducer.AllReduce(BeginPtr(summary_array), n4bytes, summary_array.size());
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
this->wspace.rptr.clear();
|
||||
@@ -660,7 +660,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
// summary array
|
||||
std::vector<WXQSketch::SummaryContainer> summary_array;
|
||||
// reducer for summary
|
||||
sync::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
rabit::SerializeReducer<WXQSketch::SummaryContainer> sreducer;
|
||||
// local temp column data structure
|
||||
std::vector<size_t> col_ptr;
|
||||
// local storage of column data
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
*/
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <rabit.h>
|
||||
#include "./param.h"
|
||||
#include "./updater.h"
|
||||
#include "../utils/omp.h"
|
||||
#include "../sync/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -85,7 +85,7 @@ class TreeRefresher: public IUpdater {
|
||||
}
|
||||
}
|
||||
// AllReduce, add statistics up
|
||||
reducer.AllReduce(BeginPtr(stemp[0]), stemp[0].size());
|
||||
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
@@ -137,7 +137,7 @@ class TreeRefresher: public IUpdater {
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
// reducer
|
||||
sync::Reducer<TStats> reducer;
|
||||
rabit::Reducer<TStats> reducer;
|
||||
};
|
||||
|
||||
} // namespace tree
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
*/
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "../sync/sync.h"
|
||||
#include <rabit.h>
|
||||
#include "../utils/quantile.h"
|
||||
#include "./updater_basemaker-inl.hpp"
|
||||
|
||||
@@ -166,8 +166,8 @@ class SketchMaker: public BaseMaker {
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array.Set(i, out);
|
||||
}
|
||||
size_t n4bytes = (summary_array.MemSize() + 3) / 4;
|
||||
sketch_reducer.AllReduce(&summary_array, n4bytes);
|
||||
size_t nbytes = summary_array.MemSize();;
|
||||
sketch_reducer.Allreduce(&summary_array, nbytes);
|
||||
}
|
||||
// update sketch information in column fid
|
||||
inline void UpdateSketchCol(const std::vector<bst_gpair> &gpair,
|
||||
@@ -256,7 +256,7 @@ class SketchMaker: public BaseMaker {
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
tmp[i] = node_stats[qexpand[i]];
|
||||
}
|
||||
stats_reducer.AllReduce(BeginPtr(tmp), tmp.size());
|
||||
stats_reducer.Allreduce(BeginPtr(tmp), tmp.size());
|
||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||
node_stats[qexpand[i]] = tmp[i];
|
||||
}
|
||||
@@ -382,9 +382,9 @@ class SketchMaker: public BaseMaker {
|
||||
// summary array
|
||||
WXQSketch::SummaryArray summary_array;
|
||||
// reducer for summary
|
||||
sync::Reducer<SKStats> stats_reducer;
|
||||
rabit::Reducer<SKStats> stats_reducer;
|
||||
// reducer for summary
|
||||
sync::ComplexReducer<WXQSketch::SummaryArray> sketch_reducer;
|
||||
rabit::SerializeReducer<WXQSketch::SummaryArray> sketch_reducer;
|
||||
// per node, per feature sketch
|
||||
std::vector< utils::WXQuantileSketch<bst_float, bst_float> > sketchs;
|
||||
};
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
*/
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <rabit.h>
|
||||
#include "./updater.h"
|
||||
#include "../sync/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -32,22 +32,22 @@ class TreeSyncher: public IUpdater {
|
||||
private:
|
||||
// synchronize the trees in different nodes, take tree from rank 0
|
||||
inline void SyncTrees(const std::vector<RegTree *> &trees) {
|
||||
if (sync::GetWorldSize() == 1) return;
|
||||
if (rabit::GetWorldSize() == 1) return;
|
||||
std::string s_model;
|
||||
utils::MemoryBufferStream fs(&s_model);
|
||||
int rank = sync::GetRank();
|
||||
int rank = rabit::GetRank();
|
||||
if (rank == 0) {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->SaveModel(fs);
|
||||
}
|
||||
sync::Bcast(&s_model, 0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
} else {
|
||||
sync::Bcast(&s_model, 0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->LoadModel(fs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user