fix xgboost build failure introduced by allgather interface (#129)
* fix missing allgether rabit declaration * fix allgather signature mismatch * fix type conversion * fix GetRingPrevRank
This commit is contained in:
@@ -71,7 +71,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
||||
slice_begin, slice_end, size_prev_slice);
|
||||
slice_begin, slice_end,
|
||||
size_prev_slice, _file, _line, _caller);
|
||||
tsum_allgather += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
|
||||
37
src/c_api.cc
37
src/c_api.cc
@@ -131,33 +131,54 @@ void Allgather(void *sendrecvbuf_,
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
type_size = sizeof(char);
|
||||
rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kUChar:
|
||||
type_size = sizeof(unsigned char);
|
||||
rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kInt:
|
||||
type_size = sizeof(int);
|
||||
rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kUInt:
|
||||
type_size = sizeof(unsigned);
|
||||
rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kLong:
|
||||
type_size = sizeof(int64_t);
|
||||
rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kULong:
|
||||
type_size = sizeof(uint64_t);
|
||||
rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kFloat:
|
||||
type_size = sizeof(float);
|
||||
rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
case kDouble:
|
||||
type_size = sizeof(double);
|
||||
rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
|
||||
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
|
||||
size_prev_slice * type_size);
|
||||
break;
|
||||
default: utils::Error("unknown data_type");
|
||||
}
|
||||
engine::Allgather
|
||||
(sendrecvbuf_, total_size * type_size, beginIndex * type_size,
|
||||
(beginIndex + size_node_slice) * type_size, size_prev_slice * type_size);
|
||||
}
|
||||
|
||||
// wrapper for serialization
|
||||
@@ -251,11 +272,11 @@ void RabitAllgather(void *sendrecvbuf_,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype) {
|
||||
rabit::c_api::Allgather(sendrecvbuf_,
|
||||
total_size,
|
||||
beginIndex,
|
||||
size_node_slice,
|
||||
size_prev_slice,
|
||||
enum_dtype);
|
||||
total_size,
|
||||
beginIndex,
|
||||
size_node_slice,
|
||||
size_prev_slice,
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -88,8 +88,12 @@ IEngine *GetEngine() {
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice);
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
|
||||
slice_end, size_prev_slice, _file, _line, _caller);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -26,17 +26,18 @@ class EmptyEngine : public IEngine {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Error("EmptyEngine:: Allgather is not supported");
|
||||
}
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||
return -1;
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
|
||||
Reference in New Issue
Block a user