fix xgboost build failure introduced by allgather interface (#129)
* fix missing allgether rabit declaration * fix allgather signature mismatch * fix type conversion * fix GetRingPrevRank
This commit is contained in:
parent
493ad834a1
commit
0d6a853212
@ -69,7 +69,8 @@ class IEngine {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf, size_t total_size,
|
||||
virtual void Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
@ -249,14 +250,18 @@ enum DataType {
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allgather(void* sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice);
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
/*!
|
||||
* \brief perform in-place Allreduce, on sendrecvbuf
|
||||
* this is an internal function used by rabit to be able to compile with MPI
|
||||
|
||||
@ -196,17 +196,22 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda_, &prepare_fun,
|
||||
_file, _line, _caller);
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// Performs inplace Allgather
|
||||
template<typename DType>
|
||||
inline void Allgather(DType *sendrecvbuf, size_t totalSize,
|
||||
size_t beginIndex, size_t sizeNodeSlice,
|
||||
size_t sizePrevSlice) {
|
||||
engine::Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
|
||||
inline void Allgather(DType *sendrecvbuf,
|
||||
size_t totalSize,
|
||||
size_t beginIndex,
|
||||
size_t sizeNodeSlice,
|
||||
size_t sizePrevSlice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
|
||||
(beginIndex + sizeNodeSlice) * sizeof(DType),
|
||||
sizePrevSlice * sizeof(DType));
|
||||
sizePrevSlice * sizeof(DType), _file, _line, _caller);
|
||||
}
|
||||
#endif // C++11
|
||||
|
||||
// print message to the tracker
|
||||
inline void TrackerPrint(const std::string &msg) {
|
||||
|
||||
@ -205,6 +205,32 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
|
||||
/*!
|
||||
* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
template<typename DType>
|
||||
inline void Allgather(DType *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
#if DMLC_USE_CXX11
|
||||
/*!
|
||||
|
||||
@ -71,7 +71,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
||||
slice_begin, slice_end, size_prev_slice);
|
||||
slice_begin, slice_end,
|
||||
size_prev_slice, _file, _line, _caller);
|
||||
tsum_allgather += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
|
||||
29
src/c_api.cc
29
src/c_api.cc
@ -131,33 +131,54 @@ void Allgather(void *sendrecvbuf_,
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
type_size = sizeof(char);
|
||||
rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kUChar:
|
||||
type_size = sizeof(unsigned char);
|
||||
rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kInt:
|
||||
type_size = sizeof(int);
|
||||
rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kUInt:
|
||||
type_size = sizeof(unsigned);
|
||||
rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kLong:
|
||||
type_size = sizeof(int64_t);
|
||||
rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kULong:
|
||||
type_size = sizeof(uint64_t);
|
||||
rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kFloat:
|
||||
type_size = sizeof(float);
|
||||
rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kDouble:
|
||||
type_size = sizeof(double);
|
||||
rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
engine::Allgather
|
||||
(sendrecvbuf_, total_size * type_size, beginIndex * type_size,
|
||||
(beginIndex + size_node_slice) * type_size, size_prev_slice * type_size);
|
||||
}
|
||||
|
||||
// wrapper for serialization
|
||||
@ -255,7 +276,7 @@ void RabitAllgather(void *sendrecvbuf_,
|
||||
beginIndex,
|
||||
size_node_slice,
|
||||
size_prev_slice,
|
||||
enum_dtype);
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -88,8 +88,12 @@ IEngine *GetEngine() {
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice);
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
|
||||
slice_end, size_prev_slice, _file, _line, _caller);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@ class EmptyEngine : public IEngine {
|
||||
}
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||
return -1;
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user