fix xgboost build failure introduced by allgather interface (#129)

* fix missing allgether rabit declaration

* fix allgather signature mismatch

* fix type conversion

* fix GetRingPrevRank
This commit is contained in:
Chen Qin
2020-01-01 06:45:14 -08:00
committed by Jiaming Yuan
parent 493ad834a1
commit 0d6a853212
7 changed files with 97 additions and 34 deletions

View File

@@ -69,13 +69,14 @@ class IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
virtual void Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) = 0;
/*!
* \brief performs in-place Allreduce, on sendrecvbuf
* this function is NOT thread-safe
@@ -249,14 +250,18 @@ enum DataType {
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Allgather(void* sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice);
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*!
* \brief perform in-place Allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI

View File

@@ -196,17 +196,22 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
_file, _line, _caller);
}
#endif // C++11
// Performs inplace Allgather
template<typename DType>
inline void Allgather(DType *sendrecvbuf, size_t totalSize,
size_t beginIndex, size_t sizeNodeSlice,
size_t sizePrevSlice) {
engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
inline void Allgather(DType *sendrecvbuf,
size_t totalSize,
size_t beginIndex,
size_t sizeNodeSlice,
size_t sizePrevSlice,
const char* _file,
const int _line,
const char* _caller) {
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
(beginIndex + sizeNodeSlice) * sizeof(DType),
sizePrevSlice * sizeof(DType));
sizePrevSlice * sizeof(DType), _file, _line, _caller);
}
#endif // C++11
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {