Correct style warnings from clang-tidy for rabit. (#6095)
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
// constructor
|
||||
AllreduceBase::AllreduceBase(void) {
|
||||
AllreduceBase::AllreduceBase() {
|
||||
tracker_uri = "NULL";
|
||||
tracker_port = 9000;
|
||||
host_uri = "";
|
||||
@@ -24,7 +24,7 @@ AllreduceBase::AllreduceBase(void) {
|
||||
rank = 0;
|
||||
world_size = -1;
|
||||
connect_retry = 5;
|
||||
hadoop_mode = 0;
|
||||
hadoop_mode = false;
|
||||
version_number = 0;
|
||||
// 32 K items
|
||||
reduce_ring_mincount = 32 << 10;
|
||||
@@ -32,27 +32,27 @@ AllreduceBase::AllreduceBase(void) {
|
||||
tree_reduce_minsize = 1 << 20;
|
||||
// tracker URL
|
||||
task_id = "NULL";
|
||||
err_link = NULL;
|
||||
err_link = nullptr;
|
||||
dmlc_role = "worker";
|
||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||
// setup possible enviroment variable of interest
|
||||
// include dmlc support direct variables
|
||||
env_vars.push_back("DMLC_TASK_ID");
|
||||
env_vars.push_back("DMLC_ROLE");
|
||||
env_vars.push_back("DMLC_NUM_ATTEMPT");
|
||||
env_vars.push_back("DMLC_TRACKER_URI");
|
||||
env_vars.push_back("DMLC_TRACKER_PORT");
|
||||
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
|
||||
env_vars.emplace_back("DMLC_TASK_ID");
|
||||
env_vars.emplace_back("DMLC_ROLE");
|
||||
env_vars.emplace_back("DMLC_NUM_ATTEMPT");
|
||||
env_vars.emplace_back("DMLC_TRACKER_URI");
|
||||
env_vars.emplace_back("DMLC_TRACKER_PORT");
|
||||
env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY");
|
||||
}
|
||||
|
||||
// initialization function
|
||||
bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
// setup from enviroment variables
|
||||
// handler to get variables from env
|
||||
for (size_t i = 0; i < env_vars.size(); ++i) {
|
||||
const char *value = getenv(env_vars[i].c_str());
|
||||
if (value != NULL) {
|
||||
this->SetParam(env_vars[i].c_str(), value);
|
||||
for (auto & env_var : env_vars) {
|
||||
const char *value = getenv(env_var.c_str());
|
||||
if (value != nullptr) {
|
||||
this->SetParam(env_var.c_str(), value);
|
||||
}
|
||||
}
|
||||
// pass in arguments override env variable.
|
||||
@@ -66,35 +66,35 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
{
|
||||
// handling for hadoop
|
||||
const char *task_id = getenv("mapred_tip_id");
|
||||
if (task_id == NULL) {
|
||||
if (task_id == nullptr) {
|
||||
task_id = getenv("mapreduce_task_id");
|
||||
}
|
||||
if (hadoop_mode) {
|
||||
utils::Check(task_id != NULL,
|
||||
utils::Check(task_id != nullptr,
|
||||
"hadoop_mode is set but cannot find mapred_task_id");
|
||||
}
|
||||
if (task_id != NULL) {
|
||||
if (task_id != nullptr) {
|
||||
this->SetParam("rabit_task_id", task_id);
|
||||
this->SetParam("rabit_hadoop_mode", "1");
|
||||
}
|
||||
const char *attempt_id = getenv("mapred_task_id");
|
||||
if (attempt_id != 0) {
|
||||
if (attempt_id != nullptr) {
|
||||
const char *att = strrchr(attempt_id, '_');
|
||||
int num_trial;
|
||||
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
|
||||
if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) {
|
||||
this->SetParam("rabit_num_trial", att + 1);
|
||||
}
|
||||
}
|
||||
// handling for hadoop
|
||||
const char *num_task = getenv("mapred_map_tasks");
|
||||
if (num_task == NULL) {
|
||||
if (num_task == nullptr) {
|
||||
num_task = getenv("mapreduce_job_maps");
|
||||
}
|
||||
if (hadoop_mode) {
|
||||
utils::Check(num_task != NULL,
|
||||
utils::Check(num_task != nullptr,
|
||||
"hadoop_mode is set but cannot find mapred_map_tasks");
|
||||
}
|
||||
if (num_task != NULL) {
|
||||
if (num_task != nullptr) {
|
||||
this->SetParam("rabit_world_size", num_task);
|
||||
}
|
||||
}
|
||||
@@ -115,10 +115,10 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
return this->ReConnectLinks();
|
||||
}
|
||||
|
||||
bool AllreduceBase::Shutdown(void) {
|
||||
bool AllreduceBase::Shutdown() {
|
||||
try {
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
all_links[i].sock.Close();
|
||||
for (auto & all_link : all_links) {
|
||||
all_link.sock.Close();
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
@@ -208,17 +208,18 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
|
||||
}
|
||||
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
|
||||
if (!strcmp(val, "true"))
|
||||
if (!strcmp(val, "true")) {
|
||||
rabit_enable_tcp_no_delay = true;
|
||||
else
|
||||
} else {
|
||||
rabit_enable_tcp_no_delay = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
utils::TCPSocket AllreduceBase::ConnectTracker() const {
|
||||
int magic = kMagic;
|
||||
// get information from tracker
|
||||
utils::TCPSocket tracker;
|
||||
@@ -241,7 +242,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
}
|
||||
}
|
||||
break;
|
||||
} while (1);
|
||||
} while (true);
|
||||
|
||||
using utils::Assert;
|
||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
@@ -320,19 +321,19 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
do {
|
||||
// send over good links
|
||||
std::vector<int> good_link;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
good_link.push_back(static_cast<int>(all_links[i].rank));
|
||||
for (auto & all_link : all_links) {
|
||||
if (!all_link.sock.BadSocket()) {
|
||||
good_link.push_back(static_cast<int>(all_link.rank));
|
||||
} else {
|
||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||
if (!all_link.sock.IsClosed()) all_link.sock.Close();
|
||||
}
|
||||
}
|
||||
int ngood = static_cast<int>(good_link.size());
|
||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||
"ReConnectLink failure 5");
|
||||
for (size_t i = 0; i < good_link.size(); ++i) {
|
||||
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
||||
sizeof(good_link[i]), "ReConnectLink failure 6");
|
||||
for (int & i : good_link) {
|
||||
Assert(tracker.SendAll(&i, sizeof(i)) == \
|
||||
sizeof(i), "ReConnectLink failure 6");
|
||||
}
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
@@ -362,11 +363,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
utils::Check(hrank == r.rank,
|
||||
"ReConnectLink failure, link rank inconsistent");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == hrank) {
|
||||
Assert(all_links[i].sock.IsClosed(),
|
||||
for (auto & all_link : all_links) {
|
||||
if (all_link.rank == hrank) {
|
||||
Assert(all_link.sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock;
|
||||
all_link.sock = r.sock;
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
@@ -390,11 +391,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 15");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == r.rank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(),
|
||||
for (auto & all_link : all_links) {
|
||||
if (all_link.rank == r.rank) {
|
||||
utils::Assert(all_link.sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock;
|
||||
all_link.sock = r.sock;
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
@@ -406,29 +407,29 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
int tcpNoDelay = 1;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
for (auto & all_link : all_links) {
|
||||
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetKeepAlive(true);
|
||||
all_link.sock.SetNonBlock(true);
|
||||
all_link.sock.SetKeepAlive(true);
|
||||
if (rabit_enable_tcp_no_delay) {
|
||||
setsockopt(all_links[i].sock, IPPROTO_TCP,
|
||||
setsockopt(all_link.sock, IPPROTO_TCP,
|
||||
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
|
||||
}
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
if (tree_neighbors.count(all_link.rank) != 0) {
|
||||
if (all_link.rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
}
|
||||
tree_links.plinks.push_back(&all_links[i]);
|
||||
tree_links.plinks.push_back(&all_link);
|
||||
}
|
||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||
if (all_link.rank == prev_rank) ring_prev = &all_link;
|
||||
if (all_link.rank == next_rank) ring_next = &all_link;
|
||||
}
|
||||
Assert(parent_rank == -1 || parent_index != -1,
|
||||
"cannot find parent in the link");
|
||||
Assert(prev_rank == -1 || ring_prev != NULL,
|
||||
Assert(prev_rank == -1 || ring_prev != nullptr,
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != NULL,
|
||||
Assert(next_rank == -1 || ring_next != nullptr,
|
||||
"cannot find next ring in the link");
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
@@ -479,11 +480,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || count == 0) return kSuccess;
|
||||
if (links.Size() == 0 || count == 0) return kSuccess;
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
const int nlink = static_cast<int>(links.Size());
|
||||
// send recv buffer
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
// size of space that we already performs reduce in up pass
|
||||
@@ -581,8 +582,9 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
|
||||
// if max reduce is less than total size, we reduce multiple times of
|
||||
// eachreduce size
|
||||
if (max_reduce < total_size)
|
||||
if (max_reduce < total_size) {
|
||||
max_reduce = max_reduce - max_reduce % eachreduce;
|
||||
}
|
||||
|
||||
// peform reduce, can be at most two rounds
|
||||
while (size_up_reduce < max_reduce) {
|
||||
@@ -677,11 +679,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||
if (links.Size() == 0 || total_size == 0) return kSuccess;
|
||||
utils::Check(root < world_size,
|
||||
"Broadcast: root should be smaller than world size");
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
const int nlink = static_cast<int>(links.Size());
|
||||
// size of space already read from data
|
||||
size_t size_in = 0;
|
||||
// input link, -2 means unknown yet, -1 means this is root
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
#endif // RABIT_CXXTESTDEFS_H
|
||||
|
||||
|
||||
namespace MPI {
|
||||
namespace MPI { // NOLINT
|
||||
// MPI data type to be compatible with existing MPI interface
|
||||
class Datatype {
|
||||
public:
|
||||
@@ -41,12 +41,12 @@ class AllreduceBase : public IEngine {
|
||||
// magic number to verify server
|
||||
static const int kMagic = 0xff99;
|
||||
// constant one byte out of band message to indicate error happening
|
||||
AllreduceBase(void);
|
||||
virtual ~AllreduceBase(void) {}
|
||||
AllreduceBase();
|
||||
virtual ~AllreduceBase() = default;
|
||||
// initialize the manager
|
||||
virtual bool Init(int argc, char* argv[]);
|
||||
// shutdown the engine
|
||||
virtual bool Shutdown(void);
|
||||
virtual bool Shutdown();
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
@@ -59,27 +59,27 @@ class AllreduceBase : public IEngine {
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg);
|
||||
void TrackerPrint(const std::string &msg) override;
|
||||
|
||||
/*! \brief get rank of previous node in ring topology*/
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
int GetRingPrevRank() const override {
|
||||
return ring_prev->rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
int GetRank() const override {
|
||||
return rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
virtual int GetWorldSize(void) const {
|
||||
int GetWorldSize() const override {
|
||||
if (world_size == -1) return 1;
|
||||
return world_size;
|
||||
}
|
||||
/*! \brief whether is distributed or not */
|
||||
virtual bool IsDistributed(void) const {
|
||||
bool IsDistributed() const override {
|
||||
return tracker_uri != "NULL";
|
||||
}
|
||||
/*! \brief get rank */
|
||||
virtual std::string GetHost(void) const {
|
||||
std::string GetHost() const override {
|
||||
return host_uri;
|
||||
}
|
||||
|
||||
@@ -99,13 +99,10 @@ class AllreduceBase : public IEngine {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
|
||||
slice_begin, slice_end, size_prev_slice) == kSuccess,
|
||||
@@ -126,16 +123,12 @@ class AllreduceBase : public IEngine {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
|
||||
void *prepare_arg = nullptr, const char *_file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_,
|
||||
type_nbytes, count, reducer) == kSuccess,
|
||||
@@ -150,8 +143,9 @@ class AllreduceBase : public IEngine {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) {
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||
"Broadcast failed");
|
||||
@@ -178,8 +172,8 @@ class AllreduceBase : public IEngine {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL) {
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = nullptr) override {
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
@@ -198,8 +192,8 @@ class AllreduceBase : public IEngine {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) {
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = nullptr) override {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
@@ -222,7 +216,7 @@ class AllreduceBase : public IEngine {
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
@@ -230,7 +224,7 @@ class AllreduceBase : public IEngine {
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
virtual int VersionNumber(void) const {
|
||||
int VersionNumber() const override {
|
||||
return version_number;
|
||||
}
|
||||
/*!
|
||||
@@ -238,14 +232,14 @@ class AllreduceBase : public IEngine {
|
||||
* call this function when IEngine throw an exception out,
|
||||
* this function is only used for test purpose
|
||||
*/
|
||||
virtual void InitAfterException(void) {
|
||||
void InitAfterException() override {
|
||||
utils::Error("InitAfterException: not implemented");
|
||||
}
|
||||
/*!
|
||||
* \brief report current status to the job tracker
|
||||
* depending on the job tracker we are in
|
||||
*/
|
||||
inline void ReportStatus(void) const {
|
||||
inline void ReportStatus() const {
|
||||
if (hadoop_mode != 0) {
|
||||
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
|
||||
version_number, seq_counter);
|
||||
@@ -274,7 +268,7 @@ class AllreduceBase : public IEngine {
|
||||
/*! \brief internal return type */
|
||||
ReturnTypeEnum value;
|
||||
// constructor
|
||||
ReturnType() {}
|
||||
ReturnType() = default;
|
||||
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
|
||||
inline bool operator==(const ReturnTypeEnum &v) const {
|
||||
return value == v;
|
||||
@@ -306,13 +300,11 @@ class AllreduceBase : public IEngine {
|
||||
// size of data sent to the link
|
||||
size_t size_write;
|
||||
// pointer to buffer head
|
||||
char *buffer_head;
|
||||
char *buffer_head {nullptr};
|
||||
// buffer size, in bytes
|
||||
size_t buffer_size;
|
||||
size_t buffer_size {0};
|
||||
// constructor
|
||||
LinkRecord(void)
|
||||
: buffer_head(NULL), buffer_size(0) {
|
||||
}
|
||||
LinkRecord() = default;
|
||||
// initialize buffer
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
@@ -328,7 +320,7 @@ class AllreduceBase : public IEngine {
|
||||
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
||||
}
|
||||
// reset the recv and sent size
|
||||
inline void ResetSize(void) {
|
||||
inline void ResetSize() {
|
||||
size_write = size_read = 0;
|
||||
}
|
||||
/*!
|
||||
@@ -340,7 +332,7 @@ class AllreduceBase : public IEngine {
|
||||
* \return the type of reading
|
||||
*/
|
||||
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
|
||||
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
|
||||
utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
|
||||
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
|
||||
size_t ngap = size_read - protect_start;
|
||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||
@@ -405,7 +397,7 @@ class AllreduceBase : public IEngine {
|
||||
inline LinkRecord &operator[](size_t i) {
|
||||
return *plinks[i];
|
||||
}
|
||||
inline size_t size(void) const {
|
||||
inline size_t Size() const {
|
||||
return plinks.size();
|
||||
}
|
||||
};
|
||||
@@ -413,7 +405,7 @@ class AllreduceBase : public IEngine {
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket ConnectTracker(void) const;
|
||||
utils::TCPSocket ConnectTracker() const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
@@ -525,64 +517,64 @@ class AllreduceBase : public IEngine {
|
||||
//---- data structure related to model ----
|
||||
// call sequence counter, records how many calls we made so far
|
||||
// from last call to CheckPoint, LoadCheckPoint
|
||||
int seq_counter;
|
||||
int seq_counter; // NOLINT
|
||||
// version number of model
|
||||
int version_number;
|
||||
int version_number; // NOLINT
|
||||
// whether the job is running in hadoop
|
||||
bool hadoop_mode;
|
||||
bool hadoop_mode; // NOLINT
|
||||
//---- local data related to link ----
|
||||
// index of parent link, can be -1, meaning this is root of the tree
|
||||
int parent_index;
|
||||
int parent_index; // NOLINT
|
||||
// rank of parent node, can be -1
|
||||
int parent_rank;
|
||||
int parent_rank; // NOLINT
|
||||
// sockets of all links this connects to
|
||||
std::vector<LinkRecord> all_links;
|
||||
std::vector<LinkRecord> all_links; // NOLINT
|
||||
// used to record the link where things goes wrong
|
||||
LinkRecord *err_link;
|
||||
LinkRecord *err_link; // NOLINT
|
||||
// all the links in the reduction tree connection
|
||||
RefLinkVector tree_links;
|
||||
RefLinkVector tree_links; // NOLINT
|
||||
// pointer to links in the ring
|
||||
LinkRecord *ring_prev, *ring_next;
|
||||
LinkRecord *ring_prev, *ring_next; // NOLINT
|
||||
//----- meta information-----
|
||||
// list of enviroment variables that are of possible interest
|
||||
std::vector<std::string> env_vars;
|
||||
std::vector<std::string> env_vars; // NOLINT
|
||||
// unique identifier of the possible job this process is doing
|
||||
// used to assign ranks, optional, default to NULL
|
||||
std::string task_id;
|
||||
std::string task_id; // NOLINT
|
||||
// uri of current host, to be set by Init
|
||||
std::string host_uri;
|
||||
std::string host_uri; // NOLINT
|
||||
// uri of tracker
|
||||
std::string tracker_uri;
|
||||
std::string tracker_uri; // NOLINT
|
||||
// role in dmlc jobs
|
||||
std::string dmlc_role;
|
||||
std::string dmlc_role; // NOLINT
|
||||
// port of tracker address
|
||||
int tracker_port;
|
||||
int tracker_port; // NOLINT
|
||||
// port of slave process
|
||||
int slave_port, nport_trial;
|
||||
int slave_port, nport_trial; // NOLINT
|
||||
// reduce buffer size
|
||||
size_t reduce_buffer_size;
|
||||
size_t reduce_buffer_size; // NOLINT
|
||||
// reduction method
|
||||
int reduce_method;
|
||||
int reduce_method; // NOLINT
|
||||
// mininum count of cells to use ring based method
|
||||
size_t reduce_ring_mincount;
|
||||
size_t reduce_ring_mincount; // NOLINT
|
||||
// minimul block size per tree reduce
|
||||
size_t tree_reduce_minsize;
|
||||
size_t tree_reduce_minsize; // NOLINT
|
||||
// current rank
|
||||
int rank;
|
||||
int rank; // NOLINT
|
||||
// world size
|
||||
int world_size;
|
||||
int world_size; // NOLINT
|
||||
// connect retry time
|
||||
int connect_retry;
|
||||
int connect_retry; // NOLINT
|
||||
// enable bootstrap cache 0 false 1 true
|
||||
bool rabit_bootstrap_cache = false;
|
||||
bool rabit_bootstrap_cache = false; // NOLINT
|
||||
// enable detailed logging
|
||||
bool rabit_debug = false;
|
||||
bool rabit_debug = false; // NOLINT
|
||||
// by default, if rabit worker not recover in half an hour exit
|
||||
int timeout_sec = 1800;
|
||||
int timeout_sec = 1800; // NOLINT
|
||||
// flag to enable rabit_timeout
|
||||
bool rabit_timeout = false;
|
||||
bool rabit_timeout = false; // NOLINT
|
||||
// Enable TCP node delay
|
||||
bool rabit_enable_tcp_no_delay = false;
|
||||
bool rabit_enable_tcp_no_delay = false; // NOLINT
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@@ -20,74 +20,65 @@ namespace engine {
|
||||
class AllreduceMock : public AllreduceRobust {
|
||||
public:
|
||||
// constructor
|
||||
AllreduceMock(void) {
|
||||
num_trial = 0;
|
||||
force_local = 0;
|
||||
report_stats = 0;
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
AllreduceMock() {
|
||||
num_trial_ = 0;
|
||||
force_local_ = 0;
|
||||
report_stats_ = 0;
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
}
|
||||
// destructor
|
||||
virtual ~AllreduceMock(void) {}
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
~AllreduceMock() override = default;
|
||||
void SetParam(const char *name, const char *val) override {
|
||||
AllreduceRobust::SetParam(name, val);
|
||||
// additional parameters
|
||||
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
||||
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
|
||||
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
|
||||
if (!strcmp(name, "force_local")) force_local = atoi(val);
|
||||
if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
|
||||
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
|
||||
if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
|
||||
if (!strcmp(name, "force_local")) force_local_ = atoi(val);
|
||||
if (!strcmp(name, "mock")) {
|
||||
MockKey k;
|
||||
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
||||
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
|
||||
"invalid mock parameter");
|
||||
mock_map[k] = 1;
|
||||
mock_map_[k] = 1;
|
||||
}
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||
void *prepare_arg, const char *_file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||
count, reducer, prepare_fun, prepare_arg,
|
||||
_file, _line, _caller);
|
||||
tsum_allreduce += utils::GetTime() - tstart;
|
||||
tsum_allreduce_ += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
|
||||
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
||||
slice_begin, slice_end,
|
||||
size_prev_slice, _file, _line, _caller);
|
||||
tsum_allgather += utils::GetTime() - tstart;
|
||||
tsum_allgather_ += utils::GetTime() - tstart;
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
||||
}
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
time_checkpoint = utils::GetTime();
|
||||
if (force_local == 0) {
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) override {
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
time_checkpoint_ = utils::GetTime();
|
||||
if (force_local_ == 0) {
|
||||
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
@@ -95,56 +86,54 @@ class AllreduceMock : public AllreduceRobust {
|
||||
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
||||
}
|
||||
}
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
|
||||
double tstart = utils::GetTime();
|
||||
double tbet_chkpt = tstart - time_checkpoint;
|
||||
if (force_local == 0) {
|
||||
double tbet_chkpt = tstart - time_checkpoint_;
|
||||
if (force_local_ == 0) {
|
||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
ComboSerializer com(global_model, local_model);
|
||||
AllreduceRobust::CheckPoint(&dum, &com);
|
||||
}
|
||||
time_checkpoint = utils::GetTime();
|
||||
time_checkpoint_ = utils::GetTime();
|
||||
double tcost = utils::GetTime() - tstart;
|
||||
if (report_stats != 0 && rank == 0) {
|
||||
if (report_stats_ != 0 && rank == 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
|
||||
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
|
||||
ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length()
|
||||
<< ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length())
|
||||
<< ",check_tcost="<< tcost <<" sec"
|
||||
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
|
||||
<< ",allgather_tcost=" << tsum_allgather << " sec"
|
||||
<< ",allreduce_tcost=" << tsum_allreduce_ << " sec"
|
||||
<< ",allgather_tcost=" << tsum_allgather_ << " sec"
|
||||
<< ",between_chpt=" << tbet_chkpt << "sec\n";
|
||||
this->TrackerPrint(ss.str());
|
||||
}
|
||||
tsum_allreduce = 0.0;
|
||||
tsum_allgather = 0.0;
|
||||
tsum_allreduce_ = 0.0;
|
||||
tsum_allgather_ = 0.0;
|
||||
}
|
||||
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint");
|
||||
AllreduceRobust::LazyCheckPoint(global_model);
|
||||
}
|
||||
|
||||
protected:
|
||||
// force checkpoint to local
|
||||
int force_local;
|
||||
int force_local_;
|
||||
// whether report statistics
|
||||
int report_stats;
|
||||
int report_stats_;
|
||||
// sum of allreduce
|
||||
double tsum_allreduce;
|
||||
double tsum_allreduce_;
|
||||
// sum of allgather
|
||||
double tsum_allgather;
|
||||
double time_checkpoint;
|
||||
double tsum_allgather_;
|
||||
double time_checkpoint_;
|
||||
|
||||
private:
|
||||
struct DummySerializer : public Serializable {
|
||||
virtual void Load(Stream *fi) {
|
||||
}
|
||||
virtual void Save(Stream *fo) const {
|
||||
}
|
||||
void Load(Stream *fi) override {}
|
||||
void Save(Stream *fo) const override {}
|
||||
};
|
||||
struct ComboSerializer : public Serializable {
|
||||
Serializable *lhs;
|
||||
@@ -155,15 +144,15 @@ class AllreduceMock : public AllreduceRobust {
|
||||
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
|
||||
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
|
||||
: lhs(nullptr), rhs(nullptr), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
virtual void Load(Stream *fi) {
|
||||
if (lhs != NULL) lhs->Load(fi);
|
||||
if (rhs != NULL) rhs->Load(fi);
|
||||
void Load(Stream *fi) override {
|
||||
if (lhs != nullptr) lhs->Load(fi);
|
||||
if (rhs != nullptr) rhs->Load(fi);
|
||||
}
|
||||
virtual void Save(Stream *fo) const {
|
||||
if (c_lhs != NULL) c_lhs->Save(fo);
|
||||
if (c_rhs != NULL) c_rhs->Save(fo);
|
||||
void Save(Stream *fo) const override {
|
||||
if (c_lhs != nullptr) c_lhs->Save(fo);
|
||||
if (c_rhs != nullptr) c_rhs->Save(fo);
|
||||
}
|
||||
};
|
||||
// key to identify the mock stage
|
||||
@@ -172,7 +161,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
int version;
|
||||
int seqno;
|
||||
int ntrial;
|
||||
MockKey(void) {}
|
||||
MockKey() = default;
|
||||
MockKey(int rank, int version, int seqno, int ntrial)
|
||||
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
|
||||
inline bool operator==(const MockKey &b) const {
|
||||
@@ -189,15 +178,15 @@ class AllreduceMock : public AllreduceRobust {
|
||||
}
|
||||
};
|
||||
// number of failure trials
|
||||
int num_trial;
|
||||
int num_trial_;
|
||||
// record all mock actions
|
||||
std::map<MockKey, int> mock_map;
|
||||
std::map<MockKey, int> mock_map_;
|
||||
// used to generate all kinds of exceptions
|
||||
inline void Verify(const MockKey &key, const char *name) {
|
||||
if (mock_map.count(key) != 0) {
|
||||
num_trial += 1;
|
||||
if (mock_map_.count(key) != 0) {
|
||||
num_trial_ += 1;
|
||||
// data processing frameworks runs on shared process
|
||||
_error("[%d]@@@Hit Mock Error:%s ", rank, name);
|
||||
error_("[%d]@@@Hit Mock Error:%s ", rank, name);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -40,9 +40,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0) return kSuccess;
|
||||
if (links.Size() == 0) return kSuccess;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
const int nlink = static_cast<int>(links.Size());
|
||||
// initialize the pointers
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,18 +22,18 @@ namespace engine {
|
||||
/*! \brief implementation of fault tolerant all reduce engine */
|
||||
class AllreduceRobust : public AllreduceBase {
|
||||
public:
|
||||
AllreduceRobust(void);
|
||||
virtual ~AllreduceRobust(void) {}
|
||||
AllreduceRobust();
|
||||
~AllreduceRobust() override = default;
|
||||
// initialize the manager
|
||||
virtual bool Init(int argc, char* argv[]);
|
||||
bool Init(int argc, char* argv[]) override;
|
||||
/*! \brief shutdown the engine */
|
||||
virtual bool Shutdown(void);
|
||||
bool Shutdown() override;
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
void SetParam(const char *name, const char *val) override;
|
||||
/*!
|
||||
* \brief perform immutable local bootstrap cache insertion
|
||||
* \param key unique cache key
|
||||
@@ -67,13 +67,10 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override;
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@@ -90,15 +87,11 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
|
||||
void *prepare_arg = nullptr, const char *_file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override;
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
@@ -108,10 +101,9 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char* _file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char* _caller = _CALLER);
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override;
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
@@ -134,8 +126,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL);
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = nullptr) override;
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
@@ -152,9 +144,9 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) {
|
||||
this->CheckPoint_(global_model, local_model, false);
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = nullptr) override {
|
||||
this->CheckPointImpl(global_model, local_model, false);
|
||||
}
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
@@ -176,18 +168,20 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
this->CheckPoint_(global_model, NULL, true);
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
this->CheckPointImpl(global_model, nullptr, true);
|
||||
}
|
||||
/*!
|
||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throw an exception out,
|
||||
* this function is only used for test purpose
|
||||
*/
|
||||
virtual void InitAfterException(void) {
|
||||
void InitAfterException() override {
|
||||
// simple way, shutdown all links
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
||||
for (auto& link : all_links) {
|
||||
if (link.sock.BadSocket()) {
|
||||
link.sock.Close();
|
||||
}
|
||||
}
|
||||
ReConnectLinks("recover");
|
||||
}
|
||||
@@ -245,69 +239,69 @@ class AllreduceRobust : public AllreduceBase {
|
||||
// there are nodes request load cache
|
||||
static const int kLoadBootstrapCache = 16;
|
||||
// constructor
|
||||
ActionSummary(void) {}
|
||||
ActionSummary() = default;
|
||||
// constructor of action
|
||||
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
|
||||
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
|
||||
seqcode = (minseqno << 5) | seqno_flag;
|
||||
maxseqcode = (maxseqno << 5) | cache_flag;
|
||||
seqcode_ = (minseqno << 5) | seqno_flag;
|
||||
maxseqcode_ = (maxseqno << 5) | cache_flag;
|
||||
}
|
||||
// minimum number of all operations by default
|
||||
// maximum number of all cache operations otherwise
|
||||
inline u_int32_t seqno(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return code >> 5;
|
||||
}
|
||||
// whether the operation set contains a load_check
|
||||
inline bool load_check(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
inline bool LoadCheck(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kLoadCheck) != 0;
|
||||
}
|
||||
// whether the operation set contains a load_cache
|
||||
inline bool load_cache(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
inline bool LoadCache(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kLoadBootstrapCache) != 0;
|
||||
}
|
||||
// whether the operation set contains a check point
|
||||
inline bool check_point(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
inline bool CheckPoint(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kCheckPoint) != 0;
|
||||
}
|
||||
// whether the operation set contains a check ack
|
||||
inline bool check_ack(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
inline bool CheckAck(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kCheckAck) != 0;
|
||||
}
|
||||
// whether the operation set contains different sequence number
|
||||
inline bool diff_seq() const {
|
||||
return (seqcode & kDiffSeq) != 0;
|
||||
inline bool DiffSeq() const {
|
||||
return (seqcode_ & kDiffSeq) != 0;
|
||||
}
|
||||
// returns the operation flag of the result
|
||||
inline int flag(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
|
||||
inline int Flag(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return code & 31;
|
||||
}
|
||||
// print flags in user friendly way
|
||||
inline void print_flags(int rank, std::string prefix ) {
|
||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
|
||||
rank, prefix.c_str(),
|
||||
seqno(), check_point(), check_ack(), load_cache(),
|
||||
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
|
||||
inline void PrintFlags(int rank, std::string prefix ) {
|
||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank,
|
||||
prefix.c_str(), Seqno(), CheckPoint(), CheckAck(),
|
||||
LoadCache(), DiffSeq(), Seqno(SeqType::kCache),
|
||||
LoadCache(SeqType::kCache));
|
||||
}
|
||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||
inline static void Reducer(const void *src_, void *dst_,
|
||||
int len, const MPI::Datatype &dtype) {
|
||||
const ActionSummary *src = (const ActionSummary*)src_;
|
||||
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
|
||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno());
|
||||
u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache),
|
||||
dst[i].seqno(SeqType::kCache));
|
||||
int action_flag = src[i].flag() | dst[i].flag();
|
||||
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
|
||||
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
|
||||
dst[i].Seqno(SeqType::kCache));
|
||||
int action_flag = src[i].Flag() | dst[i].Flag();
|
||||
// if any node is not requester set to 0 otherwise 1
|
||||
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
|
||||
int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache);
|
||||
// if seqno is different in src and destination
|
||||
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
|
||||
int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0;
|
||||
// apply or to both seq diff flag as well as cache seq diff flag
|
||||
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
||||
role_flag, min_seqno, max_seqno);
|
||||
@@ -316,19 +310,19 @@ class AllreduceRobust : public AllreduceBase {
|
||||
|
||||
private:
|
||||
// internel sequence code min of rabit seqno
|
||||
u_int32_t seqcode;
|
||||
u_int32_t seqcode_;
|
||||
// internal sequence code max of cache seqno
|
||||
u_int32_t maxseqcode;
|
||||
u_int32_t maxseqcode_;
|
||||
};
|
||||
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
|
||||
class ResultBuffer{
|
||||
public:
|
||||
// constructor
|
||||
ResultBuffer(void) {
|
||||
ResultBuffer() {
|
||||
this->Clear();
|
||||
}
|
||||
// clear the existing record
|
||||
inline void Clear(void) {
|
||||
inline void Clear() {
|
||||
seqno_.clear(); size_.clear();
|
||||
rptr_.clear(); rptr_.push_back(0);
|
||||
data_.clear();
|
||||
@@ -358,12 +352,12 @@ class AllreduceRobust : public AllreduceBase {
|
||||
inline void* Query(int seqid, size_t *p_size) {
|
||||
size_t idx = std::lower_bound(seqno_.begin(),
|
||||
seqno_.end(), seqid) - seqno_.begin();
|
||||
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
|
||||
if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr;
|
||||
*p_size = size_[idx];
|
||||
return BeginPtr(data_) + rptr_[idx];
|
||||
}
|
||||
// drop last stored result
|
||||
inline void DropLast(void) {
|
||||
inline void DropLast() {
|
||||
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
|
||||
seqno_.pop_back();
|
||||
rptr_.pop_back();
|
||||
@@ -371,7 +365,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
data_.resize(rptr_.back());
|
||||
}
|
||||
// the sequence number of last stored result
|
||||
inline int LastSeqNo(void) const {
|
||||
inline int LastSeqNo() const {
|
||||
if (seqno_.size() == 0) return -1;
|
||||
return seqno_.back();
|
||||
}
|
||||
@@ -407,9 +401,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa CheckPoint, LazyCheckPoint
|
||||
*/
|
||||
void CheckPoint_(const Serializable *global_model,
|
||||
const Serializable *local_model,
|
||||
bool lazy_checkpt);
|
||||
void CheckPointImpl(const Serializable *global_model,
|
||||
const Serializable *local_model, bool lazy_checkpt);
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
* after this function finishes, all the messages received and sent
|
||||
@@ -423,7 +416,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||
* and some link recovery proceduer is needed
|
||||
*/
|
||||
ReturnType TryResetLinks(void);
|
||||
ReturnType TryResetLinks();
|
||||
/*!
|
||||
* \brief if err_type indicates an error
|
||||
* recover links according to the error type reported
|
||||
@@ -619,29 +612,29 @@ o * the input state must exactly one saved state(local state of current node)
|
||||
size_t out_index));
|
||||
//---- recovery data structure ----
|
||||
// the round of result buffer, used to mode the result
|
||||
int result_buffer_round;
|
||||
int result_buffer_round_;
|
||||
// result buffer of all reduce
|
||||
ResultBuffer resbuf;
|
||||
ResultBuffer resbuf_;
|
||||
// current cached allreduce/braodcast sequence number
|
||||
int cur_cache_seq;
|
||||
int cur_cache_seq_;
|
||||
// result buffer of cached all reduce
|
||||
ResultBuffer cachebuf;
|
||||
ResultBuffer cachebuf_;
|
||||
// key of each cache entry
|
||||
ResultBuffer lookupbuf;
|
||||
ResultBuffer lookupbuf_;
|
||||
// last check point global model
|
||||
std::string global_checkpoint;
|
||||
std::string global_checkpoint_;
|
||||
// lazy checkpoint of global model
|
||||
const Serializable *global_lazycheck;
|
||||
const Serializable *global_lazycheck_;
|
||||
// number of replica for local state/model
|
||||
int num_local_replica;
|
||||
int num_local_replica_;
|
||||
// number of default local replica
|
||||
int default_local_replica;
|
||||
int default_local_replica_;
|
||||
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
|
||||
int use_local_model;
|
||||
int use_local_model_;
|
||||
// number of replica for global state/model
|
||||
int num_global_replica;
|
||||
int num_global_replica_;
|
||||
// number of times recovery happens
|
||||
int recover_counter;
|
||||
int recover_counter_;
|
||||
// --- recovery data structure for local checkpoint
|
||||
// there is two version of the data structure,
|
||||
// at one time one version is valid and another is used as temp memory
|
||||
@@ -649,21 +642,21 @@ o * the input state must exactly one saved state(local state of current node)
|
||||
// local model is stored in CSR format(like a sparse matrices)
|
||||
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
|
||||
std::vector<size_t> local_rptr[2];
|
||||
std::vector<size_t> local_rptr_[2];
|
||||
// storage for local model replicas
|
||||
std::string local_chkpt[2];
|
||||
std::string local_chkpt_[2];
|
||||
// version of local checkpoint can be 1 or 0
|
||||
int local_chkpt_version;
|
||||
int local_chkpt_version_;
|
||||
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
|
||||
bool checkpoint_loaded;
|
||||
bool checkpoint_loaded_;
|
||||
// sidecar executing timeout task
|
||||
std::future<bool> rabit_timeout_task;
|
||||
std::future<bool> rabit_timeout_task_;
|
||||
// flag to shutdown rabit_timeout_task before timeout
|
||||
std::atomic<bool> shutdown_timeout{false};
|
||||
std::atomic<bool> shutdown_timeout_{false};
|
||||
// error handler
|
||||
void (* _error)(const char *fmt, ...) = utils::Error;
|
||||
void (* error_)(const char *fmt, ...) = utils::Error;
|
||||
// assert handler
|
||||
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
|
||||
void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@@ -33,12 +33,12 @@ struct FHelper<op::BitOR, DType> {
|
||||
};
|
||||
|
||||
template<typename OP>
|
||||
void Allreduce_(void *sendrecvbuf_,
|
||||
void Allreduce(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
engine::mpi::DataType enum_dtype,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi;
|
||||
using namespace engine::mpi; // NOLINT
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
rabit::Allreduce<OP>
|
||||
@@ -89,28 +89,28 @@ void Allreduce(void *sendrecvbuf,
|
||||
engine::mpi::OpType enum_op,
|
||||
void (*prepare_fun)(void *arg),
|
||||
void *prepare_arg) {
|
||||
using namespace engine::mpi;
|
||||
using namespace engine::mpi; // NOLINT
|
||||
switch (enum_op) {
|
||||
case kMax:
|
||||
Allreduce_<op::Max>
|
||||
Allreduce<op::Max>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kMin:
|
||||
Allreduce_<op::Min>
|
||||
Allreduce<op::Min>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kSum:
|
||||
Allreduce_<op::Sum>
|
||||
Allreduce<op::Sum>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
return;
|
||||
case kBitwiseOR:
|
||||
Allreduce_<op::BitOR>
|
||||
Allreduce<op::BitOR>
|
||||
(sendrecvbuf,
|
||||
count, enum_dtype,
|
||||
prepare_fun, prepare_arg);
|
||||
@@ -124,7 +124,7 @@ void Allgather(void *sendrecvbuf_,
|
||||
size_t size_node_slice,
|
||||
size_t size_prev_slice,
|
||||
int enum_dtype) {
|
||||
using namespace engine::mpi;
|
||||
using namespace engine::mpi; // NOLINT
|
||||
size_t type_size = 0;
|
||||
switch (enum_dtype) {
|
||||
case kChar:
|
||||
@@ -184,7 +184,7 @@ struct ReadWrapper : public Serializable {
|
||||
std::string *p_str;
|
||||
explicit ReadWrapper(std::string *p_str)
|
||||
: p_str(p_str) {}
|
||||
virtual void Load(Stream *fi) {
|
||||
void Load(Stream *fi) override {
|
||||
uint64_t sz;
|
||||
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
|
||||
"Read pickle string");
|
||||
@@ -194,7 +194,7 @@ struct ReadWrapper : public Serializable {
|
||||
"Read pickle string");
|
||||
}
|
||||
}
|
||||
virtual void Save(Stream *fo) const {
|
||||
void Save(Stream *fo) const override {
|
||||
utils::Error("not implemented");
|
||||
}
|
||||
};
|
||||
@@ -206,10 +206,10 @@ struct WriteWrapper : public Serializable {
|
||||
size_t length)
|
||||
: data(data), length(length) {
|
||||
}
|
||||
virtual void Load(Stream *fi) {
|
||||
void Load(Stream *fi) override {
|
||||
utils::Error("not implemented");
|
||||
}
|
||||
virtual void Save(Stream *fo) const {
|
||||
void Save(Stream *fo) const override {
|
||||
uint64_t sz = static_cast<uint16_t>(length);
|
||||
fo->Write(&sz, sizeof(sz));
|
||||
fo->Write(data, length * sizeof(char));
|
||||
@@ -298,8 +298,8 @@ RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
|
||||
ReadWrapper sl(&local_buffer);
|
||||
int version;
|
||||
|
||||
if (out_local_model == NULL) {
|
||||
version = rabit::LoadCheckPoint(&sg, NULL);
|
||||
if (out_local_model == nullptr) {
|
||||
version = rabit::LoadCheckPoint(&sg, nullptr);
|
||||
*out_global_model = BeginPtr(global_buffer);
|
||||
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
|
||||
} else {
|
||||
@@ -317,8 +317,8 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len,
|
||||
using namespace rabit::c_api; // NOLINT(*)
|
||||
WriteWrapper sg(global_model, global_len);
|
||||
WriteWrapper sl(local_model, local_len);
|
||||
if (local_model == NULL) {
|
||||
rabit::CheckPoint(&sg, NULL);
|
||||
if (local_model == nullptr) {
|
||||
rabit::CheckPoint(&sg, nullptr);
|
||||
} else {
|
||||
rabit::CheckPoint(&sg, &sl);
|
||||
}
|
||||
|
||||
@@ -7,18 +7,19 @@
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#include <rabit/base.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include <memory>
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "allreduce_base.h"
|
||||
#include "allreduce_robust.h"
|
||||
#include "rabit/internal/thread_local.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
// singleton sync manager
|
||||
#ifndef RABIT_USE_BASE
|
||||
#ifndef RABIT_USE_MOCK
|
||||
typedef AllreduceRobust Manager;
|
||||
using Manager = AllreduceRobust;
|
||||
#else
|
||||
typedef AllreduceMock Manager;
|
||||
#endif // RABIT_USE_MOCK
|
||||
@@ -31,13 +32,13 @@ struct ThreadLocalEntry {
|
||||
/*! \brief stores the current engine */
|
||||
std::unique_ptr<Manager> engine;
|
||||
/*! \brief whether init has been called */
|
||||
bool initialized;
|
||||
bool initialized{false};
|
||||
/*! \brief constructor */
|
||||
ThreadLocalEntry() : initialized(false) {}
|
||||
ThreadLocalEntry() = default;
|
||||
};
|
||||
|
||||
// define the threadlocal store.
|
||||
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
||||
using EngineThreadLocal = dmlc::ThreadLocalStore<ThreadLocalEntry>;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
bool Init(int argc, char *argv[]) {
|
||||
@@ -95,7 +96,7 @@ void Allgather(void *sendrecvbuf_, size_t total_size,
|
||||
|
||||
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
void Allreduce_(void *sendrecvbuf, // NOLINT
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
@@ -111,18 +112,15 @@ void Allreduce_(void *sendrecvbuf,
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void)
|
||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||
}
|
||||
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
ReduceHandle::ReduceHandle() = default;
|
||||
ReduceHandle::~ReduceHandle() = default;
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return static_cast<int>(dtype.type_size);
|
||||
}
|
||||
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
|
||||
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
|
||||
redfunc_ = redfunc;
|
||||
}
|
||||
|
||||
@@ -133,7 +131,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
|
||||
utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce");
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
redfunc_, prepare_fun, prepare_arg,
|
||||
_file, _line, _caller);
|
||||
|
||||
@@ -16,78 +16,68 @@ namespace engine {
|
||||
/*! \brief EmptyEngine */
|
||||
class EmptyEngine : public IEngine {
|
||||
public:
|
||||
EmptyEngine(void) {
|
||||
version_number = 0;
|
||||
EmptyEngine() {
|
||||
version_number_ = 0;
|
||||
}
|
||||
virtual void Allgather(void *sendrecvbuf_,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
size_t slice_end,
|
||||
size_t size_prev_slice,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice, const char *_file,
|
||||
const int _line, const char *_caller) override {
|
||||
utils::Error("EmptyEngine:: Allgather is not supported");
|
||||
}
|
||||
virtual int GetRingPrevRank(void) const {
|
||||
int GetRingPrevRank() const override {
|
||||
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
|
||||
return -1;
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun,
|
||||
void *prepare_arg, const char *_file, const int _line,
|
||||
const char *_caller) override {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
const char* _file, const int _line, const char* _caller) {
|
||||
void Broadcast(void *sendrecvbuf_, size_t size, int root,
|
||||
const char* _file, const int _line, const char* _caller) override {
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
void InitAfterException() override {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL) {
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = nullptr) override {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = nullptr) override {
|
||||
version_number_ += 1;
|
||||
}
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
version_number += 1;
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
version_number_ += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
return version_number;
|
||||
int VersionNumber() const override {
|
||||
return version_number_;
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
virtual int GetRank(void) const {
|
||||
int GetRank() const override {
|
||||
return 0;
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
virtual int GetWorldSize(void) const {
|
||||
int GetWorldSize() const override {
|
||||
return 1;
|
||||
}
|
||||
/*! \brief whether it is distributed */
|
||||
virtual bool IsDistributed(void) const {
|
||||
bool IsDistributed() const override {
|
||||
return false;
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const {
|
||||
std::string GetHost() const override {
|
||||
return std::string("");
|
||||
}
|
||||
virtual void TrackerPrint(const std::string &msg) {
|
||||
void TrackerPrint(const std::string &msg) override {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
int version_number;
|
||||
int version_number_;
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
@@ -98,12 +88,12 @@ bool Init(int argc, char *argv[]) {
|
||||
return true;
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
bool Finalize(void) {
|
||||
bool Finalize() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void) {
|
||||
IEngine *GetEngine() {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
@@ -118,13 +108,12 @@ void Allreduce_(void *sendrecvbuf,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
ReduceHandle::ReduceHandle() = default;
|
||||
ReduceHandle::~ReduceHandle() = default;
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return 0;
|
||||
@@ -137,7 +126,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
const char* _file,
|
||||
const int _line,
|
||||
const char* _caller) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mock.cc
|
||||
* \brief this is an engine implementation that will
|
||||
* \brief this is an engine implementation that will
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@@ -12,4 +12,3 @@
|
||||
#include <rabit/base.h>
|
||||
#include "allreduce_mock.h"
|
||||
#include "engine.cc"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user