diff --git a/include/rabit/internal/engine.h b/include/rabit/internal/engine.h index 98bfe7fb5..0db10e7f0 100644 --- a/include/rabit/internal/engine.h +++ b/include/rabit/internal/engine.h @@ -69,13 +69,14 @@ class IEngine { * \param _line caller line number used to generate unique cache key * \param _caller caller function name used to generate unique cache key */ - virtual void Allgather(void *sendrecvbuf, size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice, - const char* _file = _FILE, - const int _line = _LINE, - const char* _caller = _CALLER) = 0; + virtual void Allgather(void *sendrecvbuf, + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER) = 0; /*! * \brief performs in-place Allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -249,14 +250,18 @@ enum DataType { * \param slice_begin beginning of the current slice * \param slice_end end of the current slice * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType + * \param _file caller file name used to generate unique cache key + * \param _line caller line number used to generate unique cache key + * \param _caller caller function name used to generate unique cache key */ void Allgather(void* sendrecvbuf, size_t total_size, size_t slice_begin, size_t slice_end, - size_t size_prev_slice); + size_t size_prev_slice, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER); /*! * \brief perform in-place Allreduce, on sendrecvbuf * this is an internal function used by rabit to be able to compile with MPI diff --git a/include/rabit/internal/rabit-inl.h b/include/rabit/internal/rabit-inl.h index f56b37669..8ae604c4c 100644 --- a/include/rabit/internal/rabit-inl.h +++ b/include/rabit/internal/rabit-inl.h @@ -196,17 +196,22 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, engine::mpi::GetType(), OP::kType, InvokeLambda_, &prepare_fun, _file, _line, _caller); } -#endif // C++11 // Performs inplace Allgather template -inline void Allgather(DType *sendrecvbuf, size_t totalSize, - size_t beginIndex, size_t sizeNodeSlice, - size_t sizePrevSlice) { - engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType), +inline void Allgather(DType *sendrecvbuf, + size_t totalSize, + size_t beginIndex, + size_t sizeNodeSlice, + size_t sizePrevSlice, + const char* _file, + const int _line, + const char* _caller) { + engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType), (beginIndex + sizeNodeSlice) * sizeof(DType), - sizePrevSlice * sizeof(DType)); + sizePrevSlice * sizeof(DType), _file, _line, _caller); } +#endif // C++11 // print message to the tracker inline void TrackerPrint(const std::string &msg) { diff --git a/include/rabit/rabit.h b/include/rabit/rabit.h index da963ecbf..396354e68 100644 --- a/include/rabit/rabit.h +++ b/include/rabit/rabit.h @@ -205,6 +205,32 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, const int _line = _LINE, const char* _caller = _CALLER); +/*! +* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, +* the data provided by current node k is [slice_begin, slice_end), +* the next node's segment must start with slice_end +* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments +* use a ring based algorithm +* +* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually +* \param total_size total size of data to be gathered +* \param slice_begin beginning of the current slice +* \param slice_end end of the current slice +* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size +* \param _file caller file name used to generate unique cache key +* \param _line caller line number used to generate unique cache key +* \param _caller caller function name used to generate unique cache key +*/ +template +inline void Allgather(DType *sendrecvbuf_, + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file = _FILE, + const int _line = _LINE, + const char* _caller = _CALLER); + // C++11 support for lambda prepare function #if DMLC_USE_CXX11 /*! diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 95f429945..ab9f0e0e7 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -71,7 +71,8 @@ class AllreduceMock : public AllreduceRobust { this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather"); double tstart = utils::GetTime(); AllreduceRobust::Allgather(sendrecvbuf, total_size, - slice_begin, slice_end, size_prev_slice); + slice_begin, slice_end, + size_prev_slice, _file, _line, _caller); tsum_allgather += utils::GetTime() - tstart; } virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, diff --git a/src/c_api.cc b/src/c_api.cc index 61e5045dc..d85331983 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -131,33 +131,54 @@ void Allgather(void *sendrecvbuf_, switch (enum_dtype) { case kChar: type_size = sizeof(char); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; case kUChar: type_size = sizeof(unsigned char); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; case kInt: type_size = sizeof(int); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; case kUInt: type_size = sizeof(unsigned); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; case kLong: type_size = sizeof(int64_t); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; case kULong: type_size = sizeof(uint64_t); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; case kFloat: type_size = sizeof(float); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; case kDouble: type_size = sizeof(double); + rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, + beginIndex * type_size, (beginIndex + size_node_slice) * type_size, + size_prev_slice * type_size); break; default: utils::Error("unknown data_type"); } - engine::Allgather - (sendrecvbuf_, total_size * type_size, beginIndex * type_size, - (beginIndex + size_node_slice) * type_size, size_prev_slice * type_size); } // wrapper for serialization @@ -251,11 +272,11 @@ void RabitAllgather(void *sendrecvbuf_, size_t size_prev_slice, int enum_dtype) { rabit::c_api::Allgather(sendrecvbuf_, - total_size, - beginIndex, - size_node_slice, - size_prev_slice, - enum_dtype); + total_size, + beginIndex, + size_node_slice, + size_prev_slice, + static_cast(enum_dtype)); } diff --git a/src/engine.cc b/src/engine.cc index d43017d1a..4701d2fd7 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -88,8 +88,12 @@ IEngine *GetEngine() { void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, size_t slice_end, - size_t size_prev_slice) { - GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice); + size_t size_prev_slice, + const char* _file, + const int _line, + const char* _caller) { + GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, + slice_end, size_prev_slice, _file, _line, _caller); } diff --git a/src/engine_empty.cc b/src/engine_empty.cc index b548c1782..5cecc03cd 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -26,17 +26,18 @@ class EmptyEngine : public IEngine { version_number = 0; } virtual void Allgather(void *sendrecvbuf_, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice, - const char* _file, - const int _line, - const char* _caller) { + size_t total_size, + size_t slice_begin, + size_t slice_end, + size_t size_prev_slice, + const char* _file, + const int _line, + const char* _caller) { utils::Error("EmptyEngine:: Allgather is not supported"); } virtual int GetRingPrevRank(void) const { utils::Error("EmptyEngine:: GetRingPrevRank is not supported"); + return -1; } virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes,