add complex reducer in
This commit is contained in:
parent
2c0a0671ad
commit
e72a869fd1
@ -48,5 +48,27 @@ void Allreduce_(void *sendrecvbuf,
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return static_cast<int>(dtype.type_size);
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
||||
handle_ = reinterpret_cast<void*>(redfunc);
|
||||
}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
reinterpret_cast<IEngine::ReduceFunction*>(handle_),
|
||||
prepare_fun, prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
42
src/engine.h
42
src/engine.h
@ -177,6 +177,48 @@ void Allreduce_(void *sendrecvbuf,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
|
||||
/*!
|
||||
* \brief handle for customized reducer, used to handle customized reduce
|
||||
* this class is mainly created for compatiblity issue with MPI's customized reduce
|
||||
*/
|
||||
class ReduceHandle {
|
||||
public:
|
||||
// constructor
|
||||
ReduceHandle(void);
|
||||
// destructor
|
||||
~ReduceHandle(void);
|
||||
/*!
|
||||
* \brief initialize the reduce function,
|
||||
* with the type the reduce function need to deal with
|
||||
* the reduce function MUST be communicative
|
||||
*/
|
||||
void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes);
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation
|
||||
* \param sendrecvbuf the in place send-recv buffer
|
||||
* \param type_n4bytes unit size of the type, in terms of 4bytes
|
||||
* \param count number of elements to send
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
void Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
/*! \return the number of bytes occupied by the type */
|
||||
static int TypeSize(const MPI::Datatype &dtype);
|
||||
|
||||
private:
|
||||
// handle data field
|
||||
void *handle_;
|
||||
// handle to the type field
|
||||
void *htype_;
|
||||
// the created type in 4 bytes
|
||||
size_t created_type_nbytes_;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ENGINE_H
|
||||
|
||||
@ -86,5 +86,19 @@ void Allreduce_(void *sendrecvbuf,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return 0;
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@ -124,5 +124,59 @@ void Allreduce_(void *sendrecvbuf,
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {
|
||||
if (handle_ != NULL) {
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
|
||||
op->Free();
|
||||
delete op;
|
||||
}
|
||||
if (htype_ != NULL) {
|
||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
|
||||
dtype->Free();
|
||||
delete dtype;
|
||||
}
|
||||
}
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return dtype.Get_size();
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
||||
if (type_nbytes != 0) {
|
||||
MPI::Datatype *dtype = new MPI::Datatype();
|
||||
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||
dtype->Commit();
|
||||
created_type_nbytes_ = type_nbytes;
|
||||
htype_ = dtype;
|
||||
}
|
||||
|
||||
MPI::Op *op = new MPI::Op();
|
||||
MPI::User_function *pf = redfunc;
|
||||
op->Init(pf, true);
|
||||
handle_ = op;
|
||||
}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
|
||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
|
||||
if (created_type_nbytes_ != type_nbytes || dtype == NULL) {
|
||||
if (dtype == NULL) {
|
||||
dtype = new MPI::Datatype();
|
||||
} else {
|
||||
dtype->Free();
|
||||
}
|
||||
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||
dtype->Commit();
|
||||
created_type_nbytes_ = type_nbytes;
|
||||
}
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#define RABIT_RABIT_INL_H
|
||||
// use engine for implementation
|
||||
#include "./engine.h"
|
||||
#include "./io.h"
|
||||
#include "./utils.h"
|
||||
|
||||
namespace rabit {
|
||||
@ -170,5 +171,71 @@ inline void CheckPoint(const ISerializable *global_model,
|
||||
inline int VersionNumber(void) {
|
||||
return engine::GetEngine()->VersionNumber();
|
||||
}
|
||||
// ---------------------------------
|
||||
// Code to handle customized Reduce
|
||||
// ---------------------------------
|
||||
// function to perform reduction for Reducer
|
||||
template<typename DType>
|
||||
inline void Reducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||
const size_t kUnit = sizeof(DType);
|
||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||
char *pdst = reinterpret_cast<char*>(dst_);
|
||||
DType tdst, tsrc;
|
||||
for (size_t i = 0; i < len_; ++i) {
|
||||
// use memcpy to avoid alignment issue
|
||||
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
||||
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
||||
tdst.Reduce(tsrc);
|
||||
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
||||
}
|
||||
}
|
||||
template<typename DType>
|
||||
inline Reducer<DType>::Reducer(void) {
|
||||
handle_.Init(Reducer<DType>::ReduceFunc, sizeof(DType));
|
||||
}
|
||||
template<typename DType>
|
||||
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
|
||||
}
|
||||
// function to perform reduction for SerializeReducer
|
||||
template<typename DType>
|
||||
inline void
|
||||
SerializeReducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
||||
// temp space
|
||||
DType tsrc, tdst;
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
|
||||
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
|
||||
tsrc.Load(fsrc);
|
||||
tdst.Load(fdst);
|
||||
// govern const check
|
||||
tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
|
||||
fdst.Seek(0);
|
||||
tdst.Save(fdst);
|
||||
}
|
||||
}
|
||||
template<typename DType>
|
||||
inline SerializeReducer<DType>::SerializeReducer(void) {
|
||||
handle_.Init(SerializeReducer<DType>::ReduceFunc, sizeof(DType));
|
||||
}
|
||||
template<typename DType>
|
||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
buffer_.resize(max_nbyte);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Save(fs);
|
||||
}
|
||||
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, prepare_fun, prepare_arg);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Load(fs);
|
||||
}
|
||||
}
|
||||
} // namespace rabit
|
||||
#endif
|
||||
|
||||
71
src/rabit.h
71
src/rabit.h
@ -183,6 +183,77 @@ inline void CheckPoint(const ISerializable *global_model,
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
inline int VersionNumber(void);
|
||||
// ----- extensions that allow customized reducer ------
|
||||
// helper class to do customized reduce, user do not need to know the type
|
||||
namespace engine {
|
||||
class ReduceHandle;
|
||||
} // namespace engine
|
||||
/*!
|
||||
* \brief template class to make customized reduce and all reduce easy
|
||||
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
|
||||
* \tparam DType data type that to be reduced
|
||||
* DType must be a struct, with no pointer, and contains a function Reduce(const DType &d);
|
||||
*/
|
||||
template<typename DType>
|
||||
class Reducer {
|
||||
public:
|
||||
Reducer(void);
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation
|
||||
* \param sendrecvbuf the in place send-recv buffer
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg) = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
|
||||
private:
|
||||
// inner implementation of reducer
|
||||
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
|
||||
/*! \brief function handle to do reduce */
|
||||
engine::ReduceHandle handle_;
|
||||
};
|
||||
/*!
|
||||
* \brief template class to make customized reduce,
|
||||
* this class defines complex reducer handles all the data structure that can be
|
||||
* serialized/deserialzed into fixed size buffer
|
||||
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
|
||||
*
|
||||
* \tparam DType data type that to be reduced, DType must contain following functions:
|
||||
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
|
||||
*/
|
||||
template<typename DType>
|
||||
class SerializeReducer {
|
||||
public:
|
||||
SerializeReducer(void);
|
||||
/*!
|
||||
* \brief customized in-place all reduce operation
|
||||
* \param sendrecvobj pointer to the array of objects to be reduced
|
||||
* \param max_nbyte maximum amount of memory needed to serialize each object
|
||||
* this includes budget limit for intermediate and final result
|
||||
* \param count number of elements to be reduced
|
||||
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
inline void Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
void (*prepare_fun)(void *arg) = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
|
||||
private:
|
||||
// inner implementation of reducer
|
||||
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
|
||||
/*! \brief function handle to do reduce */
|
||||
engine::ReduceHandle handle_;
|
||||
/*! \brief temporal buffer used to do reduce*/
|
||||
std::string buffer_;
|
||||
};
|
||||
} // namespace rabit
|
||||
// implementation of template functions
|
||||
#include "./rabit-inl.h"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user