diff --git a/src/rabit-inl.h b/src/rabit-inl.h index 679f6d49e..54e2c05d5 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -7,7 +7,6 @@ #ifndef RABIT_RABIT_INL_H #define RABIT_RABIT_INL_H // use engine for implementation -#include "./engine.h" #include "./io.h" #include "./utils.h" @@ -176,7 +175,7 @@ inline int VersionNumber(void) { // --------------------------------- // function to perform reduction for Reducer template -inline void Reducer::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { +inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { const size_t kUnit = sizeof(DType); const char *psrc = reinterpret_cast(src_); char *pdst = reinterpret_cast(dst_); @@ -191,7 +190,7 @@ inline void Reducer::ReduceFunc(const void *src_, void *dst_, int len_, c } template inline Reducer::Reducer(void) { - handle_.Init(Reducer::ReduceFunc, sizeof(DType)); + this->handle_.Init(ReducerFunc_, sizeof(DType)); } template inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, @@ -201,8 +200,7 @@ inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, } // function to perform reduction for SerializeReducer template -inline void -SerializeReducer::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { +inline void SerializeReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) { int nbytes = engine::ReduceHandle::TypeSize(dtype); // temp space DType tsrc, tdst; @@ -219,7 +217,7 @@ SerializeReducer::ReduceFunc(const void *src_, void *dst_, int len_, cons } template inline SerializeReducer::SerializeReducer(void) { - handle_.Init(SerializeReducer::ReduceFunc, sizeof(DType)); + handle_.Init(SerializeReducerFunc_, sizeof(DType)); } template inline void SerializeReducer::Allreduce(DType *sendrecvobj, @@ -237,5 +235,19 @@ inline void SerializeReducer::Allreduce(DType *sendrecvobj, sendrecvobj[i].Load(fs); } } + +#if __cplusplus >= 201103L +template +inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, + std::function prepare_fun) { + this->AllReduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun); +} +template +inline void SerializeReducer::Allreduce(DType *sendrecvobj, + size_t max_nbytes, size_t count, + std::function prepare_fun) { + this->AllReduce(sendrecvobj, count, max_nbytes, InvokeLambda_, &prepare_fun); +} +#endif } // namespace rabit #endif diff --git a/src/rabit.h b/src/rabit.h index 316da65c9..f5c94e1c9 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -17,6 +17,10 @@ #endif // C++11 // contains definition of ISerializable #include "./serializable.h" +// engine definition of rabit, defines internal implementation +// to use rabit interface, there is no need to read engine.h rabit.h and serializable.h +// is suffice to use the interface +#include "./engine.h" /*! \brief namespace of rabit */ namespace rabit { @@ -210,10 +214,17 @@ class Reducer { inline void Allreduce(DType *sendrecvbuf, size_t count, void (*prepare_fun)(void *arg) = NULL, void *prepare_arg = NULL); - +#if __cplusplus >= 201103L + /*! + * \brief customized in-place all reduce operation, with lambda function as preprocessor + * \param sendrecvbuf pointer to the array of objects to be reduced + * \param count number of elements to be reduced + * \param prepare_fun lambda function executed to prepare the data, if necessary + */ + inline void Allreduce(DType *sendrecvbuf, size_t count, + std::function prepare_fun); +#endif private: - // inner implementation of reducer - inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype); /*! \brief function handle to do reduce */ engine::ReduceHandle handle_; }; @@ -245,10 +256,21 @@ class SerializeReducer { size_t max_nbyte, size_t count, void (*prepare_fun)(void *arg) = NULL, void *prepare_arg = NULL); - +// C++11 support for lambda prepare function +#if __cplusplus >= 201103L + /*! + * \brief customized in-place all reduce operation, with lambda function as preprocessor + * \param sendrecvobj pointer to the array of objects to be reduced + * \param max_nbyte maximum amount of memory needed to serialize each object + * this includes budget limit for intermediate and final result + * \param count number of elements to be reduced + * \param prepare_fun lambda function executed to prepare the data, if necessary + */ + inline void Allreduce(DType *sendrecvobj, + size_t max_nbyte, size_t count, + std::function prepare_fun); +#endif private: - // inner implementation of reducer - inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype); /*! \brief function handle to do reduce */ engine::ReduceHandle handle_; /*! \brief temporal buffer used to do reduce*/