middle version

This commit is contained in:
tqchen 2014-10-16 10:38:49 -07:00
parent 6680bffaae
commit aefe58a207
8 changed files with 188 additions and 64 deletions

View File

@ -11,11 +11,11 @@ else
endif endif
# specify tensor path # specify tensor path
BIN = xgboost BIN =
OBJ = updater.o gbm.o io.o OBJ = updater.o gbm.o io.o main.o
MPIOBJ = sync.o MPIOBJ = sync.o
MPIBIN = test/test MPIBIN = test/test xgboost
SLIB = wrapper/libxgboostwrapper.so SLIB = #wrapper/libxgboostwrapper.so
.PHONY: clean all python Rpack .PHONY: clean all python Rpack
@ -28,8 +28,9 @@ updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
sync.o: src/sync/sync.cpp sync.o: src/sync/sync.cpp
xgboost: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h $(OBJ) main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h $(OBJ) xgboost: $(OBJ) $(MPIOBJ)
#wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h $(OBJ)
test/test: test/test.cpp sync.o test/test: test/test.cpp sync.o
$(BIN) : $(BIN) :

View File

@ -5,29 +5,6 @@
namespace xgboost { namespace xgboost {
namespace sync { namespace sync {
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle(NULL) {
}
ReduceHandle::~ReduceHandle(void) {
if (handle != NULL) {
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
op->Free();
delete op;
}
}
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
MPI::Op *op = new MPI::Op();
MPI::User_function *pf = reinterpret_cast<MPI::User_function*>(redfunc);
op->Init(pf, commute);
handle = op;
}
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, n4byte, MPI_INT, *op);
}
int GetRank(void) { int GetRank(void) {
return MPI::COMM_WORLD.Get_rank(); return MPI::COMM_WORLD.Get_rank();
} }
@ -57,5 +34,37 @@ void AllReduce<float>(float *sendrecvbuf, int count, ReduceOp op) {
AllReduce_(sendrecvbuf, count, MPI::FLOAT, op); AllReduce_(sendrecvbuf, count, MPI::FLOAT, op);
} }
void Bcast(std::string *sendrecv_data, int root) {
unsigned len = static_cast<unsigned>(sendrecv_data->length());
MPI::COMM_WORLD.Bcast(&len, 1, MPI::UNSIGNED, root);
sendrecv_data->resize(len);
if (len != 0) {
MPI::COMM_WORLD.Bcast(&(*sendrecv_data)[0], len, MPI::CHAR, root);
}
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle(NULL) {
}
ReduceHandle::~ReduceHandle(void) {
if (handle != NULL) {
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
op->Free();
delete op;
}
}
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
MPI::Op *op = new MPI::Op();
MPI::User_function *pf = reinterpret_cast<MPI::User_function*>(redfunc);
op->Init(pf, commute);
handle = op;
}
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, n4byte, MPI_INT, *op);
}
} // namespace sync } // namespace sync
} // namespace xgboost } // namespace xgboost

View File

@ -18,11 +18,39 @@ enum ReduceOp {
kBitwiseOR kBitwiseOR
}; };
typedef void (ReduceFunction) (const void *src, void *dst, int len); /*! \brief get rank of current process */
int GetRank(void);
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]);
/*! \brief finalize syncrhonization module */
void Finalize(void);
/* !\brief handle for customized reducer */ /*!
* \brief in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer
* \param count count of data
* \param op reduction function
*/
template<typename DType>
void AllReduce(DType *sendrecvbuf, int count, ReduceOp op);
/*!
* \brief broadcast an std::string to all others from root
* \param sendrecv_data the pointer to send or recive buffer,
* receive buffer does not need to be pre-allocated
* and string will be resized to correct length
* \param root the root of process
*/
void Bcast(std::string *sendrecv_data, int root);
/*!
* \brief handle for customized reducer
* user do not need to use this, used Reducer instead
*/
class ReduceHandle { class ReduceHandle {
public: public:
// reduce function
typedef void (ReduceFunction) (const void *src, void *dst, int len);
// constructor // constructor
ReduceHandle(void); ReduceHandle(void);
// destructor // destructor
@ -41,22 +69,8 @@ class ReduceHandle {
void *handle; void *handle;
}; };
/*! \brief get rank of current process */ // ----- extensions for ease of use ------
int GetRank(void);
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]);
/*! \brief finalize syncrhonization module */
void Finalize(void);
/*! /*!
* \brief in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer
* \param count count of data
* \param op reduction function
*/
template<typename DType>
void AllReduce(DType *sendrecvbuf, int count, ReduceOp op);
/*!
* \brief template class to make customized reduce and all reduce easy * \brief template class to make customized reduce and all reduce easy
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize * Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
* \tparam DType data type that to be reduced * \tparam DType data type that to be reduced

View File

@ -110,6 +110,10 @@ class TreeModel {
inline bool is_left_child(void) const { inline bool is_left_child(void) const {
return (parent_ & (1U << 31)) != 0; return (parent_ & (1U << 31)) != 0;
} }
/*! \brief whether this node is deleted */
inline bool is_deleted(void) const {
return sindex_ == std::numeric_limits<unsigned>::max();
}
/*! \brief whether current node is root */ /*! \brief whether current node is root */
inline bool is_root(void) const { inline bool is_root(void) const {
return parent_ == -1; return parent_ == -1;
@ -144,7 +148,11 @@ class TreeModel {
this->cleft_ = -1; this->cleft_ = -1;
this->cright_ = right; this->cright_ = right;
} }
/*! \brief mark that this node is deleted */
inline void mark_delete(void) {
this->sindex_ = std::numeric_limits<unsigned>::max();
}
private: private:
friend class TreeModel<TSplitCond, TNodeStat>; friend class TreeModel<TSplitCond, TNodeStat>;
/*! /*!
@ -197,11 +205,11 @@ class TreeModel {
leaf_vector.resize(param.num_nodes * param.size_leaf_vector); leaf_vector.resize(param.num_nodes * param.size_leaf_vector);
return nd; return nd;
} }
// delete a tree node // delete a tree node, keep the parent field to allow trace back
inline void DeleteNode(int nid) { inline void DeleteNode(int nid) {
utils::Assert(nid >= param.num_roots, "can not delete root"); utils::Assert(nid >= param.num_roots, "can not delete root");
deleted_nodes.push_back(nid); deleted_nodes.push_back(nid);
nodes[nid].set_parent(-1); nodes[nid].mark_delete();
++param.num_deleted; ++param.num_deleted;
} }

View File

@ -345,6 +345,10 @@ struct SplitEntry{
return false; return false;
} }
} }
/*! \brief same as update, used by AllReduce*/
inline void Reduce(const SplitEntry &e) {
this->Update(e);
}
/*!\return feature index to split on */ /*!\return feature index to split on */
inline unsigned split_index(void) const { inline unsigned split_index(void) const {
return sindex & ((1U << 31) - 1U); return sindex & ((1U << 31) - 1U);

View File

@ -486,13 +486,17 @@ class ColMaker: public IUpdater {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
const int nid = position[ridx]; int nid = position[ridx];
if (nid >= 0) { if (nid < 0) nid = ~nid;
if (tree[nid].is_leaf()) { if (tree[nid].is_leaf()) {
position[ridx] = - nid - 1; position[ridx] = ~nid;
} else {
// push to default branch, correct latter
int pid = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
if (position[ridx] < 0) {
position[ridx] = ~pid;
} else { } else {
// push to default branch, correct latter position[ridx] = pid;
position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright();
} }
} }
} }
@ -535,7 +539,8 @@ class ColMaker: public IUpdater {
const bst_uint ridx = col[j].index; const bst_uint ridx = col[j].index;
const float fvalue = col[j].fvalue; const float fvalue = col[j].fvalue;
int nid = position[ridx]; int nid = position[ridx];
if (nid < 0) continue; if (nid < 0) nid = ~nid;
// go back to parent, correct those who are not default // go back to parent, correct those who are not default
nid = tree[nid].parent(); nid = tree[nid].parent();
if (tree[nid].split_index() == fid) { if (tree[nid].split_index() == fid) {

View File

@ -7,7 +7,10 @@
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include "../utils/bitmap.h" #include "../utils/bitmap.h"
#include "../utils/io.h"
#include "../sync/sync.h"
#include "./updater_colmaker-inl.hpp" #include "./updater_colmaker-inl.hpp"
#include "./updater_prune-inl.hpp"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -19,6 +22,7 @@ class DistColMaker : public ColMaker<TStats> {
// set training parameter // set training parameter
virtual void SetParam(const char *name, const char *val) { virtual void SetParam(const char *name, const char *val) {
param.SetParam(name, val); param.SetParam(name, val);
pruner.SetParam(name, val);
} }
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat, IFMatrix *p_fmat,
@ -26,15 +30,46 @@ class DistColMaker : public ColMaker<TStats> {
const std::vector<RegTree*> &trees) { const std::vector<RegTree*> &trees) {
TStats::CheckInfo(info); TStats::CheckInfo(info);
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time"); utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
// build the tree
builder.Update(gpair, p_fmat, info, trees[0]); builder.Update(gpair, p_fmat, info, trees[0]);
// prune the tree
pruner.Update(gpair, p_fmat, info, trees);
this->SyncTrees(trees[0]);
// update position after the tree is pruned
builder.UpdatePosition(p_fmat, *trees[0]);
} }
private: private:
inline void SyncTrees(RegTree *tree) {
std::string s_model;
utils::MemoryBufferStream fs(&s_model);
int rank = sync::GetRank();
if (rank == 0) {
tree->SaveModel(fs);
sync::Bcast(&s_model, 0);
} else {
sync::Bcast(&s_model, 0);
tree->LoadModel(fs);
}
}
struct Builder : public ColMaker<TStats>::Builder { struct Builder : public ColMaker<TStats>::Builder {
public: public:
Builder(const TrainParam &param) Builder(const TrainParam &param)
: ColMaker<TStats>::Builder(param) { : ColMaker<TStats>::Builder(param) {
} }
protected: inline void UpdatePosition(IFMatrix *p_fmat, const RegTree &tree) {
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i];
int nid = this->position[ridx];
if (nid < 0) {
}
}
}
protected:
virtual void SetNonDefaultPosition(const std::vector<int> &qexpand, virtual void SetNonDefaultPosition(const std::vector<int> &qexpand,
IFMatrix *p_fmat, const RegTree &tree) { IFMatrix *p_fmat, const RegTree &tree) {
// step 2, classify the non-default data into right places // step 2, classify the non-default data into right places
@ -80,8 +115,8 @@ class DistColMaker : public ColMaker<TStats> {
} }
} }
// communicate bitmap // communicate bitmap
//sync::AllReduce(); sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR);
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset(); const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
// get the new position // get the new position
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size()); const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
@ -100,19 +135,29 @@ class DistColMaker : public ColMaker<TStats> {
} }
// synchronize the best solution of each node // synchronize the best solution of each node
virtual void SyncBestSolution(const std::vector<int> &qexpand) { virtual void SyncBestSolution(const std::vector<int> &qexpand) {
std::vector<SplitEntry> vec;
for (size_t i = 0; i < qexpand.size(); ++i) { for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i]; const int nid = qexpand[i];
for (int tid = 0; tid < this->nthread; ++tid) { for (int tid = 0; tid < this->nthread; ++tid) {
this->snode[nid].best.Update(this->stemp[tid][nid].best); this->snode[nid].best.Update(this->stemp[tid][nid].best);
} }
vec.push_back(this->snode[nid].best);
} }
// communicate best solution // communicate best solution
// sync::AllReduce reducer.AllReduce(BeginPtr(vec), vec.size());
// assign solution back
for (size_t i = 0; i < qexpand.size(); ++i) {
const int nid = qexpand[i];
this->snode[nid].best = vec[i];
}
} }
private: private:
utils::BitMap bitmap; utils::BitMap bitmap;
sync::Reducer<SplitEntry> reducer;
}; };
// we directly introduce pruner here
TreePruner pruner;
// training parameter // training parameter
TrainParam param; TrainParam param;
// pointer to the builder // pointer to the builder

View File

@ -92,11 +92,49 @@ class IStream {
class ISeekStream: public IStream { class ISeekStream: public IStream {
public: public:
/*! \brief seek to certain position of the file */ /*! \brief seek to certain position of the file */
virtual void Seek(long pos) = 0; virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */ /*! \brief tell the position of the stream */
virtual long Tell(void) = 0; virtual size_t Tell(void) = 0;
}; };
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream {
public:
MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual ~MemoryBufferStream(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
/*! \brief implementation of file i/o stream */ /*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream { class FileStream : public ISeekStream {
public: public:
@ -110,10 +148,10 @@ class FileStream : public ISeekStream {
virtual void Write(const void *ptr, size_t size) { virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp); std::fwrite(ptr, size, 1, fp);
} }
virtual void Seek(long pos) { virtual void Seek(size_t pos) {
std::fseek(fp, pos, SEEK_SET); std::fseek(fp, pos, SEEK_SET);
} }
virtual long Tell(void) { virtual size_t Tell(void) {
return std::ftell(fp); return std::ftell(fp);
} }
inline void Close(void) { inline void Close(void) {