diff --git a/src/rabit-inl.h b/src/rabit-inl.h index 54e2c05d5..20ee39720 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -219,17 +219,39 @@ template inline SerializeReducer::SerializeReducer(void) { handle_.Init(SerializeReducerFunc_, sizeof(DType)); } +// closure to call Allreduce +template +struct SerializeReduceClosure { + DType *sendrecvobj; + size_t max_nbyte, count; + void (*prepare_fun)(void *arg); + void *prepare_arg; + std::string *p_buffer; + // invoke the closure + inline void Run(void) { + if (prepare_fun != NULL) prepare_fun(prepare_arg); + for (size_t i = 0; i < count; ++i) { + utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte); + sendrecvobj[i].Save(fs); + } + } + inline static void Invoke(void *c) { + static_cast*>(c)->Run(); + } +}; template inline void SerializeReducer::Allreduce(DType *sendrecvobj, size_t max_nbyte, size_t count, void (*prepare_fun)(void *arg), void *prepare_arg) { - buffer_.resize(max_nbyte); - for (size_t i = 0; i < count; ++i) { - utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte); - sendrecvobj[i].Save(fs); - } - handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, prepare_fun, prepare_arg); + buffer_.resize(max_nbyte * count); + // setup closure + SerializeReduceClosure c; + c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count; + c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_; + // invoke here + handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count, + SerializeReduceClosure::Invoke, &c); for (size_t i = 0; i < count; ++i) { utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte); sendrecvobj[i].Load(fs); @@ -240,13 +262,13 @@ inline void SerializeReducer::Allreduce(DType *sendrecvobj, template inline void Reducer::Allreduce(DType *sendrecvbuf, size_t count, std::function prepare_fun) { - this->AllReduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun); + this->Allreduce(sendrecvbuf, count, InvokeLambda_, &prepare_fun); } template inline void SerializeReducer::Allreduce(DType *sendrecvobj, size_t max_nbytes, size_t count, std::function prepare_fun) { - this->AllReduce(sendrecvobj, count, max_nbytes, InvokeLambda_, &prepare_fun); + this->Allreduce(sendrecvobj, max_nbytes, count, InvokeLambda_, &prepare_fun); } #endif } // namespace rabit diff --git a/src/rabit.h b/src/rabit.h index f5c94e1c9..834c21fd1 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -49,6 +49,10 @@ inline void Finalize(void); inline int GetRank(void); /*! \brief get total number of process */ inline int GetWorldSize(void); +/*! \brief whether rabit env is in distributed mode */ +inline bool IsDistributed(void) { + return GetWorldSize() != 1; +} /*! \brief get name of processor */ inline std::string GetProcessorName(void); /*!