fix bug in lambda allreduce
This commit is contained in:
parent
5570e7ceae
commit
77d74f6c0d
@ -219,17 +219,39 @@ template<typename DType>
|
||||
inline SerializeReducer<DType>::SerializeReducer(void) {
|
||||
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
|
||||
}
|
||||
// closure to call Allreduce
|
||||
template<typename DType>
|
||||
struct SerializeReduceClosure {
|
||||
DType *sendrecvobj;
|
||||
size_t max_nbyte, count;
|
||||
void (*prepare_fun)(void *arg);
|
||||
void *prepare_arg;
|
||||
std::string *p_buffer;
|
||||
// invoke the closure
|
||||
inline void Run(void) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Save(fs);
|
||||
}
|
||||
}
|
||||
inline static void Invoke(void *c) {
|
||||
static_cast<SerializeReduceClosure<DType>*>(c)->Run();
|
||||
}
|
||||
};
|
||||
template<typename DType>
|
||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbyte, size_t count,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
buffer_.resize(max_nbyte);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Save(fs);
|
||||
}
|
||||
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, prepare_fun, prepare_arg);
|
||||
buffer_.resize(max_nbyte * count);
|
||||
// setup closure
|
||||
SerializeReduceClosure<DType> c;
|
||||
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
|
||||
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
||||
// invoke here
|
||||
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
|
||||
SerializeReduceClosure<DType>::Invoke, &c);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||
sendrecvobj[i].Load(fs);
|
||||
@ -240,13 +262,13 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
template<typename DType>
|
||||
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun) {
|
||||
this->AllReduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
||||
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
||||
}
|
||||
template<typename DType>
|
||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||
size_t max_nbytes, size_t count,
|
||||
std::function<void()> prepare_fun) {
|
||||
this->AllReduce(sendrecvobj, count, max_nbytes, InvokeLambda_, &prepare_fun);
|
||||
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun);
|
||||
}
|
||||
#endif
|
||||
} // namespace rabit
|
||||
|
||||
@ -49,6 +49,10 @@ inline void Finalize(void);
|
||||
inline int GetRank(void);
|
||||
/*! \brief get total number of process */
|
||||
inline int GetWorldSize(void);
|
||||
/*! \brief whether rabit env is in distributed mode */
|
||||
inline bool IsDistributed(void) {
|
||||
return GetWorldSize() != 1;
|
||||
}
|
||||
/*! \brief get name of processor */
|
||||
inline std::string GetProcessorName(void);
|
||||
/*!
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user