fix bug in lambda allreduce

This commit is contained in:
tqchen 2014-12-20 05:04:16 -08:00
parent 5570e7ceae
commit 77d74f6c0d
2 changed files with 34 additions and 8 deletions

View File

@ -219,17 +219,39 @@ template<typename DType>
inline SerializeReducer<DType>::SerializeReducer(void) { inline SerializeReducer<DType>::SerializeReducer(void) {
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType)); handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
} }
// closure to call Allreduce
template<typename DType>
struct SerializeReduceClosure {
DType *sendrecvobj;
size_t max_nbyte, count;
void (*prepare_fun)(void *arg);
void *prepare_arg;
std::string *p_buffer;
// invoke the closure
inline void Run(void) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Save(fs);
}
}
inline static void Invoke(void *c) {
static_cast<SerializeReduceClosure<DType>*>(c)->Run();
}
};
template<typename DType> template<typename DType>
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj, inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
size_t max_nbyte, size_t count, size_t max_nbyte, size_t count,
void (*prepare_fun)(void *arg), void (*prepare_fun)(void *arg),
void *prepare_arg) { void *prepare_arg) {
buffer_.resize(max_nbyte); buffer_.resize(max_nbyte * count);
for (size_t i = 0; i < count; ++i) { // setup closure
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte); SerializeReduceClosure<DType> c;
sendrecvobj[i].Save(fs); c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
} c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, prepare_fun, prepare_arg); // invoke here
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
SerializeReduceClosure<DType>::Invoke, &c);
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte); utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Load(fs); sendrecvobj[i].Load(fs);
@ -240,13 +262,13 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
template<typename DType> template<typename DType>
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count, inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun) { std::function<void()> prepare_fun) {
this->AllReduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun); this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
} }
template<typename DType> template<typename DType>
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj, inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
size_t max_nbytes, size_t count, size_t max_nbytes, size_t count,
std::function<void()> prepare_fun) { std::function<void()> prepare_fun) {
this->AllReduce(sendrecvobj, count, max_nbytes, InvokeLambda_, &prepare_fun); this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun);
} }
#endif #endif
} // namespace rabit } // namespace rabit

View File

@ -49,6 +49,10 @@ inline void Finalize(void);
inline int GetRank(void); inline int GetRank(void);
/*! \brief get total number of process */ /*! \brief get total number of process */
inline int GetWorldSize(void); inline int GetWorldSize(void);
/*! \brief whether rabit env is in distributed mode */
inline bool IsDistributed(void) {
return GetWorldSize() != 1;
}
/*! \brief get name of processor */ /*! \brief get name of processor */
inline std::string GetProcessorName(void); inline std::string GetProcessorName(void);
/*! /*!