Merge commit 'ea50f8e030111f659dd69b89c86eba51abd39eba'

This commit is contained in:
tqchen
2015-01-19 21:26:25 -08:00
10 changed files with 78 additions and 24 deletions

View File

@@ -135,7 +135,6 @@ template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
// C++11 support for lambda prepare function
#if __cplusplus >= 201103L
/*!
@@ -238,11 +237,13 @@ class ReduceHandle;
} // namespace engine
/*!
* \brief template class to make customized reduce and all reduce easy
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
* Do not use reducer directly in the function you call Finalize,
* because the destructor can execute after Finalize
* \tparam DType data type that to be reduced
* DType must be a struct, with no pointer, and contain a function Reduce(const DType &d);
* \tparam freduce the customized reduction function
* DType must be a struct, with no pointer
*/
template<typename DType>
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
class Reducer {
public:
Reducer(void);
@@ -280,7 +281,8 @@ class Reducer {
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
*
* \tparam DType data type that to be reduced, DType must contain the following functions:
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
* \tparam freduce the customized reduction function
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte)
*/
template<typename DType>
class SerializeReducer {

View File

@@ -195,8 +195,8 @@ inline int VersionNumber(void) {
// Code to handle customized Reduce
// ---------------------------------
// function to perform reduction for Reducer
template<typename DType>
inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
const size_t kUnit = sizeof(DType);
const char *psrc = reinterpret_cast<const char*>(src_);
char *pdst = reinterpret_cast<char*>(dst_);
@@ -205,18 +205,32 @@ inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Data
// use memcpy to avoid alignment issue
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
tdst.Reduce(tsrc);
freduce(tdst, tsrc);
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
}
}
template<typename DType>
inline Reducer<DType>::Reducer(void) {
this->handle_.Init(ReducerFunc_<DType>, sizeof(DType));
// function to perform reduction for Reducer
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
inline void ReducerAlign_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
const DType *psrc = reinterpret_cast<const DType*>(src_);
DType *pdst = reinterpret_cast<DType*>(dst_);
for (int i = 0; i < len_; ++i) {
freduce(pdst[i], psrc[i]);
}
}
template<typename DType>
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
inline Reducer<DType, freduce>::Reducer(void) {
// it is safe to directly use handle for aligned data types
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType));
} else {
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
}
}
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
}
// function to perform reduction for SerializeReducer