fix xgboost build failure introduced by allgather interface (#129)

* fix missing allgether rabit declaration

* fix allgather signature mismatch

* fix type conversion

* fix GetRingPrevRank
This commit is contained in:
Chen Qin 2020-01-01 06:45:14 -08:00 committed by Jiaming Yuan
parent 493ad834a1
commit 0d6a853212
7 changed files with 97 additions and 34 deletions

View File

@ -69,13 +69,14 @@ class IEngine {
* \param _line caller line number used to generate unique cache key * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key * \param _caller caller function name used to generate unique cache key
*/ */
virtual void Allgather(void *sendrecvbuf, size_t total_size, virtual void Allgather(void *sendrecvbuf,
size_t slice_begin, size_t total_size,
size_t slice_end, size_t slice_begin,
size_t size_prev_slice, size_t slice_end,
const char* _file = _FILE, size_t size_prev_slice,
const int _line = _LINE, const char* _file = _FILE,
const char* _caller = _CALLER) = 0; const int _line = _LINE,
const char* _caller = _CALLER) = 0;
/*! /*!
* \brief performs in-place Allreduce, on sendrecvbuf * \brief performs in-place Allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
@ -249,14 +250,18 @@ enum DataType {
* \param slice_begin beginning of the current slice * \param slice_begin beginning of the current slice
* \param slice_end end of the current slice * \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details * \param _file caller file name used to generate unique cache key
* \sa ReturnType * \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/ */
void Allgather(void* sendrecvbuf, void Allgather(void* sendrecvbuf,
size_t total_size, size_t total_size,
size_t slice_begin, size_t slice_begin,
size_t slice_end, size_t slice_end,
size_t size_prev_slice); size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*! /*!
* \brief perform in-place Allreduce, on sendrecvbuf * \brief perform in-place Allreduce, on sendrecvbuf
* this is an internal function used by rabit to be able to compile with MPI * this is an internal function used by rabit to be able to compile with MPI

View File

@ -196,17 +196,22 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun, engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
_file, _line, _caller); _file, _line, _caller);
} }
#endif // C++11
// Performs inplace Allgather // Performs inplace Allgather
template<typename DType> template<typename DType>
inline void Allgather(DType *sendrecvbuf, size_t totalSize, inline void Allgather(DType *sendrecvbuf,
size_t beginIndex, size_t sizeNodeSlice, size_t totalSize,
size_t sizePrevSlice) { size_t beginIndex,
engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType), size_t sizeNodeSlice,
size_t sizePrevSlice,
const char* _file,
const int _line,
const char* _caller) {
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
(beginIndex + sizeNodeSlice) * sizeof(DType), (beginIndex + sizeNodeSlice) * sizeof(DType),
sizePrevSlice * sizeof(DType)); sizePrevSlice * sizeof(DType), _file, _line, _caller);
} }
#endif // C++11
// print message to the tracker // print message to the tracker
inline void TrackerPrint(const std::string &msg) { inline void TrackerPrint(const std::string &msg) {

View File

@ -205,6 +205,32 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
const int _line = _LINE, const int _line = _LINE,
const char* _caller = _CALLER); const char* _caller = _CALLER);
/*!
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
template<typename DType>
inline void Allgather(DType *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
// C++11 support for lambda prepare function // C++11 support for lambda prepare function
#if DMLC_USE_CXX11 #if DMLC_USE_CXX11
/*! /*!

View File

@ -71,7 +71,8 @@ class AllreduceMock : public AllreduceRobust {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
double tstart = utils::GetTime(); double tstart = utils::GetTime();
AllreduceRobust::Allgather(sendrecvbuf, total_size, AllreduceRobust::Allgather(sendrecvbuf, total_size,
slice_begin, slice_end, size_prev_slice); slice_begin, slice_end,
size_prev_slice, _file, _line, _caller);
tsum_allgather += utils::GetTime() - tstart; tsum_allgather += utils::GetTime() - tstart;
} }
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root, virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,

View File

@ -131,33 +131,54 @@ void Allgather(void *sendrecvbuf_,
switch (enum_dtype) { switch (enum_dtype) {
case kChar: case kChar:
type_size = sizeof(char); type_size = sizeof(char);
rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
case kUChar: case kUChar:
type_size = sizeof(unsigned char); type_size = sizeof(unsigned char);
rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
case kInt: case kInt:
type_size = sizeof(int); type_size = sizeof(int);
rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
case kUInt: case kUInt:
type_size = sizeof(unsigned); type_size = sizeof(unsigned);
rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
case kLong: case kLong:
type_size = sizeof(int64_t); type_size = sizeof(int64_t);
rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
case kULong: case kULong:
type_size = sizeof(uint64_t); type_size = sizeof(uint64_t);
rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
case kFloat: case kFloat:
type_size = sizeof(float); type_size = sizeof(float);
rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
case kDouble: case kDouble:
type_size = sizeof(double); type_size = sizeof(double);
rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break; break;
default: utils::Error("unknown data_type"); default: utils::Error("unknown data_type");
} }
engine::Allgather
(sendrecvbuf_, total_size * type_size, beginIndex * type_size,
(beginIndex + size_node_slice) * type_size, size_prev_slice * type_size);
} }
// wrapper for serialization // wrapper for serialization
@ -251,11 +272,11 @@ void RabitAllgather(void *sendrecvbuf_,
size_t size_prev_slice, size_t size_prev_slice,
int enum_dtype) { int enum_dtype) {
rabit::c_api::Allgather(sendrecvbuf_, rabit::c_api::Allgather(sendrecvbuf_,
total_size, total_size,
beginIndex, beginIndex,
size_node_slice, size_node_slice,
size_prev_slice, size_prev_slice,
enum_dtype); static_cast<rabit::engine::mpi::DataType>(enum_dtype));
} }

View File

@ -88,8 +88,12 @@ IEngine *GetEngine() {
void Allgather(void *sendrecvbuf_, size_t total_size, void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin, size_t slice_begin,
size_t slice_end, size_t slice_end,
size_t size_prev_slice) { size_t size_prev_slice,
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice); const char* _file,
const int _line,
const char* _caller) {
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
slice_end, size_prev_slice, _file, _line, _caller);
} }

View File

@ -26,17 +26,18 @@ class EmptyEngine : public IEngine {
version_number = 0; version_number = 0;
} }
virtual void Allgather(void *sendrecvbuf_, virtual void Allgather(void *sendrecvbuf_,
size_t total_size, size_t total_size,
size_t slice_begin, size_t slice_begin,
size_t slice_end, size_t slice_end,
size_t size_prev_slice, size_t size_prev_slice,
const char* _file, const char* _file,
const int _line, const int _line,
const char* _caller) { const char* _caller) {
utils::Error("EmptyEngine:: Allgather is not supported"); utils::Error("EmptyEngine:: Allgather is not supported");
} }
virtual int GetRingPrevRank(void) const { virtual int GetRingPrevRank(void) const {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported"); utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1;
} }
virtual void Allreduce(void *sendrecvbuf_, virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes, size_t type_nbytes,