diff --git a/Makefile b/Makefile index 32c7070b3..64cba30fa 100644 --- a/Makefile +++ b/Makefile @@ -36,4 +36,4 @@ $(ALIB): ar cr $@ $+ clean: - $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ + $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~ diff --git a/include/rabit.h b/include/rabit.h index e3f9edfb3..5b2db3098 100644 --- a/include/rabit.h +++ b/include/rabit.h @@ -1,6 +1,5 @@ -#ifndef RABIT_RABIT_H -#define RABIT_RABIT_H /*! + * Copyright (c) 2014 by Contributors * \file rabit.h * \brief This file defines unified Allreduce/Broadcast interface of rabit * The actual implementation is redirected to rabit engine @@ -9,12 +8,14 @@ * rabit.h and serializable.h is all the user need to use rabit interface * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou */ +#ifndef RABIT_RABIT_H_ +#define RABIT_RABIT_H_ #include #include // optionally support of lambda function in C++11, if available #if __cplusplus >= 201103L #include -#endif // C++11 +#endif // C++11 // contains definition of ISerializable #include "./rabit_serializable.h" // engine definition of rabit, defines internal implementation @@ -116,7 +117,7 @@ inline void Broadcast(std::string *sendrecv_data, int root); */ template inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *arg) = NULL, + void (*prepare_fun)(void *arg) = NULL, void *prepare_arg = NULL); // C++11 support for lambda prepare function @@ -142,9 +143,9 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, * \tparam DType type of data */ template -inline void Allreduce(DType *sendrecvbuf, size_t count, std::function prepare_fun); -#endif // C++11 - +inline void Allreduce(DType *sendrecvbuf, size_t count, + std::function prepare_fun); +#endif // C++11 /*! * \brief load latest check point * \param global_model pointer to the globally shared model/state @@ -228,6 +229,7 @@ class Reducer { inline void Allreduce(DType *sendrecvbuf, size_t count, std::function prepare_fun); #endif + private: /*! \brief function handle to do reduce */ engine::ReduceHandle handle_; @@ -274,6 +276,7 @@ class SerializeReducer { size_t max_nbyte, size_t count, std::function prepare_fun); #endif + private: /*! \brief function handle to do reduce */ engine::ReduceHandle handle_; @@ -283,4 +286,4 @@ class SerializeReducer { } // namespace rabit // implementation of template functions #include "./rabit/rabit-inl.h" -#endif // RABIT_ALLREDUCE_H +#endif // RABIT_RABIT_H_ diff --git a/include/rabit/engine.h b/include/rabit/engine.h index ce2f85b66..ce8fb6ee5 100644 --- a/include/rabit/engine.h +++ b/include/rabit/engine.h @@ -1,10 +1,12 @@ /*! + * Copyright (c) 2014 by Contributors * \file engine.h * \brief This file defines the core interface of allreduce library * \author Tianqi Chen, Nacho, Tianyi */ -#ifndef RABIT_ENGINE_H -#define RABIT_ENGINE_H +#ifndef RABIT_ENGINE_H_ +#define RABIT_ENGINE_H_ +#include #include "../rabit_serializable.h" namespace MPI { @@ -122,7 +124,7 @@ class IEngine { virtual int GetRank(void) const = 0; /*! \brief get total number of */ virtual int GetWorldSize(void) const = 0; - /*! \brief get the host name of current node */ + /*! \brief get the host name of current node */ virtual std::string GetHost(void) const = 0; /*! * \brief print the msg in the tracker, @@ -211,7 +213,7 @@ class ReduceHandle { /*! \return the number of bytes occupied by the type */ static int TypeSize(const MPI::Datatype &dtype); - private: + protected: // handle data field void *handle_; // handle to the type field @@ -221,5 +223,4 @@ class ReduceHandle { }; } // namespace engine } // namespace rabit -#endif // RABIT_ENGINE_H - +#endif // RABIT_ENGINE_H_ diff --git a/include/rabit/io.h b/include/rabit/io.h index 44d0a0505..29fa7e812 100644 --- a/include/rabit/io.h +++ b/include/rabit/io.h @@ -1,16 +1,19 @@ -#ifndef RABIT_UTILS_IO_H -#define RABIT_UTILS_IO_H -#include -#include -#include -#include -#include "./utils.h" -#include "../rabit_serializable.h" /*! + * Copyright (c) 2014 by Contributors * \file io.h * \brief utilities that implements different serializable interface * \author Tianqi Chen */ +#ifndef RABIT_UTILS_IO_H_ +#define RABIT_UTILS_IO_H_ +#include +#include +#include +#include +#include +#include "./utils.h" +#include "../rabit_serializable.h" + namespace rabit { namespace utils { /*! \brief interface of i/o stream that support seek */ @@ -25,8 +28,9 @@ class ISeekStream: public IStream { /*! \brief fixed size memory buffer */ struct MemoryFixSizeBuffer : public ISeekStream { public: - MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) - : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) { + MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) + : p_buffer_(reinterpret_cast(p_buffer)), + buffer_size_(buffer_size) { curr_ptr_ = 0; } virtual ~MemoryFixSizeBuffer(void) {} @@ -40,7 +44,7 @@ struct MemoryFixSizeBuffer : public ISeekStream { } virtual void Write(const void *ptr, size_t size) { if (size == 0) return; - utils::Assert(curr_ptr_ + size <= buffer_size_, + utils::Assert(curr_ptr_ + size <= buffer_size_, "write position exceed fixed buffer size"); memcpy(p_buffer_ + curr_ptr_, ptr, size); curr_ptr_ += size; @@ -59,12 +63,12 @@ struct MemoryFixSizeBuffer : public ISeekStream { size_t buffer_size_; /*! \brief current pointer */ size_t curr_ptr_; -}; // class MemoryFixSizeBuffer +}; // class MemoryFixSizeBuffer /*! \brief a in memory buffer that can be read and write as stream interface */ struct MemoryBufferStream : public ISeekStream { public: - MemoryBufferStream(std::string *p_buffer) + explicit MemoryBufferStream(std::string *p_buffer) : p_buffer_(p_buffer) { curr_ptr_ = 0; } @@ -82,7 +86,7 @@ struct MemoryBufferStream : public ISeekStream { if (curr_ptr_ + size > p_buffer_->length()) { p_buffer_->resize(curr_ptr_+size); } - memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); + memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); curr_ptr_ += size; } virtual void Seek(size_t pos) { @@ -97,36 +101,7 @@ struct MemoryBufferStream : public ISeekStream { std::string *p_buffer_; /*! \brief current pointer */ size_t curr_ptr_; -}; // class MemoryBufferStream - -/*! \brief implementation of file i/o stream */ -class FileStream : public ISeekStream { - public: - explicit FileStream(FILE *fp) : fp(fp) {} - explicit FileStream(void) { - this->fp = NULL; - } - virtual size_t Read(void *ptr, size_t size) { - return std::fread(ptr, size, 1, fp); - } - virtual void Write(const void *ptr, size_t size) { - std::fwrite(ptr, size, 1, fp); - } - virtual void Seek(size_t pos) { - std::fseek(fp, static_cast(pos), SEEK_SET); - } - virtual size_t Tell(void) { - return std::ftell(fp); - } - inline void Close(void) { - if (fp != NULL){ - std::fclose(fp); fp = NULL; - } - } - - private: - FILE *fp; -}; +}; // class MemoryBufferStream } // namespace utils } // namespace rabit -#endif +#endif // RABIT_UTILS_IO_H_ diff --git a/include/rabit/utils.h b/include/rabit/utils.h index beae6589f..696000fac 100644 --- a/include/rabit/utils.h +++ b/include/rabit/utils.h @@ -1,10 +1,11 @@ -#ifndef RABIT_UTILS_H_ -#define RABIT_UTILS_H_ /*! + * Copyright (c) 2014 by Contributors * \file utils.h * \brief simple utils to support the code * \author Tianqi Chen */ +#ifndef RABIT_UTILS_H_ +#define RABIT_UTILS_H_ #define _CRT_SECURE_NO_WARNINGS #include #include @@ -19,7 +20,7 @@ #define fopen64 std::fopen #endif #ifdef _MSC_VER -// NOTE: sprintf_s is not equivalent to snprintf, +// NOTE: sprintf_s is not equivalent to snprintf, // they are equivalent when success, which is sufficient for our case #define snprintf sprintf_s #define vsnprintf vsprintf_s @@ -30,7 +31,7 @@ #endif #endif -#ifdef __APPLE__ +#ifdef __APPLE__ #define off64_t off_t #define fopen64 std::fopen #endif @@ -186,5 +187,5 @@ inline const char* BeginPtr(const std::string &str) { if (str.length() == 0) return NULL; return &str[0]; } -} // namespace rabit +} // namespace rabit #endif // RABIT_UTILS_H_ diff --git a/include/rabit_serializable.h b/include/rabit_serializable.h index eabc03f81..0b2ccf3cb 100644 --- a/include/rabit_serializable.h +++ b/include/rabit_serializable.h @@ -1,13 +1,14 @@ -#ifndef RABIT_RABIT_SERIALIZABLE_H -#define RABIT_RABIT_SERIALIZABLE_H -#include -#include -#include "./rabit/utils.h" /*! + * Copyright (c) 2014 by Contributors * \file serializable.h * \brief defines serializable interface of rabit * \author Tianqi Chen */ +#ifndef RABIT_RABIT_SERIALIZABLE_H_ +#define RABIT_RABIT_SERIALIZABLE_H_ +#include +#include +#include "./rabit/utils.h" namespace rabit { /*! * \brief interface of stream I/O, used by ISerializable @@ -96,4 +97,4 @@ class ISerializable { virtual void Save(IStream &fo) const = 0; }; } // namespace rabit -#endif +#endif // RABIT_RABIT_SERIALIZABLE_H_ diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 418b0fd66..671c53877 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file allreduce_base.cc * \brief Basic implementation of AllReduce * @@ -32,13 +33,15 @@ AllreduceBase::AllreduceBase(void) { // initialization function void AllreduceBase::Init(void) { // setup from enviroment variables - {// handling for hadoop + { + // handling for hadoop const char *task_id = getenv("mapred_tip_id"); if (task_id == NULL) { task_id = getenv("mapreduce_task_id"); } if (hadoop_mode != 0) { - utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id"); + utils::Check(task_id != NULL, + "hadoop_mode is set but cannot find mapred_task_id"); } if (task_id != NULL) { this->SetParam("rabit_task_id", task_id); @@ -48,7 +51,7 @@ void AllreduceBase::Init(void) { if (attempt_id != 0) { const char *att = strrchr(attempt_id, '_'); int num_trial; - if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) { + if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) { this->SetParam("rabit_num_trial", att + 1); } } @@ -58,7 +61,8 @@ void AllreduceBase::Init(void) { num_task = getenv("mapreduce_job_maps"); } if (hadoop_mode != 0) { - utils::Check(num_task != NULL, "hadoop_mode is set but cannot find mapred_map_tasks"); + utils::Check(num_task != NULL, + "hadoop_mode is set but cannot find mapred_map_tasks"); } if (num_task != NULL) { this->SetParam("rabit_world_size", num_task); @@ -81,11 +85,11 @@ void AllreduceBase::Shutdown(void) { } all_links.clear(); tree_links.plinks.clear(); - + if (tracker_uri == "NULL") return; // notify tracker rank i have shutdown utils::TCPSocket tracker = this->ConnectTracker(); - tracker.SendStr(std::string("shutdown")); + tracker.SendStr(std::string("shutdown")); tracker.Close(); utils::TCPSocket::Finalize(); } @@ -107,11 +111,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val; if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val); if (!strcmp(name, "rabit_task_id")) task_id = val; - if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); + if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); if (!strcmp(name, "rabit_reduce_buffer")) { char unit; - unsigned long amount; + uint64_t amount; if (sscanf(val, "%lu%c", &amount, &unit) == 2) { switch (unit) { case 'B': reduce_buffer_size = (amount + 7)/ 8; break; @@ -121,7 +125,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) { default: utils::Error("invalid format for reduce buffer"); } } else { - utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); + utils::Error("invalid format for reduce_buffer,"\ + "shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); } } } @@ -137,11 +142,16 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) { utils::Socket::Error("Connect"); } - utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1"); - utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2"); + using utils::Assert; + Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), + "ReConnectLink failure 1"); + Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), + "ReConnectLink failure 2"); utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure"); - utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); - utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3"); + Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), + "ReConnectLink failure 3"); + Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), + "ReConnectLink failure 3"); tracker.SendStr(task_id); return tracker; } @@ -161,29 +171,30 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { int prev_rank, next_rank; // the rank of neighbors std::map tree_neighbors; - {// get new ranks - int newrank, num_neighbors; - utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), - "ReConnectLink failure 4"); - utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank), - "ReConnectLink failure 4"); - utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), - "ReConnectLink failure 4"); - utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one"); - rank = newrank; - utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors), - "ReConnectLink failure 4"); - for (int i = 0; i < num_neighbors; ++i) { - int nrank; - utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), - "ReConnectLink failure 4"); - tree_neighbors[nrank] = 1; - } - utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), - "ReConnectLink failure 4"); - utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), - "ReConnectLink failure 4"); + using utils::Assert; + // get new ranks + int newrank, num_neighbors; + Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), + "ReConnectLink failure 4"); + Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\ + sizeof(parent_rank), "ReConnectLink failure 4"); + Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), + "ReConnectLink failure 4"); + Assert(rank == -1 || newrank == rank, + "must keep rank to same if the node already have one"); + rank = newrank; + Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \ + sizeof(num_neighbors), "ReConnectLink failure 4"); + for (int i = 0; i < num_neighbors; ++i) { + int nrank; + Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), + "ReConnectLink failure 4"); + tree_neighbors[nrank] = 1; } + Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), + "ReConnectLink failure 4"); + Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), + "ReConnectLink failure 4"); // create listening socket utils::TCPSocket sock_listen; sock_listen.Create(); @@ -204,56 +215,67 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { } } int ngood = static_cast(good_link.size()); - utils::Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), - "ReConnectLink failure 5"); + Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), + "ReConnectLink failure 5"); for (size_t i = 0; i < good_link.size(); ++i) { - utils::Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]), - "ReConnectLink failure 6"); + Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \ + sizeof(good_link[i]), "ReConnectLink failure 6"); } - utils::Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), - "ReConnectLink failure 7"); - utils::Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), - "ReConnectLink failure 8"); + Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), + "ReConnectLink failure 7"); + Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \ + sizeof(num_accept), "ReConnectLink failure 8"); num_error = 0; for (int i = 0; i < num_conn; ++i) { LinkRecord r; int hport, hrank; std::string hname; - tracker.RecvStr(&hname); - utils::Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); - utils::Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); + tracker.RecvStr(&hname); + Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), + "ReConnectLink failure 9"); + Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), + "ReConnectLink failure 10"); r.sock.Create(); if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) { num_error += 1; r.sock.Close(); continue; } - utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12"); - utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13"); - utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent"); + Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), + "ReConnectLink failure 12"); + Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), + "ReConnectLink failure 13"); + utils::Check(hrank == r.rank, + "ReConnectLink failure, link rank inconsistent"); bool match = false; for (size_t i = 0; i < all_links.size(); ++i) { if (all_links[i].rank == hrank) { - utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active"); + Assert(all_links[i].sock.IsClosed(), + "Override a link that is active"); all_links[i].sock = r.sock; match = true; break; } } if (!match) all_links.push_back(r); } - utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14"); + Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), + "ReConnectLink failure 14"); } while (num_error != 0); // send back socket listening port to tracker - utils::Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); + Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), + "ReConnectLink failure 14"); // close connection to tracker - tracker.Close(); + tracker.Close(); // listen to incoming links for (int i = 0; i < num_accept; ++i) { LinkRecord r; r.sock = sock_listen.Accept(); - utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15"); - utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15"); + Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), + "ReConnectLink failure 15"); + Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), + "ReConnectLink failure 15"); bool match = false; for (size_t i = 0; i < all_links.size(); ++i) { if (all_links[i].rank == r.rank) { - utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active"); + utils::Assert(all_links[i].sock.IsClosed(), + "Override a link that is active"); all_links[i].sock = r.sock; match = true; break; } } @@ -278,9 +300,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; if (all_links[i].rank == next_rank) ring_next = &all_links[i]; } - utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link"); - utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link"); - utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link"); + Assert(parent_rank == -1 || parent_index != -1, + "cannot find parent in the link"); + Assert(prev_rank == -1 || ring_prev != NULL, + "cannot find prev ring in the link"); + Assert(next_rank == -1 || ring_next != NULL, + "cannot find next ring in the link"); } /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure @@ -326,7 +351,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, // if no childs, no need to reduce if (nlink == static_cast(parent_index != -1)) { size_up_reduce = total_size; - } + } // while we have not passed the messages out while (true) { // select helper @@ -347,7 +372,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, if (links[i].size_read != total_size) { selecter.WatchRead(links[i].sock); } - // size_write <= size_read + // size_write <= size_read if (links[i].size_write != total_size) { selecter.WatchWrite(links[i].sock); // only watch for exception in live channels @@ -358,11 +383,11 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, } // finish runing allreduce if (finished) break; - // select must return + // select must return selecter.Select(); // exception handling for (int i = 0; i < nlink; ++i) { - // recive OOB message from some link + // recive OOB message from some link if (selecter.CheckExcept(links[i].sock)) return kGetExcept; } // read data from childs @@ -392,7 +417,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, // start position size_t start = size_up_reduce % buffer_size; // peform read till end of buffer - size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce); + size_t nread = std::min(buffer_size - start, + max_reduce - size_up_reduce); utils::Assert(nread % type_nbytes == 0, "Allreduce: size check"); for (int i = 0; i < nlink; ++i) { if (i != parent_index) { @@ -407,7 +433,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, } if (parent_index != -1) { // pass message up to parent, can pass data that are already been reduced - if (selecter.CheckWrite(links[parent_index].sock)) { + if (selecter.CheckWrite(links[parent_index].sock)) { ssize_t len = links[parent_index].sock. Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); if (len != -1) { @@ -417,7 +443,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, } } // read data from parent - if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) { + if (selecter.CheckRead(links[parent_index].sock) && + total_size > size_down_in) { ssize_t len = links[parent_index].sock. Recv(sendrecvbuf + size_down_in, total_size - size_down_in); if (len == 0) { @@ -425,7 +452,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, } if (len != -1) { size_down_in += static_cast(len); - utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error"); + utils::Assert(size_down_in <= size_up_out, + "Allreduce: boundary error"); } else { if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; } @@ -437,11 +465,13 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, // can pass message down to childs for (int i = 0; i < nlink; ++i) { if (i != parent_index && selecter.CheckWrite(links[i].sock)) { - if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError; + if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) { + return kSockError; + } } } } - return kSuccess; + return kSuccess; } /*! * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure @@ -455,14 +485,15 @@ AllreduceBase::ReturnType AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { RefLinkVector &links = tree_links; if (links.size() == 0 || total_size == 0) return kSuccess; - utils::Check(root < world_size, "Broadcast: root should be smaller than world size"); + utils::Check(root < world_size, + "Broadcast: root should be smaller than world size"); // number of links const int nlink = static_cast(links.size()); // size of space already read from data size_t size_in = 0; // input link, -2 means unknown yet, -1 means this is root int in_link = -2; - + // initialize the link statistics for (int i = 0; i < nlink; ++i) { links[i].ResetSize(); @@ -471,9 +502,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { if (this->rank == root) { size_in = total_size; in_link = -1; - } + } // while we have not passed the messages out - while(true) { + while (true) { bool finished = true; // select helper utils::SelectHelper selecter; @@ -487,7 +518,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { if (in_link != -2 && i != in_link && links[i].size_write != total_size) { selecter.WatchWrite(links[i].sock); finished = false; } - selecter.WatchException(links[i].sock); + selecter.WatchException(links[i].sock); } // finish running if (finished) break; @@ -495,14 +526,16 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { selecter.Select(); // exception handling for (int i = 0; i < nlink; ++i) { - // recive OOB message from some link + // recive OOB message from some link if (selecter.CheckExcept(links[i].sock)) return kGetExcept; } if (in_link == -2) { // probe in-link for (int i = 0; i < nlink; ++i) { if (selecter.CheckRead(links[i].sock)) { - if (!links[i].ReadToArray(sendrecvbuf_, total_size)) return kSockError; + if (!links[i].ReadToArray(sendrecvbuf_, total_size)) { + return kSockError; + } size_in = links[i].size_read; if (size_in != 0) { in_link = i; break; @@ -512,7 +545,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { } else { // read from in link if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { - if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) return kSockError; + if (!links[in_link].ReadToArray(sendrecvbuf_, total_size)) { + return kSockError; + } size_in = links[in_link].size_read; } } diff --git a/src/allreduce_base.h b/src/allreduce_base.h index bede2c228..14a8cf339 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file allreduce_base.h * \brief Basic implementation of AllReduce * using TCP non-block socket and tree-shape reduction. @@ -8,13 +9,14 @@ * * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou */ -#ifndef RABIT_ALLREDUCE_BASE_H -#define RABIT_ALLREDUCE_BASE_H +#ifndef RABIT_ALLREDUCE_BASE_H_ +#define RABIT_ALLREDUCE_BASE_H_ #include #include -#include -#include +#include +#include "rabit/utils.h" +#include "rabit/engine.h" #include "./socket.h" namespace MPI { @@ -22,7 +24,7 @@ namespace MPI { class Datatype { public: size_t type_size; - Datatype(size_t type_size) : type_size(type_size) {} + explicit Datatype(size_t type_size) : type_size(type_size) {} }; } namespace rabit { @@ -31,7 +33,7 @@ namespace engine { class AllreduceBase : public IEngine { public: // magic number to verify server - const static int kMagic = 0xff99; + static const int kMagic = 0xff99; // constant one byte out of band message to indicate error happening AllreduceBase(void); virtual ~AllreduceBase(void) {} @@ -79,12 +81,13 @@ class AllreduceBase : public IEngine { */ virtual void Allreduce(void *sendrecvbuf_, size_t type_nbytes, - size_t count, + size_t count, ReduceFunction reducer, PreprocFunction prepare_fun = NULL, void *prepare_arg = NULL) { if (prepare_fun != NULL) prepare_fun(prepare_arg); - utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess, + utils::Assert(TryAllreduce(sendrecvbuf_, + type_nbytes, count, reducer) == kSuccess, "Allreduce failed"); } /*! @@ -201,12 +204,16 @@ class AllreduceBase : public IEngine { // constructor LinkRecord(void) {} // initialize buffer - inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { + inline void InitBuffer(size_t type_nbytes, size_t count, + size_t reduce_buffer_size) { size_t n = (type_nbytes * count + 7)/ 8; buffer_.resize(std::min(reduce_buffer_size, n)); // make sure align to type_nbytes - buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; - utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size); + buffer_size = + buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; + utils::Assert(type_nbytes <= buffer_size, + "too large type_nbytes=%lu, buffer_size=%lu", + type_nbytes, buffer_size); // set buffer head buffer_head = reinterpret_cast(BeginPtr(buffer_)); } @@ -225,7 +232,7 @@ class AllreduceBase : public IEngine { size_t ngap = size_read - protect_start; utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); size_t offset = size_read % buffer_size; - size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); + size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); if (nmax == 0) return true; ssize_t len = sock.Recv(buffer_head + offset, nmax); // length equals 0, remote disconnected @@ -235,7 +242,7 @@ class AllreduceBase : public IEngine { if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; size_read += static_cast(len); return true; - } + } /*! * \brief read data into array, * this function can not be used together with ReadToRingBuffer diff --git a/src/allreduce_robust-inl.h b/src/allreduce_robust-inl.h index 49f8f2c37..e0250e426 100644 --- a/src/allreduce_robust-inl.h +++ b/src/allreduce_robust-inl.h @@ -1,11 +1,13 @@ /*! + * Copyright (c) 2014 by Contributors * \file allreduce_robust-inl.h * \brief implementation of inline template function in AllreduceRobust * * \author Tianqi Chen */ -#ifndef RABIT_ENGINE_ROBUST_INL_H -#define RABIT_ENGINE_ROBUST_INL_H +#ifndef RABIT_ENGINE_ROBUST_INL_H_ +#define RABIT_ENGINE_ROBUST_INL_H_ +#include namespace rabit { namespace engine { @@ -33,14 +35,14 @@ inline AllreduceRobust::ReturnType AllreduceRobust::MsgPassing(const NodeType &node_value, std::vector *p_edge_in, std::vector *p_edge_out, - EdgeType (*func) (const NodeType &node_value, - const std::vector &edge_in, - size_t out_index) - ) { + EdgeType (*func) + (const NodeType &node_value, + const std::vector &edge_in, + size_t out_index)) { RefLinkVector &links = tree_links; if (links.size() == 0) return kSuccess; // number of links - const int nlink = static_cast(links.size()); + const int nlink = static_cast(links.size()); // initialize the pointers for (int i = 0; i < nlink; ++i) { links[i].ResetSize(); @@ -58,7 +60,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, // if no childs, no need to, directly start passing message if (nlink == static_cast(parent_index != -1)) { utils::Assert(parent_index == 0, "parent must be 0"); - edge_out[parent_index] = func(node_value, edge_in, parent_index); + edge_out[parent_index] = func(node_value, edge_in, parent_index); stage = 1; } // while we have not passed the messages out @@ -94,7 +96,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, selecter.Select(); // exception handling for (int i = 0; i < nlink; ++i) { - // recive OOB message from some link + // recive OOB message from some link if (selecter.CheckExcept(links[i].sock)) return kGetExcept; } if (stage == 0) { @@ -103,7 +105,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, for (int i = 0; i < nlink; ++i) { if (i != parent_index) { if (selecter.CheckRead(links[i].sock)) { - if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) return kSockError; + if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) { + return kSockError; + } } if (links[i].size_read != sizeof(EdgeType)) finished = false; } @@ -124,13 +128,17 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, if (stage == 1) { const int pid = this->parent_index; utils::Assert(pid != -1, "MsgPassing invalid stage"); - if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) return kSockError; + if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) { + return kSockError; + } if (links[pid].size_write == sizeof(EdgeType)) stage = 2; } if (stage == 2) { const int pid = this->parent_index; - utils::Assert(pid != -1, "MsgPassing invalid stage"); - if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) return kSockError; + utils::Assert(pid != -1, "MsgPassing invalid stage"); + if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) { + return kSockError; + } if (links[pid].size_read == sizeof(EdgeType)) { for (int i = 0; i < nlink; ++i) { if (i != pid) edge_out[i] = func(node_value, edge_in, i); @@ -141,7 +149,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, if (stage == 3) { for (int i = 0; i < nlink; ++i) { if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { - if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError; + if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) { + return kSockError; + } } } } @@ -150,4 +160,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, } } // namespace engine } // namespace rabit -#endif // RABIT_ENGINE_ROBUST_INL_H +#endif // RABIT_ENGINE_ROBUST_INL_H_ diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index fb53a0777..e25a9c85f 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file allreduce_robust.cc * \brief Robust implementation of Allreduce * @@ -9,10 +10,10 @@ #define NOMINMAX #include #include -#include -#include -#include -#include +#include "rabit/io.h" +#include "rabit/utils.h" +#include "rabit/engine.h" +#include "rabit/rabit-inl.h" #include "./allreduce_robust.h" namespace rabit { @@ -30,10 +31,10 @@ void AllreduceRobust::Shutdown(void) { utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp), "Shutdown: check point must return true"); // reset result buffer - resbuf.Clear(); seq_counter = 0; + resbuf.Clear(); seq_counter = 0; // execute check ack step, load happens here utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp), - "Shutdown: check ack must return true"); + "Shutdown: check ack must return true"); AllreduceBase::Shutdown(); } /*! @@ -89,7 +90,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, } else { recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); } - } + } } resbuf.PushTemp(seq_counter, type_nbytes, count); seq_counter += 1; @@ -102,7 +103,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, */ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) { // skip action in single node - if (world_size == 1) return; + if (world_size == 1) return; bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 && @@ -119,7 +120,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) } else { recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); } - } + } } resbuf.PushTemp(seq_counter, 1, total_size); seq_counter += 1; @@ -151,7 +152,8 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, // skip action in single node if (world_size == 1) return 0; if (num_local_replica == 0) { - utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); + utils::Check(local_model == NULL, + "need to set num_local_replica larger than 1 to checkpoint local_model"); } // check if we succesful if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) { @@ -171,9 +173,10 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, // load from buffer utils::MemoryBufferStream fs(&global_checkpoint); if (global_checkpoint.length() == 0) { - version_number = 0; + version_number = 0; } else { - utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number"); + utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, + "read in version number"); global_model->Load(fs); utils::Assert(local_model == NULL || nlocal == num_local_replica + 1, "local model inconsistent, nlocal=%d", nlocal); @@ -212,9 +215,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model, version_number += 1; return; } if (num_local_replica == 0) { - utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); - } - if (num_local_replica != 0) { + utils::Check(local_model == NULL, + "need to set num_local_replica larger than 1 to checkpoint local_model"); + } + if (num_local_replica != 0) { while (true) { if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break; // save model model to new version place @@ -247,10 +251,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model, fs.Write(&version_number, sizeof(version_number)); global_model->Save(fs); // reset result buffer - resbuf.Clear(); seq_counter = 0; + resbuf.Clear(); seq_counter = 0; // execute check ack step, load happens here utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp), - "check ack must return true"); + "check ack must return true"); } /*! * \brief reset the all the existing links by sending Out-of-Band message marker @@ -383,7 +387,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { */ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { if (err_type == kSuccess) return true; - {// simple way, shutdown all links + { + // simple way, shutdown all links for (size_t i = 0; i < all_links.size(); ++i) { if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); } @@ -392,8 +397,8 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { } // this was old way // TryResetLinks still causes possible errors, so not use this one - while(err_type != kSuccess) { - switch(err_type) { + while (err_type != kSuccess) { + switch (err_type) { case kGetExcept: err_type = TryResetLinks(); break; case kSockError: { TryResetLinks(); @@ -416,7 +421,7 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { * \param out_index the edge index of output link * \return the shorest distance result of out edge specified by out_index */ -inline std::pair +inline std::pair ShortestDist(const std::pair &node_value, const std::vector< std::pair > &dist_in, size_t out_index) { @@ -484,8 +489,9 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role, int *p_recvlink, std::vector *p_req_in) { int best_link = -2; - {// get the shortest distance to the request point - std::vector< std::pair > dist_in, dist_out; + { + // get the shortest distance to the request point + std::vector > dist_in, dist_out; ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size), &dist_in, &dist_out, ShortestDist); if (succ != kSuccess) return succ; @@ -512,7 +518,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role, &req_in, &req_out, DataRequest); if (succ != kSuccess) return succ; // set p_req_in - p_req_in->resize(req_in.size()); + p_req_in->resize(req_in.size()); for (size_t i = 0; i < req_in.size(); ++i) { // set p_req_in (*p_req_in)[i] = (req_in[i] != 0); @@ -591,19 +597,23 @@ AllreduceRobust::TryRecoverData(RecoverType role, if (role == kRequestData) { const int pid = recv_link; if (selecter.CheckRead(links[pid].sock)) { - if(!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError; + if (!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError; } for (int i = 0; i < nlink; ++i) { if (req_in[i] && links[i].size_write != links[pid].size_read && selecter.CheckWrite(links[i].sock)) { - if(!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) return kSockError; + if (!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) { + return kSockError; + } } } } if (role == kHaveData) { for (int i = 0; i < nlink; ++i) { if (req_in[i] && selecter.CheckWrite(links[i].sock)) { - if(!links[i].WriteFromArray(sendrecvbuf_, size)) return kSockError; + if (!links[i].WriteFromArray(sendrecvbuf_, size)) { + return kSockError; + } } } } @@ -616,13 +626,14 @@ AllreduceRobust::TryRecoverData(RecoverType role, if (req_in[i]) min_write = std::min(links[i].size_write, min_write); } utils::Assert(min_write <= links[pid].size_read, "boundary check"); - if (!links[pid].ReadToRingBuffer(min_write)) return kSockError; + if (!links[pid].ReadToRingBuffer(min_write)) return kSockError; } - for (int i = 0; i < nlink; ++i) { - if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) { + for (int i = 0; i < nlink; ++i) { + if (req_in[i] && selecter.CheckWrite(links[i].sock) && + links[pid].size_read != links[i].size_write) { size_t start = links[i].size_write % buffer_size; // send out data from ring buffer - size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); + size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite); if (len != -1) { links[i].size_write += len; @@ -648,15 +659,15 @@ AllreduceRobust::TryRecoverData(RecoverType role, */ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { // check in local data - RecoverType role = requester ? kRequestData : kHaveData; - ReturnType succ; + RecoverType role = requester ? kRequestData : kHaveData; + ReturnType succ; if (num_local_replica != 0) { if (requester) { // clear existing history, if any, before load local_rptr[local_chkpt_version].clear(); - local_chkpt[local_chkpt_version].clear(); + local_chkpt[local_chkpt_version].clear(); } - // recover local checkpoint + // recover local checkpoint succ = TryRecoverLocalState(&local_rptr[local_chkpt_version], &local_chkpt[local_chkpt_version]); if (succ != kSuccess) return succ; @@ -716,7 +727,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re // if we goes to this place, use must have already setup the state once utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1, "TryGetResult::Checkpoint"); - return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]); + return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]); } // handles normal data recovery RecoverType role; @@ -735,8 +746,9 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re utils::Check(data_size != 0, "zero size check point is not allowed"); if (role == kRequestData || role == kHaveData) { utils::Check(data_size == size, - "Allreduce Recovered data size do not match the specification of function call\n"\ - "Please check if calling sequence of recovered program is the same the original one in current VersionNumber"); + "Allreduce Recovered data size do not match the specification of function call.\n"\ + "Please check if calling sequence of recovered program is the " \ + "same the original one in current VersionNumber"); } return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in); } @@ -766,7 +778,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) { while (true) { this->ReportStatus(); // action - ActionSummary act = req; + ActionSummary act = req; // get the reduced action if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue; if (act.check_ack()) { @@ -816,7 +828,8 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) { if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue; if (requester) return true; } else { - // all the request is same, this is most recent command that is yet to be executed + // all the request is same, + // this is most recent command that is yet to be executed return false; } } @@ -855,7 +868,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent"); } const int n = num_local_replica; - {// backward passing, passing state in backward direction of the ring + { + // backward passing, passing state in backward direction of the ring const int nlocal = static_cast(rptr.size() - 1); utils::Assert(nlocal <= n + 1, "invalid local replica"); std::vector msg_back(n + 1); @@ -897,10 +911,10 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, // update rptr rptr.resize(nread_end + 1); for (int i = nlocal; i < nread_end; ++i) { - rptr[i + 1] = rptr[i] + sizes[i]; + rptr[i + 1] = rptr[i] + sizes[i]; } chkpt.resize(rptr.back()); - // pass data through the link + // pass data through the link succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], rptr[nwrite_start], rptr[nread_end], ring_next, ring_prev); @@ -908,7 +922,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; } } - {// forward passing, passing state in forward direction of the ring + { + // forward passing, passing state in forward direction of the ring const int nlocal = static_cast(rptr.size() - 1); utils::Assert(nlocal <= n + 1, "invalid local replica"); std::vector msg_forward(n + 1); @@ -926,7 +941,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, 1 * sizeof(int), 2 * sizeof(int), 0 * sizeof(int), 1 * sizeof(int), ring_next, ring_prev); - if (succ != kSuccess) return succ; + if (succ != kSuccess) return succ; // calculate the number of things we can read from next link int nread_end = nlocal, nwrite_end = 1; // have to have itself in order to get other data from prev link @@ -936,7 +951,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, nread_end = std::max(nread_end, i + 1); nwrite_end = i + 1; } - if (nwrite_end > n) nwrite_end = n; + if (nwrite_end > n) nwrite_end = n; } else { nread_end = 0; nwrite_end = 0; } @@ -963,7 +978,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, rptr[i + 1] = rptr[i] + sizes[i]; } chkpt.resize(rptr.back()); - // pass data through the link + // pass data through the link succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], rptr[nwrite_start], rptr[nwrite_end], ring_prev, ring_next); @@ -995,7 +1010,8 @@ AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, if (num_local_replica == 0) return kSuccess; std::vector &rptr = *p_local_rptr; std::string &chkpt = *p_local_chkpt; - utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state"); + utils::Assert(rptr.size() == 2, + "TryCheckinLocalState must have exactly 1 state"); const int n = num_local_replica; std::vector sizes(n + 1); sizes[0] = rptr[1] - rptr[0]; @@ -1012,9 +1028,9 @@ AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, rptr.resize(n + 2); for (int i = 1; i <= n; ++i) { rptr[i + 1] = rptr[i] + sizes[i]; - } + } chkpt.resize(rptr.back()); - // pass data through the link + // pass data through the link succ = RingPassing(BeginPtr(chkpt), rptr[1], rptr[n + 1], rptr[0], rptr[n], @@ -1050,13 +1066,14 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, LinkRecord *read_link, LinkRecord *write_link) { if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess; - utils::Assert(write_end <= read_end, "RingPassing: boundary check1, write_end=%lu, read_end=%lu", write_end, read_end); + utils::Assert(write_end <= read_end, + "RingPassing: boundary check1"); utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2"); utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3"); // take reference LinkRecord &prev = *read_link, &next = *write_link; // send recv buffer - char *buf = reinterpret_cast(sendrecvbuf_); + char *buf = reinterpret_cast(sendrecvbuf_); while (true) { bool finished = true; utils::SelectHelper selecter; @@ -1066,7 +1083,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, } if (write_ptr < read_ptr && write_ptr != write_end) { selecter.WatchWrite(next.sock); - finished = false; + finished = false; } selecter.WatchException(prev.sock); selecter.WatchException(next.sock); @@ -1078,7 +1095,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr); if (len == 0) { prev.sock.Close(); return kSockError; - } + } if (len != -1) { read_ptr += static_cast(len); } else { diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index d178e391a..f2a804e95 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file allreduce_robust.h * \brief Robust implementation of Allreduce * using TCP non-block socket and tree-shape reduction. @@ -7,10 +8,12 @@ * * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou */ -#ifndef RABIT_ALLREDUCE_ROBUST_H -#define RABIT_ALLREDUCE_ROBUST_H +#ifndef RABIT_ALLREDUCE_ROBUST_H_ +#define RABIT_ALLREDUCE_ROBUST_H_ #include -#include +#include +#include +#include "rabit/engine.h" #include "./allreduce_base.h" namespace rabit { @@ -111,11 +114,11 @@ class AllreduceRobust : public AllreduceBase { private: // constant one byte out of band message to indicate error happening // and mark for channel cleanup - const static char kOOBReset = 95; + static const char kOOBReset = 95; // and mark for channel cleanup, after OOB signal - const static char kResetMark = 97; + static const char kResetMark = 97; // and mark for channel cleanup - const static char kResetAck = 97; + static const char kResetAck = 97; /*! \brief type of roles each node can play during recovery */ enum RecoverType { /*! \brief current node have data */ @@ -132,29 +135,29 @@ class AllreduceRobust : public AllreduceBase { */ struct ActionSummary { // maximumly allowed sequence id - const static int kSpecialOp = (1 << 26); + static const int kSpecialOp = (1 << 26); // special sequence number for local state checkpoint - const static int kLocalCheckPoint = (1 << 26) - 2; + static const int kLocalCheckPoint = (1 << 26) - 2; // special sequnce number for local state checkpoint ack signal - const static int kLocalCheckAck = (1 << 26) - 1; + static const int kLocalCheckAck = (1 << 26) - 1; //--------------------------------------------- - // The following are bit mask of flag used in + // The following are bit mask of flag used in //---------------------------------------------- // some node want to load check point - const static int kLoadCheck = 1; + static const int kLoadCheck = 1; // some node want to do check point - const static int kCheckPoint = 2; + static const int kCheckPoint = 2; // check point Ack, we use a two phase message in check point, // this is the second phase of check pointing - const static int kCheckAck = 4; + static const int kCheckAck = 4; // there are difference sequence number the nodes proposed // this means we want to do recover execution of the lower sequence // action instead of normal execution - const static int kDiffSeq = 8; + static const int kDiffSeq = 8; // constructor ActionSummary(void) {} - // constructor of action - ActionSummary(int flag, int minseqno = kSpecialOp) { + // constructor of action + explicit ActionSummary(int flag, int minseqno = kSpecialOp) { seqcode = (minseqno << 4) | flag; } // minimum number of all operations @@ -181,10 +184,11 @@ class AllreduceRobust : public AllreduceBase { inline int flag(void) const { return seqcode & 15; } - // reducer for Allreduce, used to get the result ActionSummary from all nodes - inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { + // reducer for Allreduce, get the result ActionSummary from all nodes + inline static void Reducer(const void *src_, void *dst_, + int len, const MPI::Datatype &dtype) { const ActionSummary *src = (const ActionSummary*)src_; - ActionSummary *dst = (ActionSummary*)dst_; + ActionSummary *dst = reinterpret_cast(dst_); for (int i = 0; i < len; ++i) { int src_seqno = src[i].min_seqno(); int dst_seqno = dst[i].min_seqno(); @@ -192,7 +196,8 @@ class AllreduceRobust : public AllreduceBase { if (src_seqno == dst_seqno) { dst[i] = ActionSummary(flag, src_seqno); } else { - dst[i] = ActionSummary(flag | kDiffSeq, std::min(src_seqno, dst_seqno)); + dst[i] = ActionSummary(flag | kDiffSeq, + std::min(src_seqno, dst_seqno)); } } } @@ -222,7 +227,7 @@ class AllreduceRobust : public AllreduceBase { data_.resize(rptr_.back() + nhop); return BeginPtr(data_) + rptr_.back(); } - // push the result in temp to the + // push the result in temp to the inline void PushTemp(int seqid, size_t type_nbytes, size_t count) { size_t size = type_nbytes * count; size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); @@ -234,13 +239,14 @@ class AllreduceRobust : public AllreduceBase { size_.push_back(size); utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent"); } - // return the stored result of seqid, if any + // return the stored result of seqid, if any inline void* Query(int seqid, size_t *p_size) { - size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin(); + size_t idx = std::lower_bound(seqno_.begin(), + seqno_.end(), seqid) - seqno_.begin(); if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL; *p_size = size_[idx]; return BeginPtr(data_) + rptr_[idx]; - } + } // drop last stored result inline void DropLast(void) { utils::Assert(seqno_.size() != 0, "there is nothing to be dropped"); @@ -254,15 +260,16 @@ class AllreduceRobust : public AllreduceBase { if (seqno_.size() == 0) return -1; return seqno_.back(); } + private: - // sequence number of each + // sequence number of each std::vector seqno_; // pointer to the positions std::vector rptr_; // actual size of each buffer std::vector size_; // content of the buffer - std::vector data_; + std::vector data_; }; /*! * \brief reset the all the existing links by sending Out-of-Band message marker @@ -291,14 +298,16 @@ class AllreduceRobust : public AllreduceBase { * \param buf the buffer to store the result * \param size the total size of the buffer * \param flag flag information about the action \sa ActionSummary - * \param seqno sequence number of the action, if it is special action with flag set, seqno needs to be set to ActionSummary::kSpecialOp + * \param seqno sequence number of the action, if it is special action with flag set, + * seqno needs to be set to ActionSummary::kSpecialOp * * \return if this function can return true or false - * - true means buf already set to the - * result by recovering procedure, the action is complete, no further action is needed + * - true means buf already set to the + * result by recovering procedure, the action is complete, no further action is needed * - false means this is the lastest action that has not yet been executed, need to execute the action */ - bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kSpecialOp); + bool RecoverExec(void *buf, size_t size, int flag, + int seqno = ActionSummary::kSpecialOp); /*! * \brief try to load check point * @@ -363,7 +372,7 @@ class AllreduceRobust : public AllreduceBase { void *sendrecvbuf_, size_t size, int recv_link, - const std::vector &req_in); + const std::vector &req_in); /*! * \brief try to recover the local state, making each local state to be the result of itself * plus replication of states in previous num_local_replica hops in the ring @@ -446,17 +455,17 @@ o * the input state must exactly one saved state(local state of current node) inline ReturnType MsgPassing(const NodeType &node_value, std::vector *p_edge_in, std::vector *p_edge_out, - EdgeType (*func) (const NodeType &node_value, - const std::vector &edge_in, - size_t out_index) - ); + EdgeType (*func) + (const NodeType &node_value, + const std::vector &edge_in, + size_t out_index)); //---- recovery data structure ---- // the round of result buffer, used to mode the result int result_buffer_round; // result buffer of all reduce ResultBuffer resbuf; // last check point global model - std::string global_checkpoint; + std::string global_checkpoint; // number of replica for local state/model int num_local_replica; // --- recovery data structure for local checkpoint @@ -465,16 +474,15 @@ o * the input state must exactly one saved state(local state of current node) // pointer to memory position in the local model // local model is stored in CSR format(like a sparse matrices) // local_model[rptr[0]:rptr[1]] stores the model of current node - // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring + // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops std::vector local_rptr[2]; // storage for local model replicas std::string local_chkpt[2]; // version of local checkpoint can be 1 or 0 - int local_chkpt_version; + int local_chkpt_version; }; } // namespace engine } // namespace rabit // implementation of inline template function #include "./allreduce_robust-inl.h" - -#endif // RABIT_ALLREDUCE_ROBUST_H +#endif // RABIT_ALLREDUCE_ROBUST_H_ diff --git a/src/engine.cc b/src/engine.cc index 57e074109..45bef329c 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file engine.cc * \brief this file governs which implementation of engine we are actually using * provides an singleton of engine interface @@ -41,16 +42,17 @@ void Finalize(void) { IEngine *GetEngine(void) { return &manager; } -// perform in-place allreduce, on sendrecvbuf +// perform in-place allreduce, on sendrecvbuf void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, - IEngine::ReduceFunction red, + IEngine::ReduceFunction red, mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, void *prepare_arg) { - GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg); + GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, + red, prepare_fun, prepare_arg); } // code for reduce handle diff --git a/src/engine_empty.cc b/src/engine_empty.cc index ff838717e..0c7020914 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file engine_empty.cc * \brief this file provides a dummy implementation of engine that does nothing * this file provides a way to fall back to single node program without causing too many dependencies @@ -25,9 +26,10 @@ class EmptyEngine : public IEngine { ReduceFunction reducer, PreprocFunction prepare_fun, void *prepare_arg) { - utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead"); + utils::Error("EmptyEngine:: Allreduce is not supported,"\ + "use Allreduce_ instead"); } - virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { + virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { } virtual void InitAfterException(void) { utils::Error("EmptyEngine is not fault tolerant"); @@ -51,7 +53,7 @@ class EmptyEngine : public IEngine { virtual int GetWorldSize(void) const { return 1; } - /*! \brief get the host name of current node */ + /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { return std::string(""); } @@ -59,6 +61,7 @@ class EmptyEngine : public IEngine { // simply print information into the tracker utils::Printf("%s", msg.c_str()); } + private: int version_number; }; @@ -77,11 +80,11 @@ void Finalize(void) { IEngine *GetEngine(void) { return &manager; } -// perform in-place allreduce, on sendrecvbuf +// perform in-place allreduce, on sendrecvbuf void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, - IEngine::ReduceFunction red, + IEngine::ReduceFunction red, mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, diff --git a/src/engine_mock.cc b/src/engine_mock.cc index e8a77a6a2..24415a1d5 100644 --- a/src/engine_mock.cc +++ b/src/engine_mock.cc @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file engine_mock.cc * \brief this is an engine implementation that will * insert failures in certain call point, to test if the engine is robust to failure diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index d8a30cbbc..bdad5e2d1 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -1,4 +1,5 @@ /*! + * Copyright (c) 2014 by Contributors * \file engine_mpi.cc * \brief this file gives an implementation of engine interface using MPI, * this will allow rabit program to run with MPI, but do not comes with fault tolerant @@ -8,10 +9,10 @@ #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE #define NOMINMAX -#include -#include -#include #include +#include +#include "rabit/engine.h" +#include "rabit/utils.h" namespace rabit { namespace engine { @@ -27,9 +28,10 @@ class MPIEngine : public IEngine { ReduceFunction reducer, PreprocFunction prepare_fun, void *prepare_arg) { - utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead"); + utils::Error("MPIEngine:: Allreduce is not supported,"\ + "use Allreduce_ instead"); } - virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { + virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root); } virtual void InitAfterException(void) { @@ -48,13 +50,13 @@ class MPIEngine : public IEngine { } /*! \brief get rank of current node */ virtual int GetRank(void) const { - return MPI::COMM_WORLD.Get_rank(); + return MPI::COMM_WORLD.Get_rank(); } /*! \brief get total number of */ virtual int GetWorldSize(void) const { - return MPI::COMM_WORLD.Get_size(); + return MPI::COMM_WORLD.Get_size(); } - /*! \brief get the host name of current node */ + /*! \brief get the host name of current node */ virtual std::string GetHost(void) const { int len; char name[MPI_MAX_PROCESSOR_NAME]; @@ -68,6 +70,7 @@ class MPIEngine : public IEngine { utils::Printf("%s", msg.c_str()); } } + private: int version_number; }; @@ -91,7 +94,7 @@ IEngine *GetEngine(void) { // transform enum to MPI data type inline MPI::Datatype GetType(mpi::DataType dtype) { using namespace mpi; - switch(dtype) { + switch (dtype) { case kInt: return MPI::INT; case kUInt: return MPI::UNSIGNED; case kFloat: return MPI::FLOAT; @@ -103,7 +106,7 @@ inline MPI::Datatype GetType(mpi::DataType dtype) { // transform enum to MPI OP inline MPI::Op GetOp(mpi::OpType otype) { using namespace mpi; - switch(otype) { + switch (otype) { case kMax: return MPI::MAX; case kMin: return MPI::MIN; case kSum: return MPI::SUM; @@ -112,17 +115,18 @@ inline MPI::Op GetOp(mpi::OpType otype) { utils::Error("unknown mpi::OpType"); return MPI::MAX; } -// perform in-place allreduce, on sendrecvbuf +// perform in-place allreduce, on sendrecvbuf void Allreduce_(void *sendrecvbuf, size_t type_nbytes, size_t count, - IEngine::ReduceFunction red, + IEngine::ReduceFunction red, mpi::DataType dtype, mpi::OpType op, IEngine::PreprocFunction prepare_fun, void *prepare_arg) { if (prepare_fun != NULL) prepare_fun(prepare_arg); - MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op)); + MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, + count, GetType(dtype), GetOp(op)); } // code for reduce handle @@ -152,7 +156,7 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) { created_type_nbytes_ = type_nbytes; htype_ = dtype; } - + MPI::Op *op = new MPI::Op(); MPI::User_function *pf = redfunc; op->Init(pf, true); @@ -175,7 +179,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf, dtype->Commit(); created_type_nbytes_ = type_nbytes; } - if (prepare_fun != NULL) prepare_fun(prepare_arg); + if (prepare_fun != NULL) prepare_fun(prepare_arg); MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op); } } // namespace engine diff --git a/src/socket.h b/src/socket.h index 29d62db35..c40cb6a88 100644 --- a/src/socket.h +++ b/src/socket.h @@ -1,10 +1,11 @@ -#ifndef RABIT_SOCKET_H -#define RABIT_SOCKET_H /*! + * Copyright (c) 2014 by Contributors * \file socket.h * \brief this file aims to provide a wrapper of sockets * \author Tianqi Chen */ +#ifndef RABIT_SOCKET_H_ +#define RABIT_SOCKET_H_ #if defined(_WIN32) #include #include @@ -21,7 +22,7 @@ #endif #include #include -#include +#include "rabit/utils.h" #if defined(_WIN32) typedef int ssize_t; @@ -68,9 +69,11 @@ struct SockAddr { inline std::string AddrStr(void) const { std::string buf; buf.resize(256); #ifdef _WIN32 - const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, &buf[0], buf.length()); + const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, + &buf[0], buf.length()); #else - const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length()); + const char *s = inet_ntop(AF_INET, &addr.sin_addr, + &buf[0], buf.length()); #endif Assert(s != NULL, "cannot decode address"); return std::string(s); @@ -94,12 +97,12 @@ class Socket { */ inline static void Startup(void) { #ifdef _WIN32 - WSADATA wsa_data; + WSADATA wsa_data; if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) { - Socket::Error("Startup"); - } + Socket::Error("Startup"); + } if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { - WSACleanup(); + WSACleanup(); utils::Error("Could not find a usable version of Winsock.dll\n"); } #endif @@ -118,11 +121,11 @@ class Socket { * it will set it back to block mode */ inline void SetNonBlock(bool non_block) { -#ifdef _WIN32 - u_long mode = non_block ? 1 : 0; - if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { +#ifdef _WIN32 + u_long mode = non_block ? 1 : 0; + if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { Socket::Error("SetNonBlock"); - } + } #else int flag = fcntl(sockfd, F_GETFL, 0); if (flag == -1) { @@ -143,7 +146,8 @@ class Socket { * \param addr */ inline void Bind(const SockAddr &addr) { - if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { + if (bind(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == -1) { Socket::Error("Bind"); } } @@ -154,10 +158,11 @@ class Socket { * \return the port successfully bind to, return -1 if failed to bind any port */ inline int TryBindHost(int start_port, int end_port) { - // TODO, add prefix check + // TODO(tqchen) add prefix check for (int port = start_port; port < end_port; ++port) { SockAddr addr("0.0.0.0", port); - if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) { + if (bind(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == 0) { return port; } if (errno != EADDRINUSE) { @@ -179,22 +184,22 @@ class Socket { inline bool BadSocket(void) const { if (IsClosed()) return true; int err = GetSockError(); - if (err == EBADF || err == EINTR) return true; + if (err == EBADF || err == EINTR) return true; return false; } /*! \brief check if socket is already closed */ inline bool IsClosed(void) const { return sockfd == INVALID_SOCKET; - } + } /*! \brief close the socket */ inline void Close(void) { if (sockfd != INVALID_SOCKET) { #ifdef _WIN32 closesocket(sockfd); #else - close(sockfd); + close(sockfd); #endif - sockfd = INVALID_SOCKET; + sockfd = INVALID_SOCKET; } else { Error("Socket::Close double close the socket or close without create"); } @@ -204,6 +209,7 @@ class Socket { int errsv = errno; utils::Error("Socket %s Error:%s", msg, strerror(errsv)); } + protected: explicit Socket(SOCKET sockfd) : sockfd(sockfd) { } @@ -227,7 +233,7 @@ class TCPSocket : public Socket{ int opt = static_cast(keepalive); if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) { Socket::Error("SetKeepAlive"); - } + } } /*! * \brief create the socket, call this before using socket @@ -273,7 +279,8 @@ class TCPSocket : public Socket{ * \return whether connect is successful */ inline bool Connect(const SockAddr &addr) { - return connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0; + return connect(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == 0; } /*! * \brief send data using the socket @@ -284,7 +291,7 @@ class TCPSocket : public Socket{ * return -1 if error occurs */ inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { - const char *buf = reinterpret_cast(buf_); + const char *buf = reinterpret_cast(buf_); return send(sockfd, buf, static_cast(len), flag); } /*! @@ -296,7 +303,7 @@ class TCPSocket : public Socket{ * return -1 if error occurs */ inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { - char *buf = reinterpret_cast(buf_); + char *buf = reinterpret_cast(buf_); return recv(sockfd, buf, static_cast(len), flags); } /*! @@ -331,7 +338,8 @@ class TCPSocket : public Socket{ char *buf = reinterpret_cast(buf_); size_t ndone = 0; while (ndone < len) { - ssize_t ret = recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); + ssize_t ret = recv(sockfd, buf, + static_cast(len - ndone), MSG_WAITALL); if (ret == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; Socket::Error("RecvAll"); @@ -385,7 +393,7 @@ struct SelectHelper { * \param fd file descriptor to be watched */ inline void WatchRead(SOCKET fd) { - FD_SET(fd, &read_set); + FD_SET(fd, &read_set); if (fd > maxfd) maxfd = fd; } /*! @@ -403,7 +411,7 @@ struct SelectHelper { inline void WatchException(SOCKET fd) { FD_SET(fd, &except_set); if (fd > maxfd) maxfd = fd; - } + } /*! * \brief Check if the descriptor is ready for read * \param fd file descriptor to check status @@ -435,8 +443,9 @@ struct SelectHelper { fd_set wait_set; FD_ZERO(&wait_set); FD_SET(fd, &wait_set); - return Select_(static_cast(fd + 1), NULL, NULL, &wait_set, timeout); - } + return Select_(static_cast(fd + 1), + NULL, NULL, &wait_set, timeout); + } /*! * \brief peform select on the set defined * \param select_read whether to watch for read event @@ -454,9 +463,10 @@ struct SelectHelper { } return ret; } - + private: - inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) { + inline static int Select_(int maxfd, fd_set *rfds, + fd_set *wfds, fd_set *efds, long timeout) { utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE"); if (timeout == 0) { return select(maxfd, rfds, wfds, efds, NULL); @@ -465,12 +475,12 @@ struct SelectHelper { tm.tv_usec = (timeout % 1000) * 1000; tm.tv_sec = timeout / 1000; return select(maxfd, rfds, wfds, efds, &tm); - } + } } - - SOCKET maxfd; + + SOCKET maxfd; fd_set read_set, write_set, except_set; }; } // namespace utils } // namespace rabit -#endif +#endif // RABIT_SOCKET_H_