diff --git a/Makefile b/Makefile index 9e827f43f..20045bbd6 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ ALIB= lib/librabit.a lib/librabit_mpi.a lib/librabit_empty.a lib/librabit_mock.a HEADERS=src/*.h include/*.h include/rabit/*.h .PHONY: clean all install mpi python -all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so +all: lib/librabit.a lib/librabit_mock.a wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so lib/librabit_base.a mpi: lib/librabit_mpi.a wrapper/librabit_wrapper_mpi.so python: wrapper/librabit_wrapper.so wrapper/librabit_wrapper_mock.so diff --git a/include/rabit.h b/include/rabit.h index 1c5c70e5f..eb2834b30 100644 --- a/include/rabit.h +++ b/include/rabit.h @@ -135,7 +135,6 @@ template inline void Allreduce(DType *sendrecvbuf, size_t count, void (*prepare_fun)(void *arg) = NULL, void *prepare_arg = NULL); - // C++11 support for lambda prepare function #if __cplusplus >= 201103L /*! @@ -238,11 +237,13 @@ class ReduceHandle; } // namespace engine /*! * \brief template class to make customized reduce and all reduce easy - * Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize + * Do not use reducer directly in the function you call Finalize, + * because the destructor can execute after Finalize * \tparam DType data type that to be reduced - * DType must be a struct, with no pointer, and contain a function Reduce(const DType &d); + * \tparam freduce the customized reduction function + * DType must be a struct, with no pointer */ -template +template class Reducer { public: Reducer(void); @@ -280,7 +281,8 @@ class Reducer { * Do not use reducer directly in the function you call Finalize, because the destructor can execute after Finalize * * \tparam DType data type that to be reduced, DType must contain the following functions: - * (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &d); + * \tparam freduce the customized reduction function + * (1) Save(IStream &fs) (2) Load(IStream &fs) (3) Reduce(const DType &src, size_t max_nbyte) */ template class SerializeReducer { diff --git a/include/rabit/rabit-inl.h b/include/rabit/rabit-inl.h index 4ffd812e7..d2a20eecb 100644 --- a/include/rabit/rabit-inl.h +++ b/include/rabit/rabit-inl.h @@ -195,8 +195,8 @@ inline int VersionNumber(void) { // Code to handle customized Reduce // --------------------------------- // function to perform reduction for Reducer -template -inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { +template +inline void ReducerSafe_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { const size_t kUnit = sizeof(DType); const char *psrc = reinterpret_cast(src_); char *pdst = reinterpret_cast(dst_); @@ -205,18 +205,32 @@ inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Data // use memcpy to avoid alignment issue std::memcpy(&tdst, pdst + i * kUnit, sizeof(tdst)); std::memcpy(&tsrc, psrc + i * kUnit, sizeof(tsrc)); - tdst.Reduce(tsrc); + freduce(tdst, tsrc); std::memcpy(pdst + i * kUnit, &tdst, sizeof(tdst)); } } -template -inline Reducer::Reducer(void) { - this->handle_.Init(ReducerFunc_, sizeof(DType)); +// function to perform reduction for Reducer +template +inline void ReducerAlign_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { + const DType *psrc = reinterpret_cast(src_); + DType *pdst = reinterpret_cast(dst_); + for (int i = 0; i < len_; ++i) { + freduce(pdst[i], psrc[i]); + } } -template -inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *arg), - void *prepare_arg) { +template +inline Reducer::Reducer(void) { + // it is safe to directly use handle for aligned data types + if (sizeof(DType) == 8 || sizeof(DType) == 4 || sizeof(DType) == 1) { + this->handle_.Init(ReducerAlign_, sizeof(DType)); + } else { + this->handle_.Init(ReducerSafe_, sizeof(DType)); + } +} +template +inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, + void (*prepare_fun)(void *arg), + void *prepare_arg) { handle_.Allreduce(sendrecvbuf, sizeof(DType), count, prepare_fun, prepare_arg); } // function to perform reduction for SerializeReducer diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 9962745a2..829051231 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -159,12 +159,17 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice"); if (type_nbytes != 0) { MPI::Datatype *dtype = new MPI::Datatype(); - *dtype = MPI::CHAR.Create_contiguous(type_nbytes); + if (type_nbytes % 8 == 0) { + *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); + } else if (type_nbytes % 4 == 0) { + *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); + } else { + *dtype = MPI::CHAR.Create_contiguous(type_nbytes); + } dtype->Commit(); created_type_nbytes_ = type_nbytes; htype_ = dtype; } - MPI::Op *op = new MPI::Op(); MPI::User_function *pf = redfunc; op->Init(pf, true); @@ -183,7 +188,13 @@ void ReduceHandle::Allreduce(void *sendrecvbuf, } else { dtype->Free(); } - *dtype = MPI::CHAR.Create_contiguous(type_nbytes); + if (type_nbytes % 8 == 0) { + *dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); + } else if (type_nbytes % 4 == 0) { + *dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int)); + } else { + *dtype = MPI::CHAR.Create_contiguous(type_nbytes); + } dtype->Commit(); created_type_nbytes_ = type_nbytes; }