Squashed 'subtree/rabit/' changes from 1db6449..85b7463
85b7463change def of reducer to take function ptrfe6366eadd engine basea98720emore deps git-subtree-dir: subtree/rabit git-subtree-split:85b746394e
This commit is contained in:
@@ -195,8 +195,8 @@ inline int VersionNumber(void) {
|
||||
// Code to handle customized Reduce
|
||||
// ---------------------------------
|
||||
// function to perform reduction for Reducer
|
||||
template<typename DType>
|
||||
inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||
inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||
const size_t kUnit = sizeof(DType);
|
||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||
char *pdst = reinterpret_cast<char*>(dst_);
|
||||
@@ -205,18 +205,32 @@ inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Data
|
||||
// use memcpy to avoid alignment issue
|
||||
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
||||
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
||||
tdst.Reduce(tsrc);
|
||||
freduce(tdst, tsrc);
|
||||
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
||||
}
|
||||
}
|
||||
template<typename DType>
|
||||
inline Reducer<DType>::Reducer(void) {
|
||||
this->handle_.Init(ReducerFunc_<DType>, sizeof(DType));
|
||||
// function to perform reduction for Reducer
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||
inline void ReducerAlign_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
||||
DType *pdst = reinterpret_cast<DType*>(dst_);
|
||||
for (int i = 0; i < len_; ++i) {
|
||||
freduce(pdst[i], psrc[i]);
|
||||
}
|
||||
}
|
||||
template<typename DType>
|
||||
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||
inline Reducer<DType, freduce>::Reducer(void) {
|
||||
// it is safe to directly use handle for aligned data types
|
||||
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
||||
this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType));
|
||||
} else {
|
||||
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
|
||||
}
|
||||
}
|
||||
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
|
||||
}
|
||||
// function to perform reduction for SerializeReducer
|
||||
|
||||
Reference in New Issue
Block a user