fix bug in lambda allreduce
This commit is contained in:
parent
5570e7ceae
commit
77d74f6c0d
@ -219,17 +219,39 @@ template<typename DType>
|
|||||||
inline SerializeReducer<DType>::SerializeReducer(void) {
|
inline SerializeReducer<DType>::SerializeReducer(void) {
|
||||||
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
|
handle_.Init(SerializeReducerFunc_<DType>, sizeof(DType));
|
||||||
}
|
}
|
||||||
|
// closure to call Allreduce
|
||||||
|
template<typename DType>
|
||||||
|
struct SerializeReduceClosure {
|
||||||
|
DType *sendrecvobj;
|
||||||
|
size_t max_nbyte, count;
|
||||||
|
void (*prepare_fun)(void *arg);
|
||||||
|
void *prepare_arg;
|
||||||
|
std::string *p_buffer;
|
||||||
|
// invoke the closure
|
||||||
|
inline void Run(void) {
|
||||||
|
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
|
||||||
|
sendrecvobj[i].Save(fs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline static void Invoke(void *c) {
|
||||||
|
static_cast<SerializeReduceClosure<DType>*>(c)->Run();
|
||||||
|
}
|
||||||
|
};
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||||
size_t max_nbyte, size_t count,
|
size_t max_nbyte, size_t count,
|
||||||
void (*prepare_fun)(void *arg),
|
void (*prepare_fun)(void *arg),
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
buffer_.resize(max_nbyte);
|
buffer_.resize(max_nbyte * count);
|
||||||
for (size_t i = 0; i < count; ++i) {
|
// setup closure
|
||||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
SerializeReduceClosure<DType> c;
|
||||||
sendrecvobj[i].Save(fs);
|
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
|
||||||
}
|
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
|
||||||
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, prepare_fun, prepare_arg);
|
// invoke here
|
||||||
|
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
|
||||||
|
SerializeReduceClosure<DType>::Invoke, &c);
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
|
||||||
sendrecvobj[i].Load(fs);
|
sendrecvobj[i].Load(fs);
|
||||||
@ -240,13 +262,13 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
|||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
inline void Reducer<DType>::Allreduce(DType *sendrecvbuf, size_t count,
|
||||||
std::function<void()> prepare_fun) {
|
std::function<void()> prepare_fun) {
|
||||||
this->AllReduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun);
|
||||||
}
|
}
|
||||||
template<typename DType>
|
template<typename DType>
|
||||||
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
|
||||||
size_t max_nbytes, size_t count,
|
size_t max_nbytes, size_t count,
|
||||||
std::function<void()> prepare_fun) {
|
std::function<void()> prepare_fun) {
|
||||||
this->AllReduce(sendrecvobj, count, max_nbytes, InvokeLambda_, &prepare_fun);
|
this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -49,6 +49,10 @@ inline void Finalize(void);
|
|||||||
inline int GetRank(void);
|
inline int GetRank(void);
|
||||||
/*! \brief get total number of process */
|
/*! \brief get total number of process */
|
||||||
inline int GetWorldSize(void);
|
inline int GetWorldSize(void);
|
||||||
|
/*! \brief whether rabit env is in distributed mode */
|
||||||
|
inline bool IsDistributed(void) {
|
||||||
|
return GetWorldSize() != 1;
|
||||||
|
}
|
||||||
/*! \brief get name of processor */
|
/*! \brief get name of processor */
|
||||||
inline std::string GetProcessorName(void);
|
inline std::string GetProcessorName(void);
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user