fix xgboost build failure introduced by allgather interface (#129)

* fix missing allgether rabit declaration

* fix allgather signature mismatch

* fix type conversion

* fix GetRingPrevRank
This commit is contained in:
Chen Qin
2020-01-01 06:45:14 -08:00
committed by Jiaming Yuan
parent 493ad834a1
commit 0d6a853212
7 changed files with 97 additions and 34 deletions

View File

@@ -71,7 +71,8 @@ class AllreduceMock : public AllreduceRobust {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
double tstart = utils::GetTime();
AllreduceRobust::Allgather(sendrecvbuf, total_size,
slice_begin, slice_end, size_prev_slice);
slice_begin, slice_end,
size_prev_slice, _file, _line, _caller);
tsum_allgather += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,

View File

@@ -131,33 +131,54 @@ void Allgather(void *sendrecvbuf_,
switch (enum_dtype) {
case kChar:
type_size = sizeof(char);
rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kUChar:
type_size = sizeof(unsigned char);
rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kInt:
type_size = sizeof(int);
rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kUInt:
type_size = sizeof(unsigned);
rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kLong:
type_size = sizeof(int64_t);
rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kULong:
type_size = sizeof(uint64_t);
rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kFloat:
type_size = sizeof(float);
rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kDouble:
type_size = sizeof(double);
rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
default: utils::Error("unknown data_type");
}
engine::Allgather
(sendrecvbuf_, total_size * type_size, beginIndex * type_size,
(beginIndex + size_node_slice) * type_size, size_prev_slice * type_size);
}
// wrapper for serialization
@@ -251,11 +272,11 @@ void RabitAllgather(void *sendrecvbuf_,
size_t size_prev_slice,
int enum_dtype) {
rabit::c_api::Allgather(sendrecvbuf_,
total_size,
beginIndex,
size_node_slice,
size_prev_slice,
enum_dtype);
total_size,
beginIndex,
size_node_slice,
size_prev_slice,
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
}

View File

@@ -88,8 +88,12 @@ IEngine *GetEngine() {
void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) {
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, slice_end, size_prev_slice);
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
slice_end, size_prev_slice, _file, _line, _caller);
}

View File

@@ -26,17 +26,18 @@ class EmptyEngine : public IEngine {
version_number = 0;
}
virtual void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("EmptyEngine:: Allgather is not supported");
}
virtual int GetRingPrevRank(void) const {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,