make clear seperation

This commit is contained in:
tqchen 2014-10-16 13:03:42 -07:00
parent 47145a7fac
commit a21df0770d
6 changed files with 63 additions and 24 deletions

View File

@ -11,11 +11,11 @@ else
endif endif
# specify tensor path # specify tensor path
BIN = BIN = xgboost
OBJ = updater.o gbm.o io.o main.o OBJ = updater.o gbm.o io.o main.o sync_empty.o
MPIOBJ = sync.o MPIOBJ = sync_mpi.o
MPIBIN = xgboost MPIBIN = xgboost-mpi
SLIB = #wrapper/libxgboostwrapper.so SLIB = wrapper/libxgboostwrapper.so
.PHONY: clean all python Rpack .PHONY: clean all python Rpack
@ -27,11 +27,12 @@ wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp $(OBJ)
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
sync.o: src/sync/sync.cpp sync_mpi.o: src/sync/sync_mpi.cpp
sync_empty.o: src/sync/sync_empty.cpp
main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
xgboost: $(OBJ) $(MPIOBJ) xgboost: updater.o gbm.o io.o main.o sync_empty.o
#wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h $(OBJ) xgboost-mpi: updater.o gbm.o io.o main.o sync_mpi.o
test/test: test/test.cpp sync.o wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h $(OBJ)
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)

View File

@ -4,12 +4,12 @@ python mapfeat.py
# split train and test # split train and test
python mknfold.py agaricus.txt 1 python mknfold.py agaricus.txt 1
# training and output the models # training and output the models
../../xgboost mushroom.conf mpirun ../../xgboost mushroom.conf
# output prediction task=pred # output prediction task=pred
../../xgboost mushroom.conf task=pred model_in=0002.model mpirun ../../xgboost mushroom.conf task=pred model_in=0002.model
# print the boosters of 00002.model in dump.raw.txt # print the boosters of 00002.model in dump.raw.txt
../../xgboost mushroom.conf task=dump model_in=0002.model name_dump=dump.raw.txt mpirun ../../xgboost mushroom.conf task=dump model_in=0002.model name_dump=dump.raw.txt
# use the feature map in printing for better visualization # use the feature map in printing for better visualization
../../xgboost mushroom.conf task=dump model_in=0002.model fmap=featmap.txt name_dump=dump.nice.txt mpirun ../../xgboost mushroom.conf task=dump model_in=0002.model fmap=featmap.txt name_dump=dump.nice.txt
cat dump.nice.txt cat dump.nice.txt

27
src/sync/sync_empty.cpp Normal file
View File

@ -0,0 +1,27 @@
#include "./sync.h"
#include "../utils/utils.h"
// no synchronization module, single thread mode does not need it anyway
namespace xgboost {
namespace sync {
int GetRank(void) {
return 0;
}
void Init(int argc, char *argv[]) {
}
void Finalize(void) {
}
template<>
void AllReduce<uint32_t>(uint32_t *sendrecvbuf, int count, ReduceOp op) {
}
template<>
void AllReduce<float>(float *sendrecvbuf, int count, ReduceOp op) {
}
void Bcast(std::string *sendrecv_data, int root) {
}
ReduceHandle::ReduceHandle(void) : handle(NULL) {}
ReduceHandle::~ReduceHandle(void) {}
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {}
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {}
} // namespace sync
} // namespace xgboost

View File

@ -1,10 +1,9 @@
#include "./sync.h" #include "./sync.h"
#include "../utils/utils.h" #include "../utils/utils.h"
#include "mpi.h" #include "mpi.h"
// use MPI to implement sync
namespace xgboost { namespace xgboost {
namespace sync { namespace sync {
int GetRank(void) { int GetRank(void) {
return MPI::COMM_WORLD.Get_rank(); return MPI::COMM_WORLD.Get_rank();
} }

View File

@ -32,13 +32,13 @@ class DistColMaker : public ColMaker<TStats> {
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time"); utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
// build the tree // build the tree
builder.Update(gpair, p_fmat, info, trees[0]); builder.Update(gpair, p_fmat, info, trees[0]);
// prune the tree //// prune the tree
pruner.Update(gpair, p_fmat, info, trees); pruner.Update(gpair, p_fmat, info, trees);
this->SyncTrees(trees[0]); this->SyncTrees(trees[0]);
// update position after the tree is pruned // update position after the tree is pruned
builder.UpdatePosition(p_fmat, *trees[0]); builder.UpdatePosition(p_fmat, *trees[0]);
} }
private: private:
inline void SyncTrees(RegTree *tree) { inline void SyncTrees(RegTree *tree) {
std::string s_model; std::string s_model;
@ -63,10 +63,12 @@ class DistColMaker : public ColMaker<TStats> {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
int nid = this->position[ridx]; int nid = this->DecodePosition(ridx);
if (nid < 0) { while (tree[nid].is_deleted()) {
nid = tree[nid].parent();
utils::Assert(nid >=0, "distributed learning error");
} }
this->position[ridx] = nid;
} }
} }
protected: protected:
@ -111,6 +113,7 @@ class DistColMaker : public ColMaker<TStats> {
} }
} }
} }
// communicate bitmap // communicate bitmap
sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR); sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR);
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset(); const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
@ -125,7 +128,7 @@ class DistColMaker : public ColMaker<TStats> {
if (tree[nid].default_left()) { if (tree[nid].default_left()) {
this->SetEncodePosition(ridx, tree[nid].cright()); this->SetEncodePosition(ridx, tree[nid].cright());
} else { } else {
this->SetEncodePosition(ridx, tree[nid].cright()); this->SetEncodePosition(ridx, tree[nid].cleft());
} }
} }
} }

View File

@ -5,6 +5,7 @@
#include <string> #include <string>
#include <cstring> #include <cstring>
#include "io/io.h" #include "io/io.h"
#include "sync/sync.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/config.h" #include "utils/config.h"
#include "learner/learner-inl.hpp" #include "learner/learner-inl.hpp"
@ -19,7 +20,7 @@ class BoostLearnTask{
if (argc < 2) { if (argc < 2) {
printf("Usage: <config>\n"); printf("Usage: <config>\n");
return 0; return 0;
} }
utils::ConfigIterator itr(argv[1]); utils::ConfigIterator itr(argv[1]);
while (itr.Next()) { while (itr.Next()) {
this->SetParam(itr.name(), itr.val()); this->SetParam(itr.name(), itr.val());
@ -30,6 +31,9 @@ class BoostLearnTask{
this->SetParam(name, val); this->SetParam(name, val);
} }
} }
if (sync::GetRank() != 0) {
this->SetParam("silent", "2");
}
this->InitData(); this->InitData();
this->InitLearner(); this->InitLearner();
if (task == "dump") { if (task == "dump") {
@ -145,7 +149,9 @@ class BoostLearnTask{
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
learner.UpdateOneIter(i, *data); learner.UpdateOneIter(i, *data);
std::string res = learner.EvalOneIter(i, devalall, eval_data_names); std::string res = learner.EvalOneIter(i, devalall, eval_data_names);
fprintf(stderr, "%s\n", res.c_str()); if (silent < 1) {
fprintf(stderr, "%s\n", res.c_str());
}
if (save_period != 0 && (i + 1) % save_period == 0) { if (save_period != 0 && (i + 1) % save_period == 0) {
this->SaveModel(i); this->SaveModel(i);
} }
@ -243,7 +249,10 @@ class BoostLearnTask{
} }
int main(int argc, char *argv[]){ int main(int argc, char *argv[]){
xgboost::sync::Init(argc, argv);
xgboost::random::Seed(0); xgboost::random::Seed(0);
xgboost::BoostLearnTask tsk; xgboost::BoostLearnTask tsk;
return tsk.Run(argc, argv); int ret = tsk.Run(argc, argv);
xgboost::sync::Finalize();
return ret;
} }