middle version
This commit is contained in:
@@ -5,29 +5,6 @@
|
||||
namespace xgboost {
|
||||
namespace sync {
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void) : handle(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {
|
||||
if (handle != NULL) {
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
||||
op->Free();
|
||||
delete op;
|
||||
}
|
||||
}
|
||||
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {
|
||||
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
|
||||
MPI::Op *op = new MPI::Op();
|
||||
MPI::User_function *pf = reinterpret_cast<MPI::User_function*>(redfunc);
|
||||
op->Init(pf, commute);
|
||||
handle = op;
|
||||
}
|
||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {
|
||||
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, n4byte, MPI_INT, *op);
|
||||
}
|
||||
|
||||
int GetRank(void) {
|
||||
return MPI::COMM_WORLD.Get_rank();
|
||||
}
|
||||
@@ -57,5 +34,37 @@ void AllReduce<float>(float *sendrecvbuf, int count, ReduceOp op) {
|
||||
AllReduce_(sendrecvbuf, count, MPI::FLOAT, op);
|
||||
}
|
||||
|
||||
void Bcast(std::string *sendrecv_data, int root) {
|
||||
unsigned len = static_cast<unsigned>(sendrecv_data->length());
|
||||
MPI::COMM_WORLD.Bcast(&len, 1, MPI::UNSIGNED, root);
|
||||
sendrecv_data->resize(len);
|
||||
if (len != 0) {
|
||||
MPI::COMM_WORLD.Bcast(&(*sendrecv_data)[0], len, MPI::CHAR, root);
|
||||
}
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void) : handle(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {
|
||||
if (handle != NULL) {
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
||||
op->Free();
|
||||
delete op;
|
||||
}
|
||||
}
|
||||
void ReduceHandle::Init(ReduceFunction redfunc, bool commute) {
|
||||
utils::Assert(handle == NULL, "cannot initialize reduce handle twice");
|
||||
MPI::Op *op = new MPI::Op();
|
||||
MPI::User_function *pf = reinterpret_cast<MPI::User_function*>(redfunc);
|
||||
op->Init(pf, commute);
|
||||
handle = op;
|
||||
}
|
||||
void ReduceHandle::AllReduce(void *sendrecvbuf, size_t n4byte) {
|
||||
utils::Assert(handle != NULL, "must intialize handle to call AllReduce");
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, n4byte, MPI_INT, *op);
|
||||
}
|
||||
|
||||
} // namespace sync
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -18,11 +18,39 @@ enum ReduceOp {
|
||||
kBitwiseOR
|
||||
};
|
||||
|
||||
typedef void (ReduceFunction) (const void *src, void *dst, int len);
|
||||
/*! \brief get rank of current process */
|
||||
int GetRank(void);
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]);
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void);
|
||||
|
||||
/* !\brief handle for customized reducer */
|
||||
/*!
|
||||
* \brief in-place all reduce operation
|
||||
* \param sendrecvbuf the in place send-recv buffer
|
||||
* \param count count of data
|
||||
* \param op reduction function
|
||||
*/
|
||||
template<typename DType>
|
||||
void AllReduce(DType *sendrecvbuf, int count, ReduceOp op);
|
||||
|
||||
/*!
|
||||
* \brief broadcast an std::string to all others from root
|
||||
* \param sendrecv_data the pointer to send or recive buffer,
|
||||
* receive buffer does not need to be pre-allocated
|
||||
* and string will be resized to correct length
|
||||
* \param root the root of process
|
||||
*/
|
||||
void Bcast(std::string *sendrecv_data, int root);
|
||||
|
||||
/*!
|
||||
* \brief handle for customized reducer
|
||||
* user do not need to use this, used Reducer instead
|
||||
*/
|
||||
class ReduceHandle {
|
||||
public:
|
||||
// reduce function
|
||||
typedef void (ReduceFunction) (const void *src, void *dst, int len);
|
||||
// constructor
|
||||
ReduceHandle(void);
|
||||
// destructor
|
||||
@@ -41,22 +69,8 @@ class ReduceHandle {
|
||||
void *handle;
|
||||
};
|
||||
|
||||
/*! \brief get rank of current process */
|
||||
int GetRank(void);
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]);
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void);
|
||||
// ----- extensions for ease of use ------
|
||||
/*!
|
||||
* \brief in-place all reduce operation
|
||||
* \param sendrecvbuf the in place send-recv buffer
|
||||
* \param count count of data
|
||||
* \param op reduction function
|
||||
*/
|
||||
template<typename DType>
|
||||
void AllReduce(DType *sendrecvbuf, int count, ReduceOp op);
|
||||
|
||||
/*!
|
||||
* \brief template class to make customized reduce and all reduce easy
|
||||
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
|
||||
* \tparam DType data type that to be reduced
|
||||
|
||||
Reference in New Issue
Block a user