add complex reducer in
This commit is contained in:
parent
2c0a0671ad
commit
e72a869fd1
@ -48,5 +48,27 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
|
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// code for reduce handle
|
||||||
|
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||||
|
}
|
||||||
|
ReduceHandle::~ReduceHandle(void) {}
|
||||||
|
|
||||||
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||||
|
return static_cast<int>(dtype.type_size);
|
||||||
|
}
|
||||||
|
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||||
|
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
||||||
|
handle_ = reinterpret_cast<void*>(redfunc);
|
||||||
|
}
|
||||||
|
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes, size_t count,
|
||||||
|
IEngine::PreprocFunction prepare_fun,
|
||||||
|
void *prepare_arg) {
|
||||||
|
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
|
||||||
|
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||||
|
reinterpret_cast<IEngine::ReduceFunction*>(handle_),
|
||||||
|
prepare_fun, prepare_arg);
|
||||||
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
42
src/engine.h
42
src/engine.h
@ -177,6 +177,48 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
mpi::OpType op,
|
mpi::OpType op,
|
||||||
IEngine::PreprocFunction prepare_fun = NULL,
|
IEngine::PreprocFunction prepare_fun = NULL,
|
||||||
void *prepare_arg = NULL);
|
void *prepare_arg = NULL);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief handle for customized reducer, used to handle customized reduce
|
||||||
|
* this class is mainly created for compatiblity issue with MPI's customized reduce
|
||||||
|
*/
|
||||||
|
class ReduceHandle {
|
||||||
|
public:
|
||||||
|
// constructor
|
||||||
|
ReduceHandle(void);
|
||||||
|
// destructor
|
||||||
|
~ReduceHandle(void);
|
||||||
|
/*!
|
||||||
|
* \brief initialize the reduce function,
|
||||||
|
* with the type the reduce function need to deal with
|
||||||
|
* the reduce function MUST be communicative
|
||||||
|
*/
|
||||||
|
void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes);
|
||||||
|
/*!
|
||||||
|
* \brief customized in-place all reduce operation
|
||||||
|
* \param sendrecvbuf the in place send-recv buffer
|
||||||
|
* \param type_n4bytes unit size of the type, in terms of 4bytes
|
||||||
|
* \param count number of elements to send
|
||||||
|
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||||
|
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||||
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
|
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||||
|
*/
|
||||||
|
void Allreduce(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes, size_t count,
|
||||||
|
IEngine::PreprocFunction prepare_fun = NULL,
|
||||||
|
void *prepare_arg = NULL);
|
||||||
|
/*! \return the number of bytes occupied by the type */
|
||||||
|
static int TypeSize(const MPI::Datatype &dtype);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// handle data field
|
||||||
|
void *handle_;
|
||||||
|
// handle to the type field
|
||||||
|
void *htype_;
|
||||||
|
// the created type in 4 bytes
|
||||||
|
size_t created_type_nbytes_;
|
||||||
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif // RABIT_ENGINE_H
|
#endif // RABIT_ENGINE_H
|
||||||
|
|||||||
@ -86,5 +86,19 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
IEngine::PreprocFunction prepare_fun,
|
IEngine::PreprocFunction prepare_fun,
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// code for reduce handle
|
||||||
|
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||||
|
}
|
||||||
|
ReduceHandle::~ReduceHandle(void) {}
|
||||||
|
|
||||||
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
|
||||||
|
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes, size_t count,
|
||||||
|
IEngine::PreprocFunction prepare_fun,
|
||||||
|
void *prepare_arg) {}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -124,5 +124,59 @@ void Allreduce_(void *sendrecvbuf,
|
|||||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
|
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// code for reduce handle
|
||||||
|
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||||
|
}
|
||||||
|
ReduceHandle::~ReduceHandle(void) {
|
||||||
|
if (handle_ != NULL) {
|
||||||
|
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
|
||||||
|
op->Free();
|
||||||
|
delete op;
|
||||||
|
}
|
||||||
|
if (htype_ != NULL) {
|
||||||
|
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
|
||||||
|
dtype->Free();
|
||||||
|
delete dtype;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||||
|
return dtype.Get_size();
|
||||||
|
}
|
||||||
|
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||||
|
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
||||||
|
if (type_nbytes != 0) {
|
||||||
|
MPI::Datatype *dtype = new MPI::Datatype();
|
||||||
|
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||||
|
dtype->Commit();
|
||||||
|
created_type_nbytes_ = type_nbytes;
|
||||||
|
htype_ = dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
MPI::Op *op = new MPI::Op();
|
||||||
|
MPI::User_function *pf = redfunc;
|
||||||
|
op->Init(pf, true);
|
||||||
|
handle_ = op;
|
||||||
|
}
|
||||||
|
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||||
|
size_t type_nbytes, size_t count,
|
||||||
|
IEngine::PreprocFunction prepare_fun,
|
||||||
|
void *prepare_arg) {
|
||||||
|
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
|
||||||
|
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
|
||||||
|
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
|
||||||
|
if (created_type_nbytes_ != type_nbytes || dtype == NULL) {
|
||||||
|
if (dtype == NULL) {
|
||||||
|
dtype = new MPI::Datatype();
|
||||||
|
} else {
|
||||||
|
dtype->Free();
|
||||||
|
}
|
||||||
|
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||||
|
dtype->Commit();
|
||||||
|
created_type_nbytes_ = type_nbytes;
|
||||||
|
}
|
||||||
|
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||||
|
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
|
||||||
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#define RABIT_RABIT_INL_H
|
#define RABIT_RABIT_INL_H
|
||||||
// use engine for implementation
|
// use engine for implementation
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
|
#include "./io.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
@ -170,5 +171,71 @@ inline void CheckPoint(const ISerializable *global_model,
|
|||||||
inline int VersionNumber(void) {
|
inline int VersionNumber(void) {
|
||||||
return engine::GetEngine()->VersionNumber();
|
return engine::GetEngine()->VersionNumber();
|
||||||
}
|
}
|
||||||
|
// ---------------------------------
|
||||||
|
// Code to handle customized Reduce
|
||||||
|
// ---------------------------------
|
||||||
|
// function to perform reduction for Reducer
|
||||||
|
template<typename DType>
|
||||||
|
inline void Reducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
|
const size_t kUnit = sizeof(DType);
|
||||||
|
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||||
|
char *pdst = reinterpret_cast<char*>(dst_);
|
||||||
|
DType tdst, tsrc;
|
||||||
|
for (size_t i = 0; i < len_; ++i) {
|
||||||
|
// use memcpy to avoid alignment issue
|
||||||
|
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
||||||
|
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
||||||
|
tdst.Reduce(tsrc);
|
||||||
|
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template<typename DType>
|
||||||
|
inline Reducer<DType>::Reducer(void) {
|
||||||
|
handle_.Init(Reducer<DType>::ReduceFunc, sizeof(DType));
|
||||||
|
}
|
||||||
|
template<typename DType>
|
||||||
|
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
|
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
|
||||||
|
}
|
||||||
|
// function to perform reduction for SerializeReducer
|
||||||
|
template<typename DType>
|
||||||
|
inline void
|
||||||
|
SerializeReducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
|
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
||||||
|
// temp space
|
||||||
|
DType tsrc, tdst;
|
||||||
|
for (int i = 0; i < len_; ++i) {
|
||||||
|
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
|
||||||
|
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
|
||||||
|
tsrc.Load(fsrc);
|
||||||
|
tdst.Load(fdst);
|
||||||
|
// govern const check
|
||||||
|
tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
|
||||||
|
fdst.Seek(0);
|
||||||
|
tdst.Save(fdst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template<typename DType>
|
||||||
|
inline SerializeReducer<DType>::SerializeReducer(void) {
|
||||||
|
handle_.Init(SerializeReducer<DType>::ReduceFunc, sizeof(DType));
|
||||||
|
}
|
||||||
|
template<typename DType>
|
||||||
|
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||||
|
size_t max_nbyte, size_t count,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
|
buffer_.resize(max_nbyte);
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||||
|
sendrecvobj[i].Save(fs);
|
||||||
|
}
|
||||||
|
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, prepare_fun, prepare_arg);
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||||
|
sendrecvobj[i].Load(fs);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
71
src/rabit.h
71
src/rabit.h
@ -183,6 +183,77 @@ inline void CheckPoint(const ISerializable *global_model,
|
|||||||
* \sa LoadCheckPoint, CheckPoint
|
* \sa LoadCheckPoint, CheckPoint
|
||||||
*/
|
*/
|
||||||
inline int VersionNumber(void);
|
inline int VersionNumber(void);
|
||||||
|
// ----- extensions that allow customized reducer ------
|
||||||
|
// helper class to do customized reduce, user do not need to know the type
|
||||||
|
namespace engine {
|
||||||
|
class ReduceHandle;
|
||||||
|
} // namespace engine
|
||||||
|
/*!
|
||||||
|
* \brief template class to make customized reduce and all reduce easy
|
||||||
|
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
|
||||||
|
* \tparam DType data type that to be reduced
|
||||||
|
* DType must be a struct, with no pointer, and contains a function Reduce(const DType &d);
|
||||||
|
*/
|
||||||
|
template<typename DType>
|
||||||
|
class Reducer {
|
||||||
|
public:
|
||||||
|
Reducer(void);
|
||||||
|
/*!
|
||||||
|
* \brief customized in-place all reduce operation
|
||||||
|
* \param sendrecvbuf the in place send-recv buffer
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||||
|
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||||
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
|
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||||
|
*/
|
||||||
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
|
void *prepare_arg = NULL);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// inner implementation of reducer
|
||||||
|
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
|
||||||
|
/*! \brief function handle to do reduce */
|
||||||
|
engine::ReduceHandle handle_;
|
||||||
|
};
|
||||||
|
/*!
|
||||||
|
* \brief template class to make customized reduce,
|
||||||
|
* this class defines complex reducer handles all the data structure that can be
|
||||||
|
* serialized/deserialzed into fixed size buffer
|
||||||
|
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
|
||||||
|
*
|
||||||
|
* \tparam DType data type that to be reduced, DType must contain following functions:
|
||||||
|
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
|
||||||
|
*/
|
||||||
|
template<typename DType>
|
||||||
|
class SerializeReducer {
|
||||||
|
public:
|
||||||
|
SerializeReducer(void);
|
||||||
|
/*!
|
||||||
|
* \brief customized in-place all reduce operation
|
||||||
|
* \param sendrecvobj pointer to the array of objects to be reduced
|
||||||
|
* \param max_nbyte maximum amount of memory needed to serialize each object
|
||||||
|
* this includes budget limit for intermediate and final result
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
|
||||||
|
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||||
|
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||||
|
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||||
|
*/
|
||||||
|
inline void Allreduce(DType *sendrecvobj,
|
||||||
|
size_t max_nbyte, size_t count,
|
||||||
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
|
void *prepare_arg = NULL);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// inner implementation of reducer
|
||||||
|
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
|
||||||
|
/*! \brief function handle to do reduce */
|
||||||
|
engine::ReduceHandle handle_;
|
||||||
|
/*! \brief temporal buffer used to do reduce*/
|
||||||
|
std::string buffer_;
|
||||||
|
};
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
// implementation of template functions
|
// implementation of template functions
|
||||||
#include "./rabit-inl.h"
|
#include "./rabit-inl.h"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user