Correct style warnings from clang-tidy for rabit. (#6095)

This commit is contained in:
Jiaming Yuan
2020-09-08 12:13:58 +08:00
committed by GitHub
parent da61d9460b
commit b0001a6e29
19 changed files with 736 additions and 868 deletions

View File

@@ -15,7 +15,7 @@
namespace rabit {
namespace engine {
// constructor
AllreduceBase::AllreduceBase(void) {
AllreduceBase::AllreduceBase() {
tracker_uri = "NULL";
tracker_port = 9000;
host_uri = "";
@@ -24,7 +24,7 @@ AllreduceBase::AllreduceBase(void) {
rank = 0;
world_size = -1;
connect_retry = 5;
hadoop_mode = 0;
hadoop_mode = false;
version_number = 0;
// 32 K items
reduce_ring_mincount = 32 << 10;
@@ -32,27 +32,27 @@ AllreduceBase::AllreduceBase(void) {
tree_reduce_minsize = 1 << 20;
// tracker URL
task_id = "NULL";
err_link = NULL;
err_link = nullptr;
dmlc_role = "worker";
this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of interest
// include dmlc support direct variables
env_vars.push_back("DMLC_TASK_ID");
env_vars.push_back("DMLC_ROLE");
env_vars.push_back("DMLC_NUM_ATTEMPT");
env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
env_vars.emplace_back("DMLC_TASK_ID");
env_vars.emplace_back("DMLC_ROLE");
env_vars.emplace_back("DMLC_NUM_ATTEMPT");
env_vars.emplace_back("DMLC_TRACKER_URI");
env_vars.emplace_back("DMLC_TRACKER_PORT");
env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY");
}
// initialization function
bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
const char *value = getenv(env_vars[i].c_str());
if (value != NULL) {
this->SetParam(env_vars[i].c_str(), value);
for (auto & env_var : env_vars) {
const char *value = getenv(env_var.c_str());
if (value != nullptr) {
this->SetParam(env_var.c_str(), value);
}
}
// pass in arguments override env variable.
@@ -66,35 +66,35 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) {
if (task_id == nullptr) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode) {
utils::Check(task_id != NULL,
utils::Check(task_id != nullptr,
"hadoop_mode is set but cannot find mapred_task_id");
}
if (task_id != NULL) {
if (task_id != nullptr) {
this->SetParam("rabit_task_id", task_id);
this->SetParam("rabit_hadoop_mode", "1");
}
const char *attempt_id = getenv("mapred_task_id");
if (attempt_id != 0) {
if (attempt_id != nullptr) {
const char *att = strrchr(attempt_id, '_');
int num_trial;
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1);
}
}
// handling for hadoop
const char *num_task = getenv("mapred_map_tasks");
if (num_task == NULL) {
if (num_task == nullptr) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode) {
utils::Check(num_task != NULL,
utils::Check(num_task != nullptr,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
if (num_task != NULL) {
if (num_task != nullptr) {
this->SetParam("rabit_world_size", num_task);
}
}
@@ -115,10 +115,10 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
return this->ReConnectLinks();
}
bool AllreduceBase::Shutdown(void) {
bool AllreduceBase::Shutdown() {
try {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
for (auto & all_link : all_links) {
all_link.sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
@@ -208,17 +208,18 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
}
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
if (!strcmp(val, "true"))
if (!strcmp(val, "true")) {
rabit_enable_tcp_no_delay = true;
else
} else {
rabit_enable_tcp_no_delay = false;
}
}
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
utils::TCPSocket AllreduceBase::ConnectTracker() const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
@@ -241,7 +242,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
}
}
break;
} while (1);
} while (true);
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
@@ -320,19 +321,19 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
for (auto & all_link : all_links) {
if (!all_link.sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_link.rank));
} else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
if (!all_link.sock.IsClosed()) all_link.sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
sizeof(good_link[i]), "ReConnectLink failure 6");
for (int & i : good_link) {
Assert(tracker.SendAll(&i, sizeof(i)) == \
sizeof(i), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
@@ -362,11 +363,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
for (auto & all_link : all_links) {
if (all_link.rank == hrank) {
Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
all_link.sock = r.sock;
match = true;
break;
}
@@ -390,11 +391,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
for (auto & all_link : all_links) {
if (all_link.rank == r.rank) {
utils::Assert(all_link.sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
all_link.sock = r.sock;
match = true;
break;
}
@@ -406,29 +407,29 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
// setup tree links and ring structure
tree_links.plinks.clear();
int tcpNoDelay = 1;
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
for (auto & all_link : all_links) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
all_link.sock.SetNonBlock(true);
all_link.sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) {
setsockopt(all_links[i].sock, IPPROTO_TCP,
setsockopt(all_link.sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
}
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_links[i]);
tree_links.plinks.push_back(&all_link);
}
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
if (all_link.rank == prev_rank) ring_prev = &all_link;
if (all_link.rank == next_rank) ring_next = &all_link;
}
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
Assert(prev_rank == -1 || ring_prev != nullptr,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
Assert(next_rank == -1 || ring_next != nullptr,
"cannot find next ring in the link");
return true;
} catch (const std::exception& e) {
@@ -479,11 +480,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess;
if (links.Size() == 0 || count == 0) return kSuccess;
// total size of message
const size_t total_size = type_nbytes * count;
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(links.Size());
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// size of space that we already performs reduce in up pass
@@ -581,8 +582,9 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
// if max reduce is less than total size, we reduce multiple times of
// eachreduce size
if (max_reduce < total_size)
if (max_reduce < total_size) {
max_reduce = max_reduce - max_reduce % eachreduce;
}
// peform reduce, can be at most two rounds
while (size_up_reduce < max_reduce) {
@@ -677,11 +679,11 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess;
if (links.Size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(links.Size());
// size of space already read from data
size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root

View File

@@ -25,7 +25,7 @@
#endif // RABIT_CXXTESTDEFS_H
namespace MPI {
namespace MPI { // NOLINT
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
@@ -41,12 +41,12 @@ class AllreduceBase : public IEngine {
// magic number to verify server
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
AllreduceBase();
virtual ~AllreduceBase() = default;
// initialize the manager
virtual bool Init(int argc, char* argv[]);
// shutdown the engine
virtual bool Shutdown(void);
virtual bool Shutdown();
/*!
* \brief set parameters to the engine
* \param name parameter name
@@ -59,27 +59,27 @@ class AllreduceBase : public IEngine {
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
void TrackerPrint(const std::string &msg) override;
/*! \brief get rank of previous node in ring topology*/
virtual int GetRingPrevRank(void) const {
int GetRingPrevRank() const override {
return ring_prev->rank;
}
/*! \brief get rank */
virtual int GetRank(void) const {
int GetRank() const override {
return rank;
}
/*! \brief get rank */
virtual int GetWorldSize(void) const {
int GetWorldSize() const override {
if (world_size == -1) return 1;
return world_size;
}
/*! \brief whether is distributed or not */
virtual bool IsDistributed(void) const {
bool IsDistributed() const override {
return tracker_uri != "NULL";
}
/*! \brief get rank */
virtual std::string GetHost(void) const {
std::string GetHost() const override {
return host_uri;
}
@@ -99,13 +99,10 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
slice_begin, slice_end, size_prev_slice) == kSuccess,
@@ -126,16 +123,12 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
@@ -150,8 +143,9 @@ class AllreduceBase : public IEngine {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) {
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"Broadcast failed");
@@ -178,8 +172,8 @@ class AllreduceBase : public IEngine {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override {
return 0;
}
/*!
@@ -198,8 +192,8 @@ class AllreduceBase : public IEngine {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number += 1;
}
/*!
@@ -222,7 +216,7 @@ class AllreduceBase : public IEngine {
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const Serializable *global_model) {
void LazyCheckPoint(const Serializable *global_model) override {
version_number += 1;
}
/*!
@@ -230,7 +224,7 @@ class AllreduceBase : public IEngine {
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const {
int VersionNumber() const override {
return version_number;
}
/*!
@@ -238,14 +232,14 @@ class AllreduceBase : public IEngine {
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
void InitAfterException() override {
utils::Error("InitAfterException: not implemented");
}
/*!
* \brief report current status to the job tracker
* depending on the job tracker we are in
*/
inline void ReportStatus(void) const {
inline void ReportStatus() const {
if (hadoop_mode != 0) {
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
version_number, seq_counter);
@@ -274,7 +268,7 @@ class AllreduceBase : public IEngine {
/*! \brief internal return type */
ReturnTypeEnum value;
// constructor
ReturnType() {}
ReturnType() = default;
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
inline bool operator==(const ReturnTypeEnum &v) const {
return value == v;
@@ -306,13 +300,11 @@ class AllreduceBase : public IEngine {
// size of data sent to the link
size_t size_write;
// pointer to buffer head
char *buffer_head;
char *buffer_head {nullptr};
// buffer size, in bytes
size_t buffer_size;
size_t buffer_size {0};
// constructor
LinkRecord(void)
: buffer_head(NULL), buffer_size(0) {
}
LinkRecord() = default;
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
@@ -328,7 +320,7 @@ class AllreduceBase : public IEngine {
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}
// reset the recv and sent size
inline void ResetSize(void) {
inline void ResetSize() {
size_write = size_read = 0;
}
/*!
@@ -340,7 +332,7 @@ class AllreduceBase : public IEngine {
* \return the type of reading
*/
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated");
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
@@ -405,7 +397,7 @@ class AllreduceBase : public IEngine {
inline LinkRecord &operator[](size_t i) {
return *plinks[i];
}
inline size_t size(void) const {
inline size_t Size() const {
return plinks.size();
}
};
@@ -413,7 +405,7 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker(void) const;
utils::TCPSocket ConnectTracker() const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
@@ -525,64 +517,64 @@ class AllreduceBase : public IEngine {
//---- data structure related to model ----
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
int seq_counter;
int seq_counter; // NOLINT
// version number of model
int version_number;
int version_number; // NOLINT
// whether the job is running in hadoop
bool hadoop_mode;
bool hadoop_mode; // NOLINT
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index;
int parent_index; // NOLINT
// rank of parent node, can be -1
int parent_rank;
int parent_rank; // NOLINT
// sockets of all links this connects to
std::vector<LinkRecord> all_links;
std::vector<LinkRecord> all_links; // NOLINT
// used to record the link where things goes wrong
LinkRecord *err_link;
LinkRecord *err_link; // NOLINT
// all the links in the reduction tree connection
RefLinkVector tree_links;
RefLinkVector tree_links; // NOLINT
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
LinkRecord *ring_prev, *ring_next; // NOLINT
//----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars;
std::vector<std::string> env_vars; // NOLINT
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string task_id;
std::string task_id; // NOLINT
// uri of current host, to be set by Init
std::string host_uri;
std::string host_uri; // NOLINT
// uri of tracker
std::string tracker_uri;
std::string tracker_uri; // NOLINT
// role in dmlc jobs
std::string dmlc_role;
std::string dmlc_role; // NOLINT
// port of tracker address
int tracker_port;
int tracker_port; // NOLINT
// port of slave process
int slave_port, nport_trial;
int slave_port, nport_trial; // NOLINT
// reduce buffer size
size_t reduce_buffer_size;
size_t reduce_buffer_size; // NOLINT
// reduction method
int reduce_method;
int reduce_method; // NOLINT
// mininum count of cells to use ring based method
size_t reduce_ring_mincount;
size_t reduce_ring_mincount; // NOLINT
// minimul block size per tree reduce
size_t tree_reduce_minsize;
size_t tree_reduce_minsize; // NOLINT
// current rank
int rank;
int rank; // NOLINT
// world size
int world_size;
int world_size; // NOLINT
// connect retry time
int connect_retry;
int connect_retry; // NOLINT
// enable bootstrap cache 0 false 1 true
bool rabit_bootstrap_cache = false;
bool rabit_bootstrap_cache = false; // NOLINT
// enable detailed logging
bool rabit_debug = false;
bool rabit_debug = false; // NOLINT
// by default, if rabit worker not recover in half an hour exit
int timeout_sec = 1800;
int timeout_sec = 1800; // NOLINT
// flag to enable rabit_timeout
bool rabit_timeout = false;
bool rabit_timeout = false; // NOLINT
// Enable TCP node delay
bool rabit_enable_tcp_no_delay = false;
bool rabit_enable_tcp_no_delay = false; // NOLINT
};
} // namespace engine
} // namespace rabit

View File

@@ -20,74 +20,65 @@ namespace engine {
class AllreduceMock : public AllreduceRobust {
public:
// constructor
AllreduceMock(void) {
num_trial = 0;
force_local = 0;
report_stats = 0;
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
AllreduceMock() {
num_trial_ = 0;
force_local_ = 0;
report_stats_ = 0;
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
}
// destructor
virtual ~AllreduceMock(void) {}
virtual void SetParam(const char *name, const char *val) {
~AllreduceMock() override = default;
void SetParam(const char *name, const char *val) override {
AllreduceRobust::SetParam(name, val);
// additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val);
if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
if (!strcmp(name, "report_stats")) report_stats_ = atoi(val);
if (!strcmp(name, "force_local")) force_local_ = atoi(val);
if (!strcmp(name, "mock")) {
MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d",
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
"invalid mock parameter");
mock_map[k] = 1;
mock_map_[k] = 1;
}
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg,
_file, _line, _caller);
tsum_allreduce += utils::GetTime() - tstart;
tsum_allreduce_ += utils::GetTime() - tstart;
}
virtual void Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
double tstart = utils::GetTime();
AllreduceRobust::Allgather(sendrecvbuf, total_size,
slice_begin, slice_end,
size_prev_slice, _file, _line, _caller);
tsum_allgather += utils::GetTime() - tstart;
tsum_allgather_ += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) override {
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
time_checkpoint_ = utils::GetTime();
if (force_local_ == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
} else {
DummySerializer dum;
@@ -95,56 +86,54 @@ class AllreduceMock : public AllreduceRobust {
return AllreduceRobust::LoadCheckPoint(&dum, &com);
}
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
void CheckPoint(const Serializable *global_model,
const Serializable *local_model) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint");
double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint;
if (force_local == 0) {
double tbet_chkpt = tstart - time_checkpoint_;
if (force_local_ == 0) {
AllreduceRobust::CheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com);
}
time_checkpoint = utils::GetTime();
time_checkpoint_ = utils::GetTime();
double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) {
if (report_stats_ != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length()
<< ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length())
<< ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
<< ",allgather_tcost=" << tsum_allgather << " sec"
<< ",allreduce_tcost=" << tsum_allreduce_ << " sec"
<< ",allgather_tcost=" << tsum_allgather_ << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
tsum_allreduce_ = 0.0;
tsum_allgather_ = 0.0;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
void LazyCheckPoint(const Serializable *global_model) override {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model);
}
protected:
// force checkpoint to local
int force_local;
int force_local_;
// whether report statistics
int report_stats;
int report_stats_;
// sum of allreduce
double tsum_allreduce;
double tsum_allreduce_;
// sum of allgather
double tsum_allgather;
double time_checkpoint;
double tsum_allgather_;
double time_checkpoint_;
private:
struct DummySerializer : public Serializable {
virtual void Load(Stream *fi) {
}
virtual void Save(Stream *fo) const {
}
void Load(Stream *fi) override {}
void Save(Stream *fo) const override {}
};
struct ComboSerializer : public Serializable {
Serializable *lhs;
@@ -155,15 +144,15 @@ class AllreduceMock : public AllreduceRobust {
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
}
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
: lhs(nullptr), rhs(nullptr), c_lhs(lhs), c_rhs(rhs) {
}
virtual void Load(Stream *fi) {
if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi);
void Load(Stream *fi) override {
if (lhs != nullptr) lhs->Load(fi);
if (rhs != nullptr) rhs->Load(fi);
}
virtual void Save(Stream *fo) const {
if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo);
void Save(Stream *fo) const override {
if (c_lhs != nullptr) c_lhs->Save(fo);
if (c_rhs != nullptr) c_rhs->Save(fo);
}
};
// key to identify the mock stage
@@ -172,7 +161,7 @@ class AllreduceMock : public AllreduceRobust {
int version;
int seqno;
int ntrial;
MockKey(void) {}
MockKey() = default;
MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const {
@@ -189,15 +178,15 @@ class AllreduceMock : public AllreduceRobust {
}
};
// number of failure trials
int num_trial;
int num_trial_;
// record all mock actions
std::map<MockKey, int> mock_map;
std::map<MockKey, int> mock_map_;
// used to generate all kinds of exceptions
inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) {
num_trial += 1;
if (mock_map_.count(key) != 0) {
num_trial_ += 1;
// data processing frameworks runs on shared process
_error("[%d]@@@Hit Mock Error:%s ", rank, name);
error_("[%d]@@@Hit Mock Error:%s ", rank, name);
}
}
};

View File

@@ -40,9 +40,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
if (links.Size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(links.Size());
// initialize the pointers
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();

File diff suppressed because it is too large Load Diff

View File

@@ -22,18 +22,18 @@ namespace engine {
/*! \brief implementation of fault tolerant all reduce engine */
class AllreduceRobust : public AllreduceBase {
public:
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
AllreduceRobust();
~AllreduceRobust() override = default;
// initialize the manager
virtual bool Init(int argc, char* argv[]);
bool Init(int argc, char* argv[]) override;
/*! \brief shutdown the engine */
virtual bool Shutdown(void);
bool Shutdown() override;
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
void SetParam(const char *name, const char *val) override;
/*!
* \brief perform immutable local bootstrap cache insertion
* \param key unique cache key
@@ -67,13 +67,10 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override;
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
@@ -90,15 +87,11 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
void *prepare_arg = nullptr, const char *_file = _FILE,
const int _line = _LINE,
const char *_caller = _CALLER) override;
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
@@ -108,10 +101,9 @@ class AllreduceRobust : public AllreduceBase {
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char *_file = _FILE, const int _line = _LINE,
const char *_caller = _CALLER) override;
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
@@ -134,8 +126,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL);
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override;
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
@@ -152,9 +144,9 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false);
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
this->CheckPointImpl(global_model, local_model, false);
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
@@ -176,18 +168,20 @@ class AllreduceRobust : public AllreduceBase {
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const Serializable *global_model) {
this->CheckPoint_(global_model, NULL, true);
void LazyCheckPoint(const Serializable *global_model) override {
this->CheckPointImpl(global_model, nullptr, true);
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
void InitAfterException() override {
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
for (auto& link : all_links) {
if (link.sock.BadSocket()) {
link.sock.Close();
}
}
ReConnectLinks("recover");
}
@@ -245,69 +239,69 @@ class AllreduceRobust : public AllreduceBase {
// there are nodes request load cache
static const int kLoadBootstrapCache = 16;
// constructor
ActionSummary(void) {}
ActionSummary() = default;
// constructor of action
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
seqcode = (minseqno << 5) | seqno_flag;
maxseqcode = (maxseqno << 5) | cache_flag;
seqcode_ = (minseqno << 5) | seqno_flag;
maxseqcode_ = (maxseqno << 5) | cache_flag;
}
// minimum number of all operations by default
// maximum number of all cache operations otherwise
inline u_int32_t seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline u_int32_t Seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code >> 5;
}
// whether the operation set contains a load_check
inline bool load_check(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool LoadCheck(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kLoadCheck) != 0;
}
// whether the operation set contains a load_cache
inline bool load_cache(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool LoadCache(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kLoadBootstrapCache) != 0;
}
// whether the operation set contains a check point
inline bool check_point(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool CheckPoint(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kCheckPoint) != 0;
}
// whether the operation set contains a check ack
inline bool check_ack(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline bool CheckAck(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return (code & kCheckAck) != 0;
}
// whether the operation set contains different sequence number
inline bool diff_seq() const {
return (seqcode & kDiffSeq) != 0;
inline bool DiffSeq() const {
return (seqcode_ & kDiffSeq) != 0;
}
// returns the operation flag of the result
inline int flag(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
inline int Flag(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
return code & 31;
}
// print flags in user friendly way
inline void print_flags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
rank, prefix.c_str(),
seqno(), check_point(), check_ack(), load_cache(),
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
inline void PrintFlags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank,
prefix.c_str(), Seqno(), CheckPoint(), CheckAck(),
LoadCache(), DiffSeq(), Seqno(SeqType::kCache),
LoadCache(SeqType::kCache));
}
// reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_;
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) {
u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno());
u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache),
dst[i].seqno(SeqType::kCache));
int action_flag = src[i].flag() | dst[i].flag();
u_int32_t min_seqno = std::min(src[i].Seqno(), dst[i].Seqno());
u_int32_t max_seqno = std::max(src[i].Seqno(SeqType::kCache),
dst[i].Seqno(SeqType::kCache));
int action_flag = src[i].Flag() | dst[i].Flag();
// if any node is not requester set to 0 otherwise 1
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache);
// if seqno is different in src and destination
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0;
// apply or to both seq diff flag as well as cache seq diff flag
dst[i] = ActionSummary(action_flag | seq_diff_flag,
role_flag, min_seqno, max_seqno);
@@ -316,19 +310,19 @@ class AllreduceRobust : public AllreduceBase {
private:
// internel sequence code min of rabit seqno
u_int32_t seqcode;
u_int32_t seqcode_;
// internal sequence code max of cache seqno
u_int32_t maxseqcode;
u_int32_t maxseqcode_;
};
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
class ResultBuffer{
public:
// constructor
ResultBuffer(void) {
ResultBuffer() {
this->Clear();
}
// clear the existing record
inline void Clear(void) {
inline void Clear() {
seqno_.clear(); size_.clear();
rptr_.clear(); rptr_.push_back(0);
data_.clear();
@@ -358,12 +352,12 @@ class AllreduceRobust : public AllreduceBase {
inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr;
*p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx];
}
// drop last stored result
inline void DropLast(void) {
inline void DropLast() {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
seqno_.pop_back();
rptr_.pop_back();
@@ -371,7 +365,7 @@ class AllreduceRobust : public AllreduceBase {
data_.resize(rptr_.back());
}
// the sequence number of last stored result
inline int LastSeqNo(void) const {
inline int LastSeqNo() const {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
@@ -407,9 +401,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa CheckPoint, LazyCheckPoint
*/
void CheckPoint_(const Serializable *global_model,
const Serializable *local_model,
bool lazy_checkpt);
void CheckPointImpl(const Serializable *global_model,
const Serializable *local_model, bool lazy_checkpt);
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent
@@ -423,7 +416,7 @@ class AllreduceRobust : public AllreduceBase {
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
*/
ReturnType TryResetLinks(void);
ReturnType TryResetLinks();
/*!
* \brief if err_type indicates an error
* recover links according to the error type reported
@@ -619,29 +612,29 @@ o * the input state must exactly one saved state(local state of current node)
size_t out_index));
//---- recovery data structure ----
// the round of result buffer, used to mode the result
int result_buffer_round;
int result_buffer_round_;
// result buffer of all reduce
ResultBuffer resbuf;
ResultBuffer resbuf_;
// current cached allreduce/braodcast sequence number
int cur_cache_seq;
int cur_cache_seq_;
// result buffer of cached all reduce
ResultBuffer cachebuf;
ResultBuffer cachebuf_;
// key of each cache entry
ResultBuffer lookupbuf;
ResultBuffer lookupbuf_;
// last check point global model
std::string global_checkpoint;
std::string global_checkpoint_;
// lazy checkpoint of global model
const Serializable *global_lazycheck;
const Serializable *global_lazycheck_;
// number of replica for local state/model
int num_local_replica;
int num_local_replica_;
// number of default local replica
int default_local_replica;
int default_local_replica_;
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
int use_local_model;
int use_local_model_;
// number of replica for global state/model
int num_global_replica;
int num_global_replica_;
// number of times recovery happens
int recover_counter;
int recover_counter_;
// --- recovery data structure for local checkpoint
// there is two version of the data structure,
// at one time one version is valid and another is used as temp memory
@@ -649,21 +642,21 @@ o * the input state must exactly one saved state(local state of current node)
// local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2];
std::vector<size_t> local_rptr_[2];
// storage for local model replicas
std::string local_chkpt[2];
std::string local_chkpt_[2];
// version of local checkpoint can be 1 or 0
int local_chkpt_version;
int local_chkpt_version_;
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
bool checkpoint_loaded;
bool checkpoint_loaded_;
// sidecar executing timeout task
std::future<bool> rabit_timeout_task;
std::future<bool> rabit_timeout_task_;
// flag to shutdown rabit_timeout_task before timeout
std::atomic<bool> shutdown_timeout{false};
std::atomic<bool> shutdown_timeout_{false};
// error handler
void (* _error)(const char *fmt, ...) = utils::Error;
void (* error_)(const char *fmt, ...) = utils::Error;
// assert handler
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert;
};
} // namespace engine
} // namespace rabit

View File

@@ -33,12 +33,12 @@ struct FHelper<op::BitOR, DType> {
};
template<typename OP>
void Allreduce_(void *sendrecvbuf_,
void Allreduce(void *sendrecvbuf_,
size_t count,
engine::mpi::DataType enum_dtype,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
using namespace engine::mpi; // NOLINT
switch (enum_dtype) {
case kChar:
rabit::Allreduce<OP>
@@ -89,28 +89,28 @@ void Allreduce(void *sendrecvbuf,
engine::mpi::OpType enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
using namespace engine::mpi; // NOLINT
switch (enum_op) {
case kMax:
Allreduce_<op::Max>
Allreduce<op::Max>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kMin:
Allreduce_<op::Min>
Allreduce<op::Min>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kSum:
Allreduce_<op::Sum>
Allreduce<op::Sum>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseOR:
Allreduce_<op::BitOR>
Allreduce<op::BitOR>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
@@ -124,7 +124,7 @@ void Allgather(void *sendrecvbuf_,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype) {
using namespace engine::mpi;
using namespace engine::mpi; // NOLINT
size_t type_size = 0;
switch (enum_dtype) {
case kChar:
@@ -184,7 +184,7 @@ struct ReadWrapper : public Serializable {
std::string *p_str;
explicit ReadWrapper(std::string *p_str)
: p_str(p_str) {}
virtual void Load(Stream *fi) {
void Load(Stream *fi) override {
uint64_t sz;
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
"Read pickle string");
@@ -194,7 +194,7 @@ struct ReadWrapper : public Serializable {
"Read pickle string");
}
}
virtual void Save(Stream *fo) const {
void Save(Stream *fo) const override {
utils::Error("not implemented");
}
};
@@ -206,10 +206,10 @@ struct WriteWrapper : public Serializable {
size_t length)
: data(data), length(length) {
}
virtual void Load(Stream *fi) {
void Load(Stream *fi) override {
utils::Error("not implemented");
}
virtual void Save(Stream *fo) const {
void Save(Stream *fo) const override {
uint64_t sz = static_cast<uint16_t>(length);
fo->Write(&sz, sizeof(sz));
fo->Write(data, length * sizeof(char));
@@ -298,8 +298,8 @@ RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
ReadWrapper sl(&local_buffer);
int version;
if (out_local_model == NULL) {
version = rabit::LoadCheckPoint(&sg, NULL);
if (out_local_model == nullptr) {
version = rabit::LoadCheckPoint(&sg, nullptr);
*out_global_model = BeginPtr(global_buffer);
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
} else {
@@ -317,8 +317,8 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len,
using namespace rabit::c_api; // NOLINT(*)
WriteWrapper sg(global_model, global_len);
WriteWrapper sl(local_model, local_len);
if (local_model == NULL) {
rabit::CheckPoint(&sg, NULL);
if (local_model == nullptr) {
rabit::CheckPoint(&sg, nullptr);
} else {
rabit::CheckPoint(&sg, &sl);
}

View File

@@ -7,18 +7,19 @@
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#include <rabit/base.h>
#include <dmlc/thread_local.h>
#include <memory>
#include "rabit/internal/engine.h"
#include "allreduce_base.h"
#include "allreduce_robust.h"
#include "rabit/internal/thread_local.h"
namespace rabit {
namespace engine {
// singleton sync manager
#ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK
typedef AllreduceRobust Manager;
using Manager = AllreduceRobust;
#else
typedef AllreduceMock Manager;
#endif // RABIT_USE_MOCK
@@ -31,13 +32,13 @@ struct ThreadLocalEntry {
/*! \brief stores the current engine */
std::unique_ptr<Manager> engine;
/*! \brief whether init has been called */
bool initialized;
bool initialized{false};
/*! \brief constructor */
ThreadLocalEntry() : initialized(false) {}
ThreadLocalEntry() = default;
};
// define the threadlocal store.
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
using EngineThreadLocal = dmlc::ThreadLocalStore<ThreadLocalEntry>;
/*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) {
@@ -95,7 +96,7 @@ void Allgather(void *sendrecvbuf_, size_t total_size,
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
void Allreduce_(void *sendrecvbuf, // NOLINT
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
@@ -111,18 +112,15 @@ void Allreduce_(void *sendrecvbuf,
}
// code for reduce handle
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size);
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
utils::Assert(redfunc_ == nullptr, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
@@ -133,7 +131,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
const char* _file,
const int _line,
const char* _caller) {
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
utils::Assert(redfunc_ != nullptr, "must intialize handle to call AllReduce");
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
redfunc_, prepare_fun, prepare_arg,
_file, _line, _caller);

View File

@@ -16,78 +16,68 @@ namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine(void) {
version_number = 0;
EmptyEngine() {
version_number_ = 0;
}
virtual void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
size_t slice_end, size_t size_prev_slice, const char *_file,
const int _line, const char *_caller) override {
utils::Error("EmptyEngine:: Allgather is not supported");
}
virtual int GetRingPrevRank(void) const {
int GetRingPrevRank() const override {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
ReduceFunction reducer, PreprocFunction prepare_fun,
void *prepare_arg, const char *_file, const int _line,
const char *_caller) override {
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) {
void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) override {
}
virtual void InitAfterException(void) {
void InitAfterException() override {
utils::Error("EmptyEngine is not fault tolerant");
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = nullptr) override {
return 0;
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
void CheckPoint(const Serializable *global_model,
const Serializable *local_model = nullptr) override {
version_number_ += 1;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
void LazyCheckPoint(const Serializable *global_model) override {
version_number_ += 1;
}
virtual int VersionNumber(void) const {
return version_number;
int VersionNumber() const override {
return version_number_;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
int GetRank() const override {
return 0;
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
int GetWorldSize() const override {
return 1;
}
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
bool IsDistributed() const override {
return false;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
std::string GetHost() const override {
return std::string("");
}
virtual void TrackerPrint(const std::string &msg) {
void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
int version_number_;
};
// singleton sync manager
@@ -98,12 +88,12 @@ bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
bool Finalize(void) {
bool Finalize() {
return true;
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
IEngine *GetEngine() {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
@@ -118,13 +108,12 @@ void Allreduce_(void *sendrecvbuf,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
ReduceHandle::ReduceHandle() = default;
ReduceHandle::~ReduceHandle() = default;
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0;
@@ -137,7 +126,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
}
} // namespace engine
} // namespace rabit

View File

@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure
* \author Tianqi Chen
*/
@@ -12,4 +12,3 @@
#include <rabit/base.h>
#include "allreduce_mock.h"
#include "engine.cc"