add complex reducer in

This commit is contained in:
tqchen 2014-12-19 20:57:53 -08:00
parent 2c0a0671ad
commit e72a869fd1
6 changed files with 270 additions and 0 deletions

View File

@ -48,5 +48,27 @@ void Allreduce_(void *sendrecvbuf,
void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size);
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
handle_ = reinterpret_cast<void*>(redfunc);
}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
reinterpret_cast<IEngine::ReduceFunction*>(handle_),
prepare_fun, prepare_arg);
}
} // namespace engine
} // namespace rabit

View File

@ -177,6 +177,48 @@ void Allreduce_(void *sendrecvbuf,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
/*!
* \brief handle for customized reducer, used to handle customized reduce
* this class is mainly created for compatiblity issue with MPI's customized reduce
*/
class ReduceHandle {
public:
// constructor
ReduceHandle(void);
// destructor
~ReduceHandle(void);
/*!
* \brief initialize the reduce function,
* with the type the reduce function need to deal with
* the reduce function MUST be communicative
*/
void Init(IEngine::ReduceFunction redfunc, size_t type_nbytes);
/*!
* \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer
* \param type_n4bytes unit size of the type, in terms of 4bytes
* \param count number of elements to send
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
void Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
/*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype);
private:
// handle data field
void *handle_;
// handle to the type field
void *htype_;
// the created type in 4 bytes
size_t created_type_nbytes_;
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_H

View File

@ -86,5 +86,19 @@ void Allreduce_(void *sendrecvbuf,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0;
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {}
} // namespace engine
} // namespace rabit

View File

@ -124,5 +124,59 @@ void Allreduce_(void *sendrecvbuf,
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {
if (handle_ != NULL) {
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
op->Free();
delete op;
}
if (htype_ != NULL) {
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
dtype->Free();
delete dtype;
}
}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return dtype.Get_size();
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
if (type_nbytes != 0) {
MPI::Datatype *dtype = new MPI::Datatype();
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
dtype->Commit();
created_type_nbytes_ = type_nbytes;
htype_ = dtype;
}
MPI::Op *op = new MPI::Op();
MPI::User_function *pf = redfunc;
op->Init(pf, true);
handle_ = op;
}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
if (created_type_nbytes_ != type_nbytes || dtype == NULL) {
if (dtype == NULL) {
dtype = new MPI::Datatype();
} else {
dtype->Free();
}
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
dtype->Commit();
created_type_nbytes_ = type_nbytes;
}
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
}
} // namespace engine
} // namespace rabit

View File

@ -8,6 +8,7 @@
#define RABIT_RABIT_INL_H
// use engine for implementation
#include "./engine.h"
#include "./io.h"
#include "./utils.h"
namespace rabit {
@ -170,5 +171,71 @@ inline void CheckPoint(const ISerializable *global_model,
inline int VersionNumber(void) {
return engine::GetEngine()->VersionNumber();
}
// ---------------------------------
// Code to handle customized Reduce
// ---------------------------------
// function to perform reduction for Reducer
template<typename DType>
inline void Reducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
const size_t kUnit = sizeof(DType);
const char *psrc = reinterpret_cast<const char*>(src_);
char *pdst = reinterpret_cast<char*>(dst_);
DType tdst, tsrc;
for (size_t i = 0; i < len_; ++i) {
// use memcpy to avoid alignment issue
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
tdst.Reduce(tsrc);
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
}
}
template<typename DType>
inline Reducer<DType>::Reducer(void) {
handle_.Init(Reducer<DType>::ReduceFunc, sizeof(DType));
}
template<typename DType>
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
}
// function to perform reduction for SerializeReducer
template<typename DType>
inline void
SerializeReducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
int nbytes = engine::ReduceHandle::TypeSize(dtype);
// temp space
DType tsrc, tdst;
for (int i = 0; i < len_; ++i) {
utils::MemoryFixSizeBuffer fsrc((char*)(src_) + i * nbytes, nbytes);
utils::MemoryFixSizeBuffer fdst((char*)(dst_) + i * nbytes, nbytes);
tsrc.Load(fsrc);
tdst.Load(fdst);
// govern const check
tdst.Reduce(static_cast<const DType &>(tsrc), nbytes);
fdst.Seek(0);
tdst.Save(fdst);
}
}
template<typename DType>
inline SerializeReducer<DType>::SerializeReducer(void) {
handle_.Init(SerializeReducer<DType>::ReduceFunc, sizeof(DType));
}
template<typename DType>
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
buffer_.resize(max_nbyte);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Save(fs);
}
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, prepare_fun, prepare_arg);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Load(fs);
}
}
} // namespace rabit
#endif

View File

@ -183,6 +183,77 @@ inline void CheckPoint(const ISerializable *global_model,
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber(void);
// ----- extensions that allow customized reducer ------
// helper class to do customized reduce, user do not need to know the type
namespace engine {
class ReduceHandle;
} // namespace engine
/*!
* \brief template class to make customized reduce and all reduce easy
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
* \tparam DType data type that to be reduced
* DType must be a struct, with no pointer, and contains a function Reduce(const DType &d);
*/
template<typename DType>
class Reducer {
public:
Reducer(void);
/*!
* \brief customized in-place all reduce operation
* \param sendrecvbuf the in place send-recv buffer
* \param count number of elements to be reduced
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
private:
// inner implementation of reducer
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
};
/*!
* \brief template class to make customized reduce,
* this class defines complex reducer handles all the data structure that can be
* serialized/deserialzed into fixed size buffer
* Do not use reducer directly in the function you call Finalize, because the destructor can happen after Finalize
*
* \tparam DType data type that to be reduced, DType must contain following functions:
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
*/
template<typename DType>
class SerializeReducer {
public:
SerializeReducer(void);
/*!
* \brief customized in-place all reduce operation
* \param sendrecvobj pointer to the array of objects to be reduced
* \param max_nbyte maximum amount of memory needed to serialize each object
* this includes budget limit for intermediate and final result
* \param count number of elements to be reduced
* \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
inline void Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
private:
// inner implementation of reducer
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
/*! \brief temporal buffer used to do reduce*/
std::string buffer_;
};
} // namespace rabit
// implementation of template functions
#include "./rabit-inl.h"