add complex types
This commit is contained in:
parent
e72a869fd1
commit
5570e7ceae
@ -7,7 +7,6 @@
|
|||||||
#ifndef RABIT_RABIT_INL_H
|
#ifndef RABIT_RABIT_INL_H
|
||||||
#define RABIT_RABIT_INL_H
|
#define RABIT_RABIT_INL_H
|
||||||
// use engine for implementation
|
// use engine for implementation
|
||||||
#include "./engine.h"
|
|
||||||
#include "./io.h"
|
#include "./io.h"
|
||||||
#include "./utils.h"
|
#include "./utils.h"
|
||||||
|
|
||||||
@ -176,7 +175,7 @@ inline int VersionNumber(void) {
|
|||||||
// ---------------------------------
|
// ---------------------------------
|
||||||
// function to perform reduction for Reducer
|
// function to perform reduction for Reducer
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void Reducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
inline void ReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
const size_t kUnit = sizeof(DType);
|
const size_t kUnit = sizeof(DType);
|
||||||
const char *psrc = reinterpret_cast<const char*>(src_);
|
const char *psrc = reinterpret_cast<const char*>(src_);
|
||||||
char *pdst = reinterpret_cast<char*>(dst_);
|
char *pdst = reinterpret_cast<char*>(dst_);
|
||||||
@ -191,7 +190,7 @@ inline void Reducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, c
|
|||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline Reducer<DType>::Reducer(void) {
|
inline Reducer<DType>::Reducer(void) {
|
||||||
handle_.Init(Reducer<DType>::ReduceFunc, sizeof(DType));
|
this->handle_.Init(ReducerFunc_<DType>, sizeof(DType));
|
||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
@ -201,8 +200,7 @@ inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
|||||||
}
|
}
|
||||||
// function to perform reduction for SerializeReducer
|
// function to perform reduction for SerializeReducer
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void
|
inline void SerializeReducerFunc_(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
||||||
SerializeReducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype) {
|
|
||||||
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
int nbytes = engine::ReduceHandle::TypeSize(dtype);
|
||||||
// temp space
|
// temp space
|
||||||
DType tsrc, tdst;
|
DType tsrc, tdst;
|
||||||
@ -219,7 +217,7 @@ SerializeReducer<DType>::ReduceFunc(const void *src_, void *dst_, int len_, cons
|
|||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline SerializeReducer<DType>::SerializeReducer(void) {
|
inline SerializeReducer<DType>::SerializeReducer(void) {
|
||||||
handle_.Init(SerializeReducer<DType>::ReduceFunc, sizeof(DType));
|
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
|
||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||||
@ -237,5 +235,19 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|||||||
sendrecvobj[i].Load(fs);
|
sendrecvobj[i].Load(fs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
template<typename DType>
|
||||||
|
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
|
std::function<void()> prepare_fun) {
|
||||||
|
this->AllReduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
||||||
|
}
|
||||||
|
template<typename DType>
|
||||||
|
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||||
|
size_t max_nbytes, size_t count,
|
||||||
|
std::function<void()> prepare_fun) {
|
||||||
|
this->AllReduce(sendrecvobj, count, max_nbytes, InvokeLambda_, &prepare_fun);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
34
src/rabit.h
34
src/rabit.h
@ -17,6 +17,10 @@
|
|||||||
#endif // C++11
|
#endif // C++11
|
||||||
// contains definition of ISerializable
|
// contains definition of ISerializable
|
||||||
#include "./serializable.h"
|
#include "./serializable.h"
|
||||||
|
// engine definition of rabit, defines internal implementation
|
||||||
|
// to use rabit interface, there is no need to read engine.h rabit.h and serializable.h
|
||||||
|
// is suffice to use the interface
|
||||||
|
#include "./engine.h"
|
||||||
|
|
||||||
/*! \brief namespace of rabit */
|
/*! \brief namespace of rabit */
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
@ -210,10 +214,17 @@ class Reducer {
|
|||||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
void (*prepare_fun)(void *arg) = NULL,
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
void *prepare_arg = NULL);
|
void *prepare_arg = NULL);
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
/*!
|
||||||
|
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
||||||
|
* \param sendrecvbuf pointer to the array of objects to be reduced
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||||
|
*/
|
||||||
|
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
|
std::function<void()> prepare_fun);
|
||||||
|
#endif
|
||||||
private:
|
private:
|
||||||
// inner implementation of reducer
|
|
||||||
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
|
|
||||||
/*! \brief function handle to do reduce */
|
/*! \brief function handle to do reduce */
|
||||||
engine::ReduceHandle handle_;
|
engine::ReduceHandle handle_;
|
||||||
};
|
};
|
||||||
@ -245,10 +256,21 @@ class SerializeReducer {
|
|||||||
size_t max_nbyte, size_t count,
|
size_t max_nbyte, size_t count,
|
||||||
void (*prepare_fun)(void *arg) = NULL,
|
void (*prepare_fun)(void *arg) = NULL,
|
||||||
void *prepare_arg = NULL);
|
void *prepare_arg = NULL);
|
||||||
|
// C++11 support for lambda prepare function
|
||||||
|
#if __cplusplus >= 201103L
|
||||||
|
/*!
|
||||||
|
* \brief customized in-place all reduce operation, with lambda function as preprocessor
|
||||||
|
* \param sendrecvobj pointer to the array of objects to be reduced
|
||||||
|
* \param max_nbyte maximum amount of memory needed to serialize each object
|
||||||
|
* this includes budget limit for intermediate and final result
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param prepare_fun lambda function executed to prepare the data, if necessary
|
||||||
|
*/
|
||||||
|
inline void Allreduce(DType *sendrecvobj,
|
||||||
|
size_t max_nbyte, size_t count,
|
||||||
|
std::function<void()> prepare_fun);
|
||||||
|
#endif
|
||||||
private:
|
private:
|
||||||
// inner implementation of reducer
|
|
||||||
inline static void ReduceFunc(const void *src_, void *dst_, int len_, const MPI::Datatype &dtype);
|
|
||||||
/*! \brief function handle to do reduce */
|
/*! \brief function handle to do reduce */
|
||||||
engine::ReduceHandle handle_;
|
engine::ReduceHandle handle_;
|
||||||
/*! \brief temporal buffer used to do reduce*/
|
/*! \brief temporal buffer used to do reduce*/
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user