change def of reducer to take function ptr
This commit is contained in:
parent
fe6366eb40
commit
85b746394e
2
Makefile
2
Makefile
@ -16,7 +16,7 @@ ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a
|
|||||||
HEADERS=src/*.h include/*.h include/rabit/*.h
|
HEADERS=src/*.h include/*.h include/rabit/*.h
|
||||||
.PHONY: clean all install mpi python
|
.PHONY: clean all install mpi python
|
||||||
|
|
||||||
all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so lib/librabit_base.a
|
||||||
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
|
mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so
|
||||||
python: wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
python: wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so
|
||||||
|
|
||||||
|
|||||||
@ -135,7 +135,6 @@ template<typename OP, typename DType>
|
|||||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
void (*prepare_fun)(void *arg) = NULL,
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
void *prepare_arg = NULL);
|
void *prepare_arg = NULL);
|
||||||
|
|
||||||
// C++11 support for lambda prepare function
|
// C++11 support for lambda prepare function
|
||||||
#if __cplusplus >= 201103L
|
#if __cplusplus >= 201103L
|
||||||
/*!
|
/*!
|
||||||
@ -238,11 +237,13 @@ class ReduceHandle;
|
|||||||
} // namespace engine
|
} // namespace engine
|
||||||
/*!
|
/*!
|
||||||
* \brief template class to make customized reduce and all reduce easy
|
* \brief template class to make customized reduce and all reduce easy
|
||||||
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
|
* Do not use reducer directly in the function you call Finalize,
|
||||||
|
* because the destructor can execute after Finalize
|
||||||
* \tparam DType data type that to be reduced
|
* \tparam DType data type that to be reduced
|
||||||
* DType must be a struct, with no pointer, and contain a function Reduce(const DType &d);
|
* \tparam freduce the customized reduction function
|
||||||
|
* DType must be a struct, with no pointer
|
||||||
*/
|
*/
|
||||||
template<typename DType>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||||
class Reducer {
|
class Reducer {
|
||||||
public:
|
public:
|
||||||
Reducer(void);
|
Reducer(void);
|
||||||
@ -280,7 +281,8 @@ class Reducer {
|
|||||||
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
|
* Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize
|
||||||
*
|
*
|
||||||
* \tparam DType data type that to be reduced, DType must contain the following functions:
|
* \tparam DType data type that to be reduced, DType must contain the following functions:
|
||||||
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d);
|
* \tparam freduce the customized reduction function
|
||||||
|
* (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte)
|
||||||
*/
|
*/
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
class SerializeReducer {
|
class SerializeReducer {
|
||||||
|
|||||||
@ -195,8 +195,8 @@ inline int VersionNumber(void) {
|
|||||||
// Code to handle customized Reduce
|
// Code to handle customized Reduce
|
||||||
// ---------------------------------
|
// ---------------------------------
|
||||||
// function to perform reduction for Reducer
|
// function to perform reduction for Reducer
|
||||||
template<typename DType>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||||
inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
const size_t kUnit = sizeof(DType);
|
const size_t kUnit = sizeof(DType);
|
||||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||||
char *pdst = reinterpret_cast<char*>(dst_);
|
char *pdst = reinterpret_cast<char*>(dst_);
|
||||||
@ -205,18 +205,32 @@ inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Data
|
|||||||
// use memcpy to avoid alignment issue
|
// use memcpy to avoid alignment issue
|
||||||
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst));
|
||||||
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc));
|
||||||
tdst.Reduce(tsrc);
|
freduce(tdst, tsrc);
|
||||||
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template<typename DType>
|
// function to perform reduction for Reducer
|
||||||
inline Reducer<DType>::Reducer(void) {
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||||
this->handle_.Init(ReducerFunc_<DType>, sizeof(DType));
|
inline void ReducerAlign_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
|
const DType *psrc = reinterpret_cast<const DType*>(src_);
|
||||||
|
DType *pdst = reinterpret_cast<DType*>(dst_);
|
||||||
|
for (int i = 0; i < len_; ++i) {
|
||||||
|
freduce(pdst[i], psrc[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||||
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
inline Reducer<DType, freduce>::Reducer(void) {
|
||||||
void (*prepare_fun)(void *arg),
|
// it is safe to directly use handle for aligned data types
|
||||||
void *prepare_arg) {
|
if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) {
|
||||||
|
this->handle_.Init(ReducerAlign_<DType, freduce>, sizeof(DType));
|
||||||
|
} else {
|
||||||
|
this->handle_.Init(ReducerSafe_<DType, freduce>, sizeof(DType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template<typename DType, void (*freduce)(DType &dst, const DType &src)>
|
||||||
|
inline void Reducer<DType, freduce>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
|
void (*prepare_fun)(void *arg),
|
||||||
|
void *prepare_arg) {
|
||||||
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
|
handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg);
|
||||||
}
|
}
|
||||||
// function to perform reduction for SerializeReducer
|
// function to perform reduction for SerializeReducer
|
||||||
|
|||||||
@ -159,12 +159,17 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
|||||||
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
||||||
if (type_nbytes != 0) {
|
if (type_nbytes != 0) {
|
||||||
MPI::Datatype *dtype = new MPI::Datatype();
|
MPI::Datatype *dtype = new MPI::Datatype();
|
||||||
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
if (type_nbytes % 8 == 0) {
|
||||||
|
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long));
|
||||||
|
} else if (type_nbytes % 4 == 0) {
|
||||||
|
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
|
||||||
|
} else {
|
||||||
|
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||||
|
}
|
||||||
dtype->Commit();
|
dtype->Commit();
|
||||||
created_type_nbytes_ = type_nbytes;
|
created_type_nbytes_ = type_nbytes;
|
||||||
htype_ = dtype;
|
htype_ = dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
MPI::Op *op = new MPI::Op();
|
MPI::Op *op = new MPI::Op();
|
||||||
MPI::User_function *pf = redfunc;
|
MPI::User_function *pf = redfunc;
|
||||||
op->Init(pf, true);
|
op->Init(pf, true);
|
||||||
@ -183,7 +188,13 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
|
|||||||
} else {
|
} else {
|
||||||
dtype->Free();
|
dtype->Free();
|
||||||
}
|
}
|
||||||
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
if (type_nbytes % 8 == 0) {
|
||||||
|
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long));
|
||||||
|
} else if (type_nbytes % 4 == 0) {
|
||||||
|
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
|
||||||
|
} else {
|
||||||
|
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||||
|
}
|
||||||
dtype->Commit();
|
dtype->Commit();
|
||||||
created_type_nbytes_ = type_nbytes;
|
created_type_nbytes_ = type_nbytes;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user