Cleanup RABIT. (#6290)

* Remove recovery and MPI speed tests.
* Remove readme.
* Remove Python binding.
* Add checks in C API.
This commit is contained in:
Jiaming Yuan
2020-10-27 08:48:22 +08:00
committed by GitHub
parent 8e0f5a6fc7
commit b180223d18
40 changed files with 113 additions and 1875 deletions

View File

@@ -275,7 +275,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
}
try {
utils::TCPSocket tracker = this->ConnectTracker();
fprintf(stdout, "task %s connected to the tracker\n", task_id.c_str());
LOG(INFO) << "task " << task_id << " connected to the tracker";
tracker.SendStr(std::string(cmd));
// the rank of previous link, next link in ring

View File

@@ -98,14 +98,9 @@ class AllreduceBase : public IEngine {
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
size_t slice_end, size_t size_prev_slice) override {
if (world_size == 1 || world_size == -1) {
return;
}
@@ -124,15 +119,10 @@ class AllreduceBase : public IEngine {
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override {
void *prepare_arg = nullptr) override {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
@@ -148,9 +138,7 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"Broadcast failed");
@@ -232,14 +220,6 @@ class AllreduceBase : public IEngine {
int VersionNumber() const override {
return version_number;
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
void InitAfterException() override {
utils::Error("InitAfterException: not implemented");
}
/*!
* \brief report current status to the job tracker
* depending on the job tracker we are in

View File

@@ -11,8 +11,8 @@
#include <vector>
#include <map>
#include <sstream>
#include <dmlc/timer.h>
#include "rabit/internal/engine.h"
#include "rabit/internal/timer.h"
#include "allreduce_base.h"
namespace rabit {
@@ -46,36 +46,30 @@ class AllreduceMock : public AllreduceBase {
}
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override {
void *prepare_arg) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
double tstart = utils::GetTime();
double tstart = dmlc::GetTime();
AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer,
prepare_fun, prepare_arg, _file, _line, _caller);
tsum_allreduce_ += utils::GetTime() - tstart;
prepare_fun, prepare_arg);
tsum_allreduce_ += dmlc::GetTime() - tstart;
}
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
size_t slice_end, size_t size_prev_slice) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
double tstart = utils::GetTime();
double tstart = dmlc::GetTime();
AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end,
size_prev_slice, _file, _line, _caller);
tsum_allgather_ += utils::GetTime() - tstart;
size_prev_slice);
tsum_allgather_ += dmlc::GetTime() - tstart;
}
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root);
}
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) override {
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
time_checkpoint_ = utils::GetTime();
time_checkpoint_ = dmlc::GetTime();
if (force_local_ == 0) {
return AllreduceBase::LoadCheckPoint(global_model, local_model);
} else {
@@ -87,7 +81,7 @@ class AllreduceMock : public AllreduceBase {
void CheckPoint(const Serializable *global_model,
const Serializable *local_model) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
double tstart = utils::GetTime();
double tstart = dmlc::GetTime();
double tbet_chkpt = tstart - time_checkpoint_;
if (force_local_ == 0) {
AllreduceBase::CheckPoint(global_model, local_model);
@@ -96,8 +90,8 @@ class AllreduceMock : public AllreduceBase {
ComboSerializer com(global_model, local_model);
AllreduceBase::CheckPoint(&dum, &com);
}
time_checkpoint_ = utils::GetTime();
double tcost = utils::GetTime() - tstart;
time_checkpoint_ = dmlc::GetTime();
double tcost = dmlc::GetTime() - tstart;
if (report_stats_ != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size="

View File

@@ -228,12 +228,12 @@ RABIT_DLL bool RabitInit(int argc, char *argv[]) {
return ret;
}
RABIT_DLL bool RabitFinalize() {
RABIT_DLL int RabitFinalize() {
auto ret = rabit::Finalize();
if (!ret) {
XGBAPISetLastError("Failed to shutdown RABIT worker.");
}
return ret;
return static_cast<int>(ret);
}
RABIT_DLL int RabitGetRingPrevRank() {
@@ -270,37 +270,39 @@ RABIT_DLL void RabitGetProcessorName(char *out_name,
*out_len = static_cast<rbt_ulong>(s.length());
}
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
RABIT_DLL int RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root) {
API_BEGIN()
rabit::Broadcast(sendrecv_data, size, root);
API_END()
}
RABIT_DLL void RabitAllgather(void *sendrecvbuf_, size_t total_size,
RABIT_DLL int RabitAllgather(void *sendrecvbuf_, size_t total_size,
size_t beginIndex, size_t size_node_slice,
size_t size_prev_slice, int enum_dtype) {
rabit::c_api::Allgather(sendrecvbuf_,
total_size,
beginIndex,
size_node_slice,
size_prev_slice,
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
API_BEGIN()
rabit::c_api::Allgather(
sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice,
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
API_END()
}
RABIT_DLL void RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
int enum_op, void (*prepare_fun)(void *arg),
void *prepare_arg) {
rabit::c_api::Allreduce
(sendrecvbuf, count,
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
static_cast<rabit::engine::mpi::OpType>(enum_op),
prepare_fun, prepare_arg);
API_BEGIN()
rabit::c_api::Allreduce(sendrecvbuf, count,
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
static_cast<rabit::engine::mpi::OpType>(enum_op),
prepare_fun, prepare_arg);
API_END()
}
RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
rbt_ulong *out_global_len,
char **out_local_model,
rbt_ulong *out_local_len) {
// NOTE: this function is not thread-safe
// no-op as XGBoost 1.3
using rabit::BeginPtr;
using namespace rabit::c_api; // NOLINT(*)
static std::string global_buffer;

View File

@@ -85,12 +85,9 @@ IEngine *GetEngine() {
void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
size_t size_prev_slice) {
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
slice_end, size_prev_slice, _file, _line, _caller);
slice_end, size_prev_slice);
}
@@ -102,12 +99,9 @@ void Allreduce_(void *sendrecvbuf, // NOLINT
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun,
prepare_arg, _file, _line, _caller);
prepare_arg);
}
// code for reduce handle
@@ -126,14 +120,10 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
void *prepare_arg) {
utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce");
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
redfunc_, prepare_fun, prepare_arg,
_file, _line, _caller);
redfunc_, prepare_fun, prepare_arg);
}
} // namespace engine
} // namespace rabit