Cleanup RABIT. (#6290)
* Remove recovery and MPI speed tests. * Remove readme. * Remove Python binding. * Add checks in C API.
This commit is contained in:
@@ -275,7 +275,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
}
|
||||
try {
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
fprintf(stdout, "task %s connected to the tracker\n", task_id.c_str());
|
||||
LOG(INFO) << "task " << task_id << " connected to the tracker";
|
||||
tracker.SendStr(std::string(cmd));
|
||||
|
||||
// the rank of previous link, next link in ring
|
||||
|
||||
@@ -98,14 +98,9 @@ class AllreduceBase : public IEngine {
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
size_t slice_end, size_t size_prev_slice) override {
|
||||
if (world_size == 1 || world_size == -1) {
|
||||
return;
|
||||
}
|
||||
@@ -124,15 +119,10 @@ class AllreduceBase : public IEngine {
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
|
||||
void *prepare_arg = nullptr, const char *_file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
void *prepare_arg = nullptr) override {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
|
||||
@@ -148,9 +138,7 @@ class AllreduceBase : public IEngine {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||
"Broadcast failed");
|
||||
@@ -232,14 +220,6 @@ class AllreduceBase : public IEngine {
|
||||
int VersionNumber() const override {
|
||||
return version_number;
|
||||
}
|
||||
/*!
|
||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throw an exception out,
|
||||
* this function is only used for test purpose
|
||||
*/
|
||||
void InitAfterException() override {
|
||||
utils::Error("InitAfterException: not implemented");
|
||||
}
|
||||
/*!
|
||||
* \brief report current status to the job tracker
|
||||
* depending on the job tracker we are in
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <dmlc/timer.h>
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "rabit/internal/timer.h"
|
||||
#include "allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
@@ -46,36 +46,30 @@ class AllreduceMock : public AllreduceBase {
|
||||
}
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||
void *prepare_arg, const char *_file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
void *prepare_arg) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
|
||||
double tstart = utils::GetTime();
|
||||
double tstart = dmlc::GetTime();
|
||||
AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer,
|
||||
prepare_fun, prepare_arg, _file, _line, _caller);
|
||||
tsum_allreduce_ += utils::GetTime() - tstart;
|
||||
prepare_fun, prepare_arg);
|
||||
tsum_allreduce_ += dmlc::GetTime() - tstart;
|
||||
}
|
||||
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
size_t slice_end, size_t size_prev_slice) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
|
||||
double tstart = utils::GetTime();
|
||||
double tstart = dmlc::GetTime();
|
||||
AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end,
|
||||
size_prev_slice, _file, _line, _caller);
|
||||
tsum_allgather_ += utils::GetTime() - tstart;
|
||||
size_prev_slice);
|
||||
tsum_allgather_ += dmlc::GetTime() - tstart;
|
||||
}
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
|
||||
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
||||
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root);
|
||||
}
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) override {
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
time_checkpoint_ = utils::GetTime();
|
||||
time_checkpoint_ = dmlc::GetTime();
|
||||
if (force_local_ == 0) {
|
||||
return AllreduceBase::LoadCheckPoint(global_model, local_model);
|
||||
} else {
|
||||
@@ -87,7 +81,7 @@ class AllreduceMock : public AllreduceBase {
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
|
||||
double tstart = utils::GetTime();
|
||||
double tstart = dmlc::GetTime();
|
||||
double tbet_chkpt = tstart - time_checkpoint_;
|
||||
if (force_local_ == 0) {
|
||||
AllreduceBase::CheckPoint(global_model, local_model);
|
||||
@@ -96,8 +90,8 @@ class AllreduceMock : public AllreduceBase {
|
||||
ComboSerializer com(global_model, local_model);
|
||||
AllreduceBase::CheckPoint(&dum, &com);
|
||||
}
|
||||
time_checkpoint_ = utils::GetTime();
|
||||
double tcost = utils::GetTime() - tstart;
|
||||
time_checkpoint_ = dmlc::GetTime();
|
||||
double tcost = dmlc::GetTime() - tstart;
|
||||
if (report_stats_ != 0 && rank == 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[v" << version_number << "] global_size="
|
||||
|
||||
@@ -228,12 +228,12 @@ RABIT_DLL bool RabitInit(int argc, char *argv[]) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
RABIT_DLL bool RabitFinalize() {
|
||||
RABIT_DLL int RabitFinalize() {
|
||||
auto ret = rabit::Finalize();
|
||||
if (!ret) {
|
||||
XGBAPISetLastError("Failed to shutdown RABIT worker.");
|
||||
}
|
||||
return ret;
|
||||
return static_cast<int>(ret);
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitGetRingPrevRank() {
|
||||
@@ -270,37 +270,39 @@ RABIT_DLL void RabitGetProcessorName(char *out_name,
|
||||
*out_len = static_cast<rbt_ulong>(s.length());
|
||||
}
|
||||
|
||||
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
|
||||
RABIT_DLL int RabitBroadcast(void *sendrecv_data,
|
||||
rbt_ulong size, int root) {
|
||||
API_BEGIN()
|
||||
rabit::Broadcast(sendrecv_data, size, root);
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL void RabitAllgather(void *sendrecvbuf_, size_t total_size,
|
||||
RABIT_DLL int RabitAllgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t beginIndex, size_t size_node_slice,
|
||||
size_t size_prev_slice, int enum_dtype) {
|
||||
rabit::c_api::Allgather(sendrecvbuf_,
|
||||
total_size,
|
||||
beginIndex,
|
||||
size_node_slice,
|
||||
size_prev_slice,
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
|
||||
API_BEGIN()
|
||||
rabit::c_api::Allgather(
|
||||
sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice,
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL void RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
|
||||
RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
|
||||
int enum_op, void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
rabit::c_api::Allreduce
|
||||
(sendrecvbuf, count,
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
|
||||
static_cast<rabit::engine::mpi::OpType>(enum_op),
|
||||
prepare_fun, prepare_arg);
|
||||
API_BEGIN()
|
||||
rabit::c_api::Allreduce(sendrecvbuf, count,
|
||||
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
|
||||
static_cast<rabit::engine::mpi::OpType>(enum_op),
|
||||
prepare_fun, prepare_arg);
|
||||
API_END()
|
||||
}
|
||||
|
||||
RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
|
||||
rbt_ulong *out_global_len,
|
||||
char **out_local_model,
|
||||
rbt_ulong *out_local_len) {
|
||||
// NOTE: this function is not thread-safe
|
||||
// no-op as XGBoost 1.3
|
||||
using rabit::BeginPtr;
|
||||
using namespace rabit::c_api; // NOLINT(*)
|
||||
static std::string global_buffer;
|
||||
|
||||
@@ -85,12 +85,9 @@ IEngine *GetEngine() {
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
size_t size_prev_slice) {
|
||||
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
|
||||
slice_end, size_prev_slice, _file, _line, _caller);
|
||||
slice_end, size_prev_slice);
|
||||
}
|
||||
|
||||
|
||||
@@ -102,12 +99,9 @@ void Allreduce_(void *sendrecvbuf, // NOLINT
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun,
|
||||
prepare_arg, _file, _line, _caller);
|
||||
prepare_arg);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
@@ -126,14 +120,10 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
void *prepare_arg) {
|
||||
utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce");
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
redfunc_, prepare_fun, prepare_arg,
|
||||
_file, _line, _caller);
|
||||
redfunc_, prepare_fun, prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
Reference in New Issue
Block a user