/*! * \file rabit-inl.h * \brief implementation of inline template function for rabit interface * * \author Tianqi Chen */ #ifndef RABIT_RABIT_INL_H #define RABIT_RABIT_INL_H namespace rabit { namespace engine { namespace mpi { // template function to translate type to enum indicator template inline DataType GetType(void); template<> inline DataType GetType(void) { return kInt; } template<> inline DataType GetType(void) { return kUInt; } template<> inline DataType GetType(void) { return kFloat; } template<> inline DataType GetType(void) { return kDouble; } } // namespace mpi } // namespace engine namespace op { struct Max { const static engine::mpi::OpType kType = engine::mpi::kMax; template inline static void Reduce(DType &dst, const DType &src) { if (dst < src) dst = src; } }; struct Min { const static engine::mpi::OpType kType = engine::mpi::kMin; template inline static void Reduce(DType &dst, const DType &src) { if (dst > src) dst = src; } }; struct Sum { const static engine::mpi::OpType kType = engine::mpi::kSum; template inline static void Reduce(DType &dst, const DType &src) { dst += src; } }; struct BitOR { const static engine::mpi::OpType kType = engine::mpi::kBitwiseOR; template inline static void Reduce(DType &dst, const DType &src) { dst |= src; } }; template inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { const DType *src = (const DType*)src_; DType *dst = (DType*)dst_; for (int i = 0; i < len; ++i) { OP::Reduce(dst[i], src[i]); } } } // namespace op // intialize the rabit engine inline void Init(int argc, char *argv[]) { engine::Init(argc, argv); } // finalize the rabit engine inline void Finalize(void) { engine::Finalize(); } // get the rank of current process inline int GetRank(void) { return engine::GetEngine()->GetRank(); } // the the size of the world inline int GetWorldSize(void) { return engine::GetEngine()->GetWorldSize(); } // get the name of current processor inline std::string GetProcessorName(void) { return engine::GetEngine()->GetHost(); } // broadcast data to all other nodes from root inline void Broadcast(void *sendrecv_data, size_t size, int root) { engine::GetEngine()->Broadcast(sendrecv_data, size, root); } template inline void Broadcast(std::vector *sendrecv_data, int root) { size_t size = sendrecv_data->size(); Broadcast(&size, sizeof(size), root); if (sendrecv_data->size() != size) { sendrecv_data->resize(size); } if (size != 0) { Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root); } } inline void Broadcast(std::string *sendrecv_data, int root) { size_t size = sendrecv_data->length(); Broadcast(&size, sizeof(size), root); if (sendrecv_data->length() != size) { sendrecv_data->resize(size); } if (size != 0) { Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root); } } // perform inplace Allreduce template inline void Allreduce(DType *sendrecvbuf, size_t count) { engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, engine::mpi::GetType(), OP::kType); } // load latest check point inline int LoadCheckPoint(utils::ISerializable *p_model) { return engine::GetEngine()->LoadCheckPoint(p_model); } // checkpoint the model, meaning we finished a stage of execution inline void CheckPoint(const utils::ISerializable &model) { engine::GetEngine()->CheckPoint(model); } // return the version number of currently stored model inline int VersionNumber(void) { return engine::GetEngine()->VersionNumber(); } } // namespace rabit #endif