Expose RabitAllGatherRing and RabitGetRingPrevRank (#113)

* add unittests

* Expose RabitAllGatherRing and RabitGetRingPrevRank

* Enabled TCP_NODELAY to decrease latency
This commit is contained in:
nateagr
2019-11-12 12:55:32 +01:00
committed by Jiaming Yuan
parent 90e2239372
commit 1907b25cd0
15 changed files with 509 additions and 25 deletions

View File

@@ -43,6 +43,12 @@ RABIT_DLL bool RabitInit(int argc, char *argv[]);
*/
RABIT_DLL bool RabitFinalize(void);
/*!
* \brief get rank of previous process in ring topology
* \return rank number of worker
* */
RABIT_DLL int RabitGetRingPrevRank(void);
/*!
* \brief get rank of current process
* \return rank number of worker
@@ -87,6 +93,30 @@ RABIT_DLL void RabitGetProcessorName(char *out_name,
*/
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype
* \param size_node_slice size of the current node slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
RABIT_DLL void RabitAllgather(void *sendrecvbuf,
size_t total_size,
size_t beginIndex,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe

View File

@@ -53,6 +53,29 @@ class IEngine {
const MPI::Datatype &dtype);
/*! \brief virtual destructor */
virtual ~IEngine() {}
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
/*!
* \brief performs in-place Allreduce, on sendrecvbuf
* this function is NOT thread-safe
@@ -165,6 +188,8 @@ class IEngine {
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const = 0;
/*! \brief gets rank of previous node in ring topology */
virtual int GetRingPrevRank(void) const = 0;
/*! \brief gets rank of current node */
virtual int GetRank(void) const = 0;
/*! \brief gets total number of nodes */
@@ -212,6 +237,26 @@ enum DataType {
kULongLong = 9
};
} // namespace mpi
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
void Allgather(void* sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice);
/*!
* \brief perform in-place Allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI

View File

@@ -110,6 +110,10 @@ inline bool Init(int argc, char *argv[]) {
inline bool Finalize(void) {
return engine::Finalize();
}
// get the rank of the previous worker in ring topology
inline int GetRingPrevRank(void) {
return engine::GetEngine()->GetRingPrevRank();
}
// get the rank of current process
inline int GetRank(void) {
return engine::GetEngine()->GetRank();
@@ -194,6 +198,16 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
}
#endif // C++11
// Performs inplace Allgather
template<typename DType>
inline void Allgather(DType *sendrecvbuf, size_t totalSize,
size_t beginIndex, size_t sizeNodeSlice,
size_t sizePrevSlice) {
engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
(beginIndex + sizeNodeSlice) * sizeof(DType),
sizePrevSlice * sizeof(DType));
}
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);