diff --git a/src/engine.cc b/src/engine.cc new file mode 100644 index 000000000..17aacd5cf --- /dev/null +++ b/src/engine.cc @@ -0,0 +1,39 @@ +/*! + * \file engine.cc + * \brief this file governs which implementation of engine we are actually using + * provides an singleton of engine interface + * + * \author Tianqi, Nacho, Tianyi + */ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE +#define NOMINMAX + +#include "./engine.h" +#include "./engine_base.h" +#include "./engine_robust.h" + +namespace engine { +// singleton sync manager +AllReduceRobust manager; + +/*! \brief intiialize the synchronization module */ +void Init(int argc, char *argv[]) { + for (int i = 1; i < argc; ++i) { + char name[256], val[256]; + if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { + manager.SetParam(name, val); + } + } + manager.Init(); +} + +/*! \brief finalize syncrhonization module */ +void Finalize(void) { + manager.Shutdown(); +} +/*! \brief singleton method to get engine */ +IEngine *GetEngine(void) { + return &manager; +} +} // namespace engine diff --git a/src/engine.h b/src/engine.h index 510f3aabd..42e19c139 100644 --- a/src/engine.h +++ b/src/engine.h @@ -1,10 +1,10 @@ -#ifndef ALLREDUCE_ENGINE_H -#define ALLREDUCE_ENGINE_H /*! * \file engine.h * \brief This file defines the interface of allreduce library * \author Tianqi Chen, Nacho, Tianyi */ +#ifndef ALLREDUCE_ENGINE_H +#define ALLREDUCE_ENGINE_H #include "./io.h" @@ -49,7 +49,7 @@ class IEngine { * \param sendrecvbuf_ buffer for both sending and recving data * \param size the size of the data to be broadcasted * \param root the root worker id to broadcast the data - */ + */ virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0; /*! * \brief load latest check point diff --git a/src/engine_base.cc b/src/engine_base.cc new file mode 100644 index 000000000..e2eca014f --- /dev/null +++ b/src/engine_base.cc @@ -0,0 +1,349 @@ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE +#define NOMINMAX +#include +#include "./engine_base.h" + +namespace engine { +// constructor +AllReduceBase::AllReduceBase(void) { + master_uri = "NULL"; + master_port = 9000; + host_uri = ""; + slave_port = 9010; + nport_trial = 1000; + rank = 0; + world_size = 1; + this->SetParam("reduce_buffer", "256MB"); +} + +// initialization function +void AllReduceBase::Init(void) { + utils::Socket::Startup(); + // single node mode + if (master_uri == "NULL") return; + utils::Assert(links.size() == 0, "can only call Init once"); + int magic = kMagic; + int nchild = 0, nparent = 0; + this->host_uri = utils::SockAddr::GetHostName(); + // get information from master + utils::TCPSocket master; + master.Create(); + if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) { + utils::Socket::Error("Connect"); + } + utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1"); + utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2"); + utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); + utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3"); + utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4"); + utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5"); + utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6"); + utils::Assert(nchild >= 0, "in correct number of childs"); + utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent"); + + // create listen + utils::TCPSocket sock_listen; + sock_listen.Create(); + int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); + utils::Check(port != -1, "sync::Init fail to bind the ports specified"); + sock_listen.Listen(); + + if (nparent != 0) { + parent_index = 0; + links.push_back(LinkRecord()); + int len, hport; + std::string hname; + utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9"); + hname.resize(len); + utils::Assert(len != 0, "string must not be empty"); + utils::Assert(master.RecvAll(&hname[0], len) == static_cast(len), "sync::Init failure 10"); + utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11"); + links[0].sock.Create(); + links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport)); + utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12"); + utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13"); + utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch"); + parent_index = 0; + } else { + parent_index = -1; + } + // send back socket listening port to master + utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14"); + // close connection to master + master.Close(); + // accept links from childs + for (int i = 0; i < nchild; ++i) { + LinkRecord r; + while (true) { + r.sock = sock_listen.Accept(); + if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) { + utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15"); + break; + } else { + // not a valid child + r.sock.Close(); + } + } + links.push_back(r); + } + // close listening sockets + sock_listen.Close(); + // setup selecter + for (size_t i = 0; i < links.size(); ++i) { + // set the socket to non-blocking mode + links[i].sock.SetNonBlock(true); + } + // done +} + +void AllReduceBase::Shutdown(void) { + for (size_t i = 0; i < links.size(); ++i) { + links[i].sock.Close(); + } + links.clear(); + utils::TCPSocket::Finalize(); +} +// set the parameters for AllReduce +void AllReduceBase::SetParam(const char *name, const char *val) { + if (!strcmp(name, "master_uri")) master_uri = val; + if (!strcmp(name, "master_port")) master_port = atoi(val); + if (!strcmp(name, "reduce_buffer")) { + char unit; + unsigned long amount; + if (sscanf(val, "%lu%c", &amount, &unit) == 2) { + switch (unit) { + case 'B': reduce_buffer_size = (amount + 7)/ 8; break; + case 'K': reduce_buffer_size = amount << 7UL; break; + case 'M': reduce_buffer_size = amount << 17UL; break; + case 'G': reduce_buffer_size = amount << 27UL; break; + default: utils::Error("invalid format for reduce buffer"); + } + } else { + utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); + } + } +} + +/*! + * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure + * + * NOTE on AllReduce: + * The kSuccess TryAllReduce does NOT mean every node have successfully finishes TryAllReduce. + * It only means the current node get the correct result of AllReduce. + * However, it means every node finishes LAST call(instead of this one) of AllReduce/Bcast + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return + * - kSuccess: allreduce is success, + * - kSockError: a neighbor node go down, the connection is dropped + * - kGetExcept: another node which is not my neighbor go down, get Out-of-Band exception notification from my neighbor + */ +AllReduceBase::ReturnType +AllReduceBase::TryAllReduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) { + if (links.size() == 0) return kSuccess; + // total size of message + const size_t total_size = type_nbytes * count; + // number of links + const int nlink = static_cast(links.size()); + // send recv buffer + char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); + // size of space that we already performs reduce in up pass + size_t size_up_reduce = 0; + // size of space that we have already passed to parent + size_t size_up_out = 0; + // size of message we received, and send in the down pass + size_t size_down_in = 0; + // initialize the link ring-buffer and pointer + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + links[i].InitBuffer(type_nbytes, count, reduce_buffer_size); + } + links[i].ResetSize(); + } + // if no childs, no need to reduce + if (nlink == static_cast(parent_index != -1)) { + size_up_reduce = total_size; + } + + // while we have not passed the messages out + while (true) { + // select helper + utils::SelectHelper selecter; + for (size_t i = 0; i < links.size(); ++i) { + selecter.WatchRead(links[i].sock); + selecter.WatchWrite(links[i].sock); + selecter.WatchException(links[i].sock); + } + // select must return + selecter.Select(); + // exception handling + for (int i = 0; i < nlink; ++i) { + // recive OOB message from some link + if (selecter.CheckExcept(links[i].sock)) return kGetExcept; + } + // read data from childs + for (int i = 0; i < nlink; ++i) { + if (i != parent_index && selecter.CheckRead(links[i].sock)) { + if (!links[i].ReadToRingBuffer(size_up_out)) return kSockError; + } + } + // this node have childs, peform reduce + if (nlink > static_cast(parent_index != -1)) { + size_t buffer_size = 0; + // do upstream reduce + size_t max_reduce = total_size; + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + max_reduce= std::min(max_reduce, links[i].size_read); + utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size, + "buffer size inconsistent"); + buffer_size = links[i].buffer_size; + } + } + utils::Assert(buffer_size != 0, "must assign buffer_size"); + // round to type_n4bytes + max_reduce = (max_reduce / type_nbytes * type_nbytes); + // peform reduce, can be at most two rounds + while (size_up_reduce < max_reduce) { + // start position + size_t start = size_up_reduce % buffer_size; + // peform read till end of buffer + size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce); + utils::Assert(nread % type_nbytes == 0, "AllReduce: size check"); + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + reducer(links[i].buffer_head + start, + sendrecvbuf + size_up_reduce, + static_cast(nread / type_nbytes), + MPI::Datatype(type_nbytes)); + } + } + size_up_reduce += nread; + } + } + if (parent_index != -1) { + // pass message up to parent, can pass data that are already been reduced + if (selecter.CheckWrite(links[parent_index].sock)) { + ssize_t len = links[parent_index].sock. + Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); + if (len != -1) { + size_up_out += static_cast(len); + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; + } + } + // read data from parent + if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) { + ssize_t len = links[parent_index].sock. + Recv(sendrecvbuf + size_down_in, total_size - size_down_in); + if (len == 0) { + links[parent_index].sock.Close(); return kSockError; + } + if (len != -1) { + size_down_in += static_cast(len); + utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error"); + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; + } + } + } else { + // this is root, can use reduce as most recent point + size_down_in = size_up_out = size_up_reduce; + } + // check if we finished the job of message passing + size_t nfinished = size_down_in; + // can pass message down to childs + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + if (selecter.CheckWrite(links[i].sock)) { + if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError; + } + nfinished = std::min(links[i].size_write, nfinished); + } + } + // check boundary condition + if (nfinished >= total_size) break; + } + return kSuccess; +} +/*! + * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure + * \param sendrecvbuf_ buffer for both sending and recving data + * \param total_size the size of the data to be broadcasted + * \param root the root worker id to broadcast the data + * \return this function can return three possible values, see detail in TryAllReduce + */ +AllReduceBase::ReturnType +AllReduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { + if (links.size() == 0) return kSuccess; + // number of links + const int nlink = static_cast(links.size()); + // size of space already read from data + size_t size_in = 0; + // input link, -2 means unknown yet, -1 means this is root + int in_link = -2; + + // initialize the link statistics + for (int i = 0; i < nlink; ++i) { + links[i].ResetSize(); + } + // root have all the data + if (this->rank == root) { + size_in = total_size; + in_link = -1; + } + // while we have not passed the messages out + while(true) { + // select helper + utils::SelectHelper selecter; + for (size_t i = 0; i < links.size(); ++i) { + selecter.WatchRead(links[i].sock); + selecter.WatchWrite(links[i].sock); + selecter.WatchException(links[i].sock); + } + // exception handling + for (int i = 0; i < nlink; ++i) { + // recive OOB message from some link + if (selecter.CheckExcept(links[i].sock)) return kGetExcept; + } + if (in_link == -2) { + // probe in-link + for (int i = 0; i < nlink; ++i) { + if (selecter.CheckRead(links[i].sock)) { + if (!links[i].ReadToArray(sendrecvbuf_, total_size)) return kSockError; + size_in = links[i].size_read; + if (size_in != 0) { + in_link = i; break; + } + } + } + } else { + // read from in link + if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { + if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) return kSockError; + size_in = links[in_link].size_read; + } + } + size_t nfinished = total_size; + // send data to all out-link + for (int i = 0; i < nlink; ++i) { + if (i != in_link) { + if (selecter.CheckWrite(links[i].sock)) { + if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) return kSockError; + } + nfinished = std::min(nfinished, links[i].size_write); + } + } + // check boundary condition + if (nfinished >= total_size) break; + } + return kSuccess; +} +} // namespace engine diff --git a/src/engine_base.h b/src/engine_base.h new file mode 100644 index 000000000..6c138529a --- /dev/null +++ b/src/engine_base.h @@ -0,0 +1,244 @@ +/*! + * \file engine_base.h + * \brief Basic implementation of AllReduce + * using TCP non-block socket and tree-shape reduction. + * + * This implementation provides basic utility of AllReduce and Broadcast + * without considering node failure + * + * \author Tianqi, Nacho, Tianyi + */ +#ifndef ALLREDUCE_ENGINE_BASE_H +#define ALLREDUCE_ENGINE_BASE_H + +#include +#include +#include "./utils.h" +#include "./socket.h" +#include "./engine.h" + +namespace MPI { +// MPI data type to be compatible with existing MPI interface +class Datatype { + public: + size_t type_size; + Datatype(size_t type_size) : type_size(type_size) {} +}; +} + +namespace engine { +/*! \brief implementation of basic AllReduce engine */ +class AllReduceBase : public IEngine { + public: + // magic number to verify server + const static int kMagic = 0xff99; + // constant one byte out of band message to indicate error happening + AllReduceBase(void); + virtual ~AllReduceBase(void) {} + // shutdown the engine + void Shutdown(void); + // initialize the manager + void Init(void); + /*! \brief set parameters to the sync manager */ + virtual void SetParam(const char *name, const char *val); + /*! \brief get rank */ + virtual int GetRank(void) const { + return rank; + } + /*! \brief get rank */ + virtual int GetWorldSize(void) const { + return world_size; + } + /*! \brief get rank */ + virtual std::string GetHost(void) const { + return host_uri; + } + /*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + */ + virtual void AllReduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) { + utils::Assert(TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess, + "AllReduce failed"); + } + /*! + * \brief broadcast data from root to all nodes + * \param sendrecvbuf_ buffer for both sending and recving data + * \param size the size of the data to be broadcasted + * \param root the root worker id to broadcast the data + */ + virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { + utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, + "AllReduce failed"); + } + /*! + * \brief load latest check point + * \param p_model pointer to the model + * \return true if there was stored checkpoint and load was successful + * false if there was no stored checkpoint, means we are start over gain + */ + virtual bool LoadCheckPoint(utils::ISerializable *p_model) { + return false; + } + /*! + * \brief checkpoint the model, meaning we finished a stage of execution + * \param p_model pointer to the model + */ + virtual void CheckPoint(const utils::ISerializable &model) { + } + + protected: + /*! \brief enumeration of possible returning results from Try functions */ + enum ReturnType { + kSuccess, + kSockError, + kGetExcept + }; + // link record to a neighbor + struct LinkRecord { + public: + // socket to get data from/to link + utils::TCPSocket sock; + // size of data readed from link + size_t size_read; + // size of data sent to the link + size_t size_write; + // pointer to buffer head + char *buffer_head; + // buffer size, in bytes + size_t buffer_size; + // constructor + LinkRecord(void) {} + // initialize buffer + inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { + size_t n = (type_nbytes * count + 7)/ 8; + buffer_.resize(std::min(reduce_buffer_size, n)); + // make sure align to type_nbytes + buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; + utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size); + // set buffer head + buffer_head = reinterpret_cast(BeginPtr(buffer_)); + } + // reset the recv and sent size + inline void ResetSize(void) { + size_write = size_read = 0; + } + /*! + * \brief read data into ring-buffer, with care not to existing useful override data + * position after protect_start + * \param protect_start all data start from protect_start is still needed in buffer + * read shall not override this + * \return true if it is an successful read, false if there is some error happens, check errno + */ + inline bool ReadToRingBuffer(size_t protect_start) { + size_t ngap = size_read - protect_start; + utils::Assert(ngap <= buffer_size, "AllReduce: boundary check"); + size_t offset = size_read % buffer_size; + size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); + if (nmax == 0) return true; + ssize_t len = sock.Recv(buffer_head + offset, nmax); + // length equals 0, remote disconnected + if (len == 0) { + sock.Close(); return false; + } + if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; + size_read += static_cast(len); + return true; + } + /*! + * \brief read data into array, + * this function can not be used together with ReadToRingBuffer + * a link can either read into the ring buffer, or existing array + * \param max_size maximum size of array + * \return true if it is an successful read, false if there is some error happens, check errno + */ + inline bool ReadToArray(void *recvbuf_, size_t max_size) { + if (max_size == size_read) return true; + char *p = static_cast(recvbuf_); + ssize_t len = sock.Recv(p + size_read, max_size - size_read); + // length equals 0, remote disconnected + if (len == 0) { + sock.Close(); return false; + } + if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; + size_read += static_cast(len); + return true; + } + /*! + * \brief write data in array to sock + * \param sendbuf_ head of array + * \param max_size maximum size of array + * \return true if it is an successful write, false if there is some error happens, check errno + */ + inline bool WriteFromArray(const void *sendbuf_, size_t max_size) { + const char *p = static_cast(sendbuf_); + ssize_t len = sock.Send(p + size_write, max_size - size_write); + if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; + size_write += static_cast(len); + return true; + } + + private: + // recv buffer to get data from child + // aligned with 64 bits, will be able to perform 64 bits operations freely + std::vector buffer_; + }; + /*! + * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure + * + * NOTE on AllReduce: + * The kSuccess TryAllReduce does NOT mean every node have successfully finishes TryAllReduce. + * It only means the current node get the correct result of AllReduce. + * However, it means every node finishes LAST call(instead of this one) of AllReduce/Bcast + * + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + * \return this function can return + * - kSuccess: allreduce is success, + * - kSockError: a neighbor node go down, the connection is dropped + * - kGetExcept: another node which is not my neighbor go down, get Out-of-Band exception notification from my neighbor + */ + ReturnType TryAllReduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer); + /*! + * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure + * \param sendrecvbuf_ buffer for both sending and recving data + * \param size the size of the data to be broadcasted + * \param root the root worker id to broadcast the data + * \return this function can return three possible values, see detail in TryAllReduce + */ + ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); + //---- local data related to link ---- + // index of parent link, can be -1, meaning this is root of the tree + int parent_index; + // sockets of all links + std::vector links; + //----- meta information----- + // uri of current host, to be set by Init + std::string host_uri; + // uri of master + std::string master_uri; + // port of master address + int master_port; + // port of slave process + int slave_port, nport_trial; + // reduce buffer size + size_t reduce_buffer_size; + // current rank + int rank; + // world size + int world_size; +}; +} // namespace engine +#endif // ALLREDUCE_ENGINE_BASE_H diff --git a/src/engine_robust.cc b/src/engine_robust.cc new file mode 100644 index 000000000..b969bc9f6 --- /dev/null +++ b/src/engine_robust.cc @@ -0,0 +1,181 @@ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE +#define NOMINMAX +#include "./utils.h" +#include "./engine_robust.h" + +namespace engine { +/*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + */ +void AllReduceRobust::AllReduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) { + utils::LogPrintf("[%d] call AllReduce", rank); + TryResetLinks(); + utils::LogPrintf("[%d] start work", rank); + while (true) { + ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer); + if (ret == kSuccess) return; + if (ret == kSockError) { + utils::Error("error occur during all reduce\n"); + } + utils::LogPrintf("[%d] receive except signal, start reset link", rank); + TryResetLinks(); + //utils::Check(TryResetLinks() == kSuccess, "error when reset links"); + } + // TODO +} +/*! + * \brief broadcast data from root to all nodes + * \param sendrecvbuf_ buffer for both sending and recving data + * \param size the size of the data to be broadcasted + * \param root the root worker id to broadcast the data + */ +void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) { + utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, + "AllReduce failed"); + // TODO +} +/*! + * \brief load latest check point + * \param p_model pointer to the model + * \return true if there was stored checkpoint and load was successful + * false if there was no stored checkpoint, means we are start over gain + */ +bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { + // TODO + return false; +} +/*! + * \brief checkpoint the model, meaning we finished a stage of execution + * \param p_model pointer to the model + */ +void AllReduceRobust::CheckPoint(const utils::ISerializable &model) { + // TODO +} +/*! + * \brief reset the all the existing links by sending Out-of-Band message marker + * after this function finishes, all the messages received and sent before in all live links are discarded, + * This allows us to get a fresh start after error has happened + * + * \return this function can return kSuccess or kSockError + * when kSockError is returned, it simply means there are bad sockets in the links, + * and some link recovery proceduer is needed + */ +AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) { + utils::LogPrintf("[%d] TryResetLinks, start\n", rank); + // number of links + const int nlink = static_cast(links.size()); + for (int i = 0; i < nlink; ++i) { + links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); + links[i].ResetSize(); + } + // read and discard data from all channels until pass mark + while (true) { + for (int i = 0; i < nlink; ++i) { + if (links[i].sock.BadSocket()) continue; + if (links[i].size_write == 0) { + char sig = kOOBReset; + ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); + // error will be filtered in next loop + if (len == sizeof(sig)) links[i].size_write = 1; + } + if (links[i].size_write == 1) { + char sig = kResetMark; + ssize_t len = links[i].sock.Send(&sig, sizeof(sig)); + if (len == sizeof(sig)) links[i].size_write = 2; + } + if (links[i].size_read == 0) { + int atmark = links[i].sock.AtMark(); + if (atmark < 0) { + utils::Assert(links[i].sock.BadSocket(), "must already gone bad"); + } else if (atmark > 0) { + links[i].size_read = 1; + } else { + printf("buffer_size=%lu\n", links[i].buffer_size); + // no at mark, read and discard data + ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); + // zero length, remote closed the connection, close socket + if (len == 0) links[i].sock.Close(); + } + } + } + utils::SelectHelper rsel; + bool finished = true; + for (int i = 0; i < nlink; ++i) { + if (links[i].size_write != 2 && !links[i].sock.BadSocket()) { + rsel.WatchWrite(links[i].sock); finished = false; + } + if (links[i].size_read == 0 && !links[i].sock.BadSocket()) { + rsel.WatchRead(links[i].sock); finished = false; + } + } + if (finished) break; + // wait to read from the channels to discard data + rsel.Select(); + } + utils::LogPrintf("[%d] Finish discard data\n", rank); + // start synchronization, use blocking I/O to avoid select + for (int i = 0; i < nlink; ++i) { + if (!links[i].sock.BadSocket()) { + char oob_mark; + links[i].sock.SetNonBlock(false); + ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); + if (len == 0) { + links[i].sock.Close(); continue; + } else if (len > 0) { + utils::Assert(oob_mark == kResetMark, "wrong oob msg"); + utils::Assert(!links[i].sock.AtMark(), "should already read past mark"); + } else { + utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + } + // send out ack + char ack = kResetAck; + while (true) { + len = links[i].sock.Send(&ack, sizeof(ack)); + if (len == sizeof(ack)) break; + if (len == -1) { + if (errno != EAGAIN && errno != EWOULDBLOCK) break; + } + } + } + } + utils::LogPrintf("[%d] GGet all Acks\n", rank); + // wait all ack + for (int i = 0; i < nlink; ++i) { + if (!links[i].sock.BadSocket()) { + char ack; + ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL); + if (len == 0) { + links[i].sock.Close(); continue; + } else if (len > 0) { + utils::Assert(ack == kResetAck, "wrong Ack MSG"); + } else { + utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); + } + // set back to nonblock mode + links[i].sock.SetNonBlock(true); + } + } + for (int i = 0; i < nlink; ++i) { + if (links[i].sock.BadSocket()) return kSockError; + } + utils::LogPrintf("[%d] TryResetLinks,!! return\n", rank); + return kSuccess; +} + +bool AllReduceRobust::RecoverExec(void *sendrecvbuf_, size_t size, int flag, int seqno) { + if (flag != 0) { + utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations"); + } + ActionSummary act(flag, seqno); + return true; +} +} // namespace engine diff --git a/src/engine_robust.cpp b/src/engine_robust.cpp deleted file mode 100644 index 3382d189d..000000000 --- a/src/engine_robust.cpp +++ /dev/null @@ -1,702 +0,0 @@ -/*! - * \file engine_robust.cpp - * \brief Robust implementation of AllReduce - * using TCP non-block socket and tree-shape reduction. - * - * This implementation considers the failure of nodes - * - * \author Tianqi, Nacho, Tianyi - */ -#define _CRT_SECURE_NO_WARNINGS -#define _CRT_SECURE_NO_DEPRECATE -#define NOMINMAX -#include -#include -#include -#include "./utils.h" -#include "./engine.h" -#include "./socket.h" - -namespace MPI { -// MPI data type to be compatible with existing MPI interface -class Datatype { - public: - size_t type_size; - Datatype(size_t type_size) : type_size(type_size) {} -}; -} - -namespace engine { -/*! \brief implementation of fault tolerant all reduce engine */ -class AllReduceManager : public IEngine { - public: - // magic number to verify server - const static int kMagic = 0xff99; - // constant one byte out of band message to indicate error happening - // and mark for channel cleanup - const static char kOOBReset = 95; - // and mark for channel cleanup, after OOB signal - const static char kResetMark = 97; - // and mark for channel cleanup - const static char kResetAck = 97; - - AllReduceManager(void) { - master_uri = "NULL"; - master_port = 9000; - host_uri = ""; - slave_port = 9010; - nport_trial = 1000; - rank = 0; - world_size = 1; - this->SetParam("reduce_buffer", "256MB"); - } - ~AllReduceManager(void) { - } - inline void Shutdown(void) { - for (size_t i = 0; i < links.size(); ++i) { - links[i].sock.Close(); - } - links.clear(); - utils::TCPSocket::Finalize(); - } - /*! \brief set parameters to the sync manager */ - inline void SetParam(const char *name, const char *val) { - if (!strcmp(name, "master_uri")) master_uri = val; - if (!strcmp(name, "master_port")) master_port = atoi(val); - if (!strcmp(name, "reduce_buffer")) { - char unit; - unsigned long amount; - if (sscanf(val, "%lu%c", &amount, &unit) == 2) { - switch (unit) { - case 'B': reduce_buffer_size = (amount + 7)/ 8; break; - case 'K': reduce_buffer_size = amount << 7UL; break; - case 'M': reduce_buffer_size = amount << 17UL; break; - case 'G': reduce_buffer_size = amount << 27UL; break; - default: utils::Error("invalid format for reduce buffer"); - } - } else { - utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); - } - } - } - // initialize the manager - inline void Init(void) { - utils::Socket::Startup(); - // single node mode - if (master_uri == "NULL") return; - utils::Assert(links.size() == 0, "can only call Init once"); - int magic = kMagic; - int nchild = 0, nparent = 0; - this->host_uri = utils::SockAddr::GetHostName(); - // get information from master - utils::TCPSocket master; - master.Create(); - if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) { - utils::Socket::Error("Connect"); - } - utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1"); - utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2"); - utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); - utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3"); - utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4"); - utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5"); - utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6"); - utils::Assert(nchild >= 0, "in correct number of childs"); - utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent"); - - // create listen - utils::TCPSocket sock_listen; - sock_listen.Create(); - int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); - utils::Check(port != -1, "sync::Init fail to bind the ports specified"); - sock_listen.Listen(); - - if (nparent != 0) { - parent_index = 0; - links.push_back(LinkRecord()); - int len, hport; - std::string hname; - utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9"); - hname.resize(len); - utils::Assert(len != 0, "string must not be empty"); - utils::Assert(master.RecvAll(&hname[0], len) == static_cast(len), "sync::Init failure 10"); - utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11"); - links[0].sock.Create(); - links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport)); - utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12"); - utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13"); - utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch"); - parent_index = 0; - } else { - parent_index = -1; - } - // send back socket listening port to master - utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14"); - // close connection to master - master.Close(); - // accept links from childs - for (int i = 0; i < nchild; ++i) { - LinkRecord r; - while (true) { - r.sock = sock_listen.Accept(); - if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) { - utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15"); - break; - } else { - // not a valid child - r.sock.Close(); - } - } - links.push_back(r); - } - // close listening sockets - sock_listen.Close(); - // setup selecter - for (size_t i = 0; i < links.size(); ++i) { - // set the socket to non-blocking mode - links[i].sock.SetNonBlock(true); - } - // done - } - /*! \brief get rank */ - virtual int GetRank(void) const { - return rank; - } - /*! \brief get rank */ - virtual int GetWorldSize(void) const { - return world_size; - } - /*! \brief get rank */ - virtual std::string GetHost(void) const { - return host_uri; - } - virtual void AllReduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - while (true) { - ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer); - if (ret == kSuccess) return; - if (ret == kSockError) { - utils::Error("error occur during all reduce\n"); - } - utils::Check(TryResetLinks() == kSuccess, "error when reset links"); - } - } - /*! - * \brief broadcast data from root to all nodes - * \param sendrecvbuf_ buffer for both sending and recving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - */ - virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { - if (links.size() == 0) return; - // number of links - const int nlink = static_cast(links.size()); - // size of space already read from data - size_t size_in = 0; - // input link, -2 means unknown yet, -1 means this is root - int in_link = -2; - - // initialize the link statistics - for (int i = 0; i < nlink; ++i) { - links[i].ResetSize(); - } - // root have all the data - if (this->rank == root) { - size_in = total_size; - in_link = -1; - } - - // while we have not passed the messages out - while(true) { - // select helper - utils::SelectHelper selecter; - for (size_t i = 0; i < links.size(); ++i) { - selecter.WatchRead(links[i].sock); - selecter.WatchWrite(links[i].sock); - selecter.WatchException(links[i].sock); - } - if (in_link == -2) { - // probe in-link - for (int i = 0; i < nlink; ++i) { - if (selecter.CheckRead(links[i].sock)) { - if (!links[i].ReadToArray(sendrecvbuf_, total_size)) { - utils::Socket::Error("Recv"); - } - size_in = links[i].size_read; - if (size_in != 0) { - in_link = i; break; - } - } - } - } else { - // read from in link - if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { - if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) { - utils::Socket::Error("Recv"); - } - size_in = links[in_link].size_read; - } - } - size_t nfinished = total_size; - // send data to all out-link - for (int i = 0; i < nlink; ++i) { - if (i != in_link) { - if (selecter.CheckWrite(links[i].sock)) { - if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) { - utils::Socket::Error("Send"); - } - } - nfinished = std::min(nfinished, links[i].size_write); - } - } - // check boundary condition - if (nfinished >= total_size) break; - } - } - virtual bool LoadCheckPoint(utils::ISerializable *p_model) { - return false; - } - virtual void CheckPoint(const utils::ISerializable &model) { - } - - protected: - // possible returning type from the Try Functions - enum ReturnType { - kSuccess, - kSockError, - kGetExcept - }; - // possible state of the server - enum ServerState { - kNormal, - kConnDrop, - kRecover - }; - // cleanup the links, by sending OOB message - inline ReturnType TryResetLinks(void) { - // number of links - const int nlink = static_cast(links.size()); - for (int i = 0; i < nlink; ++i) { - links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); - links[i].ResetSize(); - links[i].except = false; - } - // read and discard data from all channels until pass mark - while (true) { - for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) continue; - if (links[i].size_write == 0) { - char sig = kOOBReset; - ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); - // error will be filtered in next loop - if (len == sizeof(sig)) links[i].size_write = 1; - } - if (links[i].size_write == 1) { - char sig = kResetMark; - ssize_t len = links[i].sock.Send(&sig, sizeof(sig)); - if (len == sizeof(sig)) links[i].size_write = 2; - } - if (links[i].size_read == 0) { - int atmark = links[i].sock.AtMark(); - if (atmark < 0) { - utils::Assert(links[i].sock.BadSocket(), "must already gone bad"); - } else if (atmark > 0) { - links[i].size_read = 1; - } else { - // no at mark, read and discard data - ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size); - // zero length, remote closed the connection, close socket - if (len == 0) links[i].sock.Close(); - } - } - } - utils::SelectHelper rsel; - bool finished = true; - for (int i = 0; i < nlink; ++i) { - if (links[i].size_write != 2 && !links[i].sock.BadSocket()) { - rsel.WatchWrite(links[i].sock); finished = false; - } - if (links[i].size_read == 0 && !links[i].sock.BadSocket()) { - rsel.WatchRead(links[i].sock); finished = false; - } - } - if (finished) break; - // wait to read from the channels to discard data - rsel.Select(); - } - // start synchronization, use blocking I/O to avoid select - for (int i = 0; i < nlink; ++i) { - if (!links[i].sock.BadSocket()) { - char oob_mark; - links[i].sock.SetNonBlock(false); - ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); - if (len == 0) { - links[i].sock.Close(); continue; - } else if (len > 0) { - utils::Assert(oob_mark == kResetMark, "wrong oob msg"); - utils::Assert(!links[i].sock.AtMark(), "should already read past mark"); - } else { - utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); - } - // send out ack - char ack = kResetAck; - while (true) { - len = links[i].sock.Send(&ack, sizeof(ack)); - if (len == sizeof(ack)) break; - if (len == -1) { - if (errno != EAGAIN && errno != EWOULDBLOCK) break; - } - } - } - } - // wait all ack - for (int i = 0; i < nlink; ++i) { - if (!links[i].sock.BadSocket()) { - char ack; - ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL); - if (len == 0) { - links[i].sock.Close(); continue; - } else if (len > 0) { - utils::Assert(ack == kResetAck, "wrong Ack MSG"); - } else { - utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); - } - // set back to nonblock mode - links[i].sock.SetNonBlock(true); - } - } - for (int i = 0; i < nlink; ++i) { - if (links[i].sock.BadSocket()) return kSockError; - } - return kSuccess; - } - // Run AllReduce, return if success - inline ReturnType TryAllReduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - if (links.size() == 0) return kSuccess; - // total size of message - const size_t total_size = type_nbytes * count; - // number of links - const int nlink = static_cast(links.size()); - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - // size of space that we already performs reduce in up pass - size_t size_up_reduce = 0; - // size of space that we have already passed to parent - size_t size_up_out = 0; - // size of message we received, and send in the down pass - size_t size_down_in = 0; - // initialize the link ring-buffer and pointer - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - links[i].InitBuffer(type_nbytes, count, reduce_buffer_size); - } - links[i].ResetSize(); - } - // if no childs, no need to reduce - if (nlink == static_cast(parent_index != -1)) { - size_up_reduce = total_size; - } - - // while we have not passed the messages out - while (true) { - // select helper - utils::SelectHelper selecter; - for (size_t i = 0; i < links.size(); ++i) { - selecter.WatchRead(links[i].sock); - selecter.WatchWrite(links[i].sock); - selecter.WatchException(links[i].sock); - } - // select must return - selecter.Select(); - // exception handling - for (int i = 0; i < nlink; ++i) { - // recive OOB message from some link - if (selecter.CheckExcept(links[i].sock)) return kGetExcept; - } - // read data from childs - for (int i = 0; i < nlink; ++i) { - if (i != parent_index && selecter.CheckRead(links[i].sock)) { - if (!links[i].ReadToRingBuffer(size_up_out)) return kSockError; - } - } - // this node have childs, peform reduce - if (nlink > static_cast(parent_index != -1)) { - size_t buffer_size = 0; - // do upstream reduce - size_t max_reduce = total_size; - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - max_reduce= std::min(max_reduce, links[i].size_read); - utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size, - "buffer size inconsistent"); - buffer_size = links[i].buffer_size; - } - } - utils::Assert(buffer_size != 0, "must assign buffer_size"); - // round to type_n4bytes - max_reduce = (max_reduce / type_nbytes * type_nbytes); - // peform reduce, can be at most two rounds - while (size_up_reduce < max_reduce) { - // start position - size_t start = size_up_reduce % buffer_size; - // peform read till end of buffer - size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce); - utils::Assert(nread % type_nbytes == 0, "AllReduce: size check"); - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - reducer(links[i].buffer_head + start, - sendrecvbuf + size_up_reduce, - static_cast(nread / type_nbytes), - MPI::Datatype(type_nbytes)); - } - } - size_up_reduce += nread; - } - } - if (parent_index != -1) { - // pass message up to parent, can pass data that are already been reduced - if (selecter.CheckWrite(links[parent_index].sock)) { - ssize_t len = links[parent_index].sock. - Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); - if (len != -1) { - size_up_out += static_cast(len); - } else { - if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; - } - } - // read data from parent - if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) { - ssize_t len = links[parent_index].sock. - Recv(sendrecvbuf + size_down_in, total_size - size_down_in); - if (len == 0) { - links[parent_index].sock.Close(); return kSockError; - } - if (len != -1) { - size_down_in += static_cast(len); - utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error"); - } else { - if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; - } - } - } else { - // this is root, can use reduce as most recent point - size_down_in = size_up_out = size_up_reduce; - } - // check if we finished the job of message passing - size_t nfinished = size_down_in; - // can pass message down to childs - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - if (selecter.CheckWrite(links[i].sock)) { - if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError; - } - nfinished = std::min(links[i].size_write, nfinished); - } - } - // check boundary condition - if (nfinished >= total_size) break; - } - return kSuccess; - } - - private: - // link record to a neighbor - struct LinkRecord { - public: - // socket to get data from/to link - utils::TCPSocket sock; - // size of data readed from link - size_t size_read; - // size of data sent to the link - size_t size_write; - // pointer to buffer head - char *buffer_head; - // buffer size, in bytes - size_t buffer_size; - // exception - bool except; - // constructor - LinkRecord(void) {} - - // initialize buffer - inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { - size_t n = (type_nbytes * count + 7)/ 8; - buffer_.resize(std::min(reduce_buffer_size, n)); - // make sure align to type_nbytes - buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; - utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size); - // set buffer head - buffer_head = reinterpret_cast(BeginPtr(buffer_)); - } - // reset the recv and sent size - inline void ResetSize(void) { - size_write = size_read = 0; - } - /*! - * \brief read data into ring-buffer, with care not to existing useful override data - * position after protect_start - * \param protect_start all data start from protect_start is still needed in buffer - * read shall not override this - * \return true if it is an successful read, false if there is some error happens, check errno - */ - inline bool ReadToRingBuffer(size_t protect_start) { - size_t ngap = size_read - protect_start; - utils::Assert(ngap <= buffer_size, "AllReduce: boundary check"); - size_t offset = size_read % buffer_size; - size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); - if (nmax == 0) return true; - ssize_t len = sock.Recv(buffer_head + offset, nmax); - // length equals 0, remote disconnected - if (len == 0) { - sock.Close(); return false; - } - if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; - size_read += static_cast(len); - return true; - } - /*! - * \brief read data into array, - * this function can not be used together with ReadToRingBuffer - * a link can either read into the ring buffer, or existing array - * \param max_size maximum size of array - * \return true if it is an successful read, false if there is some error happens, check errno - */ - inline bool ReadToArray(void *recvbuf_, size_t max_size) { - if (max_size == size_read) return true; - char *p = static_cast(recvbuf_); - ssize_t len = sock.Recv(p + size_read, max_size - size_read); - // length equals 0, remote disconnected - if (len == 0) { - sock.Close(); return false; - } - if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; - size_read += static_cast(len); - return true; - } - /*! - * \brief write data in array to sock - * \param sendbuf_ head of array - * \param max_size maximum size of array - * \return true if it is an successful write, false if there is some error happens, check errno - */ - inline bool WriteFromArray(const void *sendbuf_, size_t max_size) { - const char *p = static_cast(sendbuf_); - ssize_t len = sock.Send(p + size_write, max_size - size_write); - if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; - size_write += static_cast(len); - return true; - } - - private: - // recv buffer to get data from child - // aligned with 64 bits, will be able to perform 64 bits operations freely - std::vector buffer_; - }; - // data structure to remember result of Bcast and AllReduce calls - class ResultBuffer { - public: - // constructor - ResultBuffer(void) { - this->Clear(); - } - // clear the existing record - inline void Clear(void) { - seqno_.clear(); size_.clear(); - rptr_.clear(); rptr_.push_back(0); - data_.clear(); - } - // allocate temporal space for - inline void *AllocTemp(size_t type_nbytes, size_t count) { - size_t size = type_nbytes * count; - size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); - utils::Assert(nhop != 0, "cannot allocate 0 size memory"); - data_.resize(rptr_.back() + nhop); - return BeginPtr(data_) + rptr_.back(); - } - // push the result in temp to the - inline void PushTemp(int seqid, size_t type_nbytes, size_t count) { - size_t size = type_nbytes * count; - size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); - if (seqno_.size() != 0) { - utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent"); - } - seqno_.push_back(seqid); - rptr_.push_back(rptr_.back() + nhop); - size_.push_back(size); - utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent"); - } - // return the stored result of seqid, if any - inline void* Query(int seqid, size_t *p_size) { - size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin(); - if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL; - *p_size = size_[idx]; - return BeginPtr(data_) + rptr_[idx]; - } - private: - // sequence number of each - std::vector seqno_; - // pointer to the positions - std::vector rptr_; - // actual size of each buffer - std::vector size_; - // content of the buffer - std::vector data_; - }; - //---- recovery data structure ---- - // call sequence counter, records how many calls we made so far - // from last call to CheckPoint, LoadCheckPoint - int seq_counter; - // result buffer - ResultBuffer resbuf; - // model that is saved from last CheckPoint - std::string check_point; - //---- local data related to link ---- - // index of parent link, can be -1, meaning this is root of the tree - int parent_index; - // sockets of all links - std::vector links; - //----- meta information----- - // uri of current host, to be set by Init - std::string host_uri; - // uri of master - std::string master_uri; - // port of master address - int master_port; - // port of slave process - int slave_port, nport_trial; - // reduce buffer size - size_t reduce_buffer_size; - // current rank - int rank; - // world size - int world_size; -}; - -// singleton sync manager -AllReduceManager manager; - -/*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { - for (int i = 1; i < argc; ++i) { - char name[256], val[256]; - if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { - manager.SetParam(name, val); - } - } - manager.Init(); -} - -/*! \brief finalize syncrhonization module */ -void Finalize(void) { - manager.Shutdown(); -} -/*! \brief singleton method to get engine */ -IEngine *GetEngine(void) { - return &manager; -} -} // namespace engine diff --git a/src/engine_robust.h b/src/engine_robust.h new file mode 100644 index 000000000..f1949e11a --- /dev/null +++ b/src/engine_robust.h @@ -0,0 +1,206 @@ +/*! + * \file engine_robust.h + * \brief Robust implementation of AllReduce + * using TCP non-block socket and tree-shape reduction. + * + * This implementation considers the failure of nodes + * + * \author Tianqi, Nacho, Tianyi + */ +#ifndef ALLREDUCE_ENGINE_ROBUST_H +#define ALLREDUCE_ENGINE_ROBUST_H +#include "./engine.h" +#include "./engine_base.h" + +namespace engine { +/*! \brief implementation of fault tolerant all reduce engine */ +class AllReduceRobust : public AllReduceBase { + public: + virtual ~AllReduceRobust(void) {} + /*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_nbytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + */ + virtual void AllReduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer); + /*! + * \brief broadcast data from root to all nodes + * \param sendrecvbuf_ buffer for both sending and recving data + * \param size the size of the data to be broadcasted + * \param root the root worker id to broadcast the data + */ + virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root); + /*! + * \brief load latest check point + * \param p_model pointer to the model + * \return true if there was stored checkpoint and load was successful + * false if there was no stored checkpoint, means we are start over gain + */ + virtual bool LoadCheckPoint(utils::ISerializable *p_model); + /*! + * \brief checkpoint the model, meaning we finished a stage of execution + * \param p_model pointer to the model + */ + virtual void CheckPoint(const utils::ISerializable &model); + + private: + // constant one byte out of band message to indicate error happening + // and mark for channel cleanup + const static char kOOBReset = 95; + // and mark for channel cleanup, after OOB signal + const static char kResetMark = 97; + // and mark for channel cleanup + const static char kResetAck = 97; + /*! + * \brief summary of actions proposed in all nodes + * this data structure is used to make consensus decision + * about next action to take in the recovery mode + */ + struct ActionSummary { + // maximumly allowed sequence id + const static int kMaxSeq = 1 << 26; + //--------------------------------------------- + // The following are bit mask of flag used in + //---------------------------------------------- + // some node want to load check point + const static int kLoadCheck = 1; + // some node want to do check point + const static int kCheckPoint = 2; + // check point Ack, we use a two phase message in check point, + // this is the second phase of check pointing + const static int kCheckAck = 4; + // there are difference sequence number the nodes proposed + // this means we want to do recover execution of the lower sequence + // action instead of normal execution + const static int kDiffSeq = 8; + // constructor + ActionSummary(void) {} + // constructor of action + ActionSummary(int flag, int minseqno = kMaxSeq) { + seqcode = (minseqno << 4) | flag; + } + // minimum number of all operations + inline int min_seqno(void) const { + return seqcode >> 4; + } + // whether the operation set contains a check point + inline bool check_point(void) const { + return (seqcode & kCheckPoint) != 0; + } + // whether the operation set contains a check point + inline bool check_ack(void) const { + return (seqcode & kCheckAck) != 0; + } + // whether the operation set contains a check point + inline bool diff_seq(void) const { + return (seqcode & kDiffSeq) != 0; + } + // returns the operation flag of the result + inline int flag(void) const { + return seqcode & 15; + } + // reducer for AllReduce, used to get the result ActionSummary from all nodes + inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { + const ActionSummary *src = (const ActionSummary*)src_; + ActionSummary *dst = (ActionSummary*)dst_; + for (int i = 0; i < len; ++i) { + int src_seqno = src[i].min_seqno(); + int dst_seqno = dst[i].min_seqno(); + int flag = src[i].flag() | dst[i].flag(); + if (src_seqno == dst_seqno) { + dst[i] = ActionSummary(flag, src_seqno); + } else { + dst[i] = ActionSummary(flag | kDiffSeq, std::min(src_seqno, dst_seqno)); + } + } + } + + private: + // internel sequence code + int seqcode; + }; + /*! \brief data structure to remember result of Bcast and AllReduce calls */ + class ResultBuffer { + public: + // constructor + ResultBuffer(void) { + this->Clear(); + } + // clear the existing record + inline void Clear(void) { + seqno_.clear(); size_.clear(); + rptr_.clear(); rptr_.push_back(0); + data_.clear(); + } + // allocate temporal space for + inline void *AllocTemp(size_t type_nbytes, size_t count) { + size_t size = type_nbytes * count; + size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); + utils::Assert(nhop != 0, "cannot allocate 0 size memory"); + data_.resize(rptr_.back() + nhop); + return BeginPtr(data_) + rptr_.back(); + } + // push the result in temp to the + inline void PushTemp(int seqid, size_t type_nbytes, size_t count) { + size_t size = type_nbytes * count; + size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); + if (seqno_.size() != 0) { + utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent"); + } + seqno_.push_back(seqid); + rptr_.push_back(rptr_.back() + nhop); + size_.push_back(size); + utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent"); + } + // return the stored result of seqid, if any + inline void* Query(int seqid, size_t *p_size) { + size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin(); + if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL; + *p_size = size_[idx]; + return BeginPtr(data_) + rptr_[idx]; + } + private: + // sequence number of each + std::vector seqno_; + // pointer to the positions + std::vector rptr_; + // actual size of each buffer + std::vector size_; + // content of the buffer + std::vector data_; + }; + /*! + * \brief reset the all the existing links by sending Out-of-Band message marker + * after this function finishes, all the messages received and sent before in all live links are discarded, + * This allows us to get a fresh start after error has happened + * + * \return this function can return kSuccess or kSockError + * when kSockError is returned, it simply means there are bad sockets in the links, + * and some link recovery proceduer is needed + */ + ReturnType TryResetLinks(void); + /*! + * \brief Run recovery execution of a action specified by flag and seqno, + * there can be two outcome of the function + * + * \param sendrecvbuf_ + * + * \return if this function returns true, this means + * behind and we will be able to recover data from existing node + */ + bool RecoverExec(void *sendrecvbuf_, size_t size, int flag, int seqno); + //---- recovery data structure ---- + // call sequence counter, records how many calls we made so far + // from last call to CheckPoint, LoadCheckPoint + int seq_counter; + // result buffer + ResultBuffer resbuf; +}; +} // namespace engine +#endif // ALLREDUCE_ENGINE_ROBUST_H diff --git a/src/engine_tcp.cpp b/src/engine_tcp.cpp deleted file mode 100644 index 4cbbe384f..000000000 --- a/src/engine_tcp.cpp +++ /dev/null @@ -1,485 +0,0 @@ -/*! - * \file engine_tcp.cpp - * \brief implementation of sync AllReduce using TCP sockets - * with use non-block socket and tree-shape reduction - * \author Tianqi Chen - */ -#define _CRT_SECURE_NO_WARNINGS -#define _CRT_SECURE_NO_DEPRECATE -#define NOMINMAX -#include -#include -#include -#include "./engine.h" -#include "./socket.h" - -namespace MPI { -class Datatype { - public: - size_t type_size; - Datatype(size_t type_size) : type_size(type_size) {} -}; -} -namespace engine { -/*! \brief implementation of sync goes to here */ -class SyncManager : public IEngine { - public: - const static int kMagic = 0xff99; - SyncManager(void) { - master_uri = "NULL"; - master_port = 9000; - host_uri = ""; - slave_port = 9010; - nport_trial = 1000; - rank = 0; - world_size = 1; - this->SetParam("reduce_buffer", "256MB"); - } - ~SyncManager(void) { - } - - inline void Shutdown(void) { - for (size_t i = 0; i < links.size(); ++i) { - links[i].sock.Close(); - } - links.clear(); - utils::TCPSocket::Finalize(); - } - /*! \brief set parameters to the sync manager */ - inline void SetParam(const char *name, const char *val) { - if (!strcmp(name, "master_uri")) master_uri = val; - if (!strcmp(name, "master_port")) master_port = atoi(val); - if (!strcmp(name, "reduce_buffer")) { - char unit; - unsigned long amount; - if (sscanf(val, "%lu%c", &amount, &unit) == 2) { - switch (unit) { - case 'B': reduce_buffer_size = (amount + 7)/ 8; break; - case 'K': reduce_buffer_size = amount << 7UL; break; - case 'M': reduce_buffer_size = amount << 17UL; break; - case 'G': reduce_buffer_size = amount << 27UL; break; - default: utils::Error("invalid format for reduce buffer"); - } - } else { - utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); - } - } - } - // initialize the manager - inline void Init(void) { - utils::Socket::Startup(); - // single node mode - if (master_uri == "NULL") return; - utils::Assert(links.size() == 0, "can only call Init once"); - int magic = kMagic; - int nchild = 0, nparent = 0; - this->host_uri = utils::SockAddr::GetHostName(); - // get information from master - utils::TCPSocket master; - master.Create(); - if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) { - utils::Socket::Error("Connect"); - } - utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1"); - utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2"); - utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); - utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3"); - utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4"); - utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5"); - utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6"); - utils::Assert(nchild >= 0, "in correct number of childs"); - utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent"); - - // create listen - utils::TCPSocket sock_listen; - sock_listen.Create(); - int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); - utils::Check(port != -1, "sync::Init fail to bind the ports specified"); - sock_listen.Listen(); - - if (nparent != 0) { - parent_index = 0; - links.push_back(LinkRecord()); - int len, hport; - std::string hname; - utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9"); - hname.resize(len); - utils::Assert(len != 0, "string must not be empty"); - utils::Assert(master.RecvAll(&hname[0], len) == static_cast(len), "sync::Init failure 10"); - utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11"); - links[0].sock.Create(); - links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport)); - utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12"); - utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13"); - utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch"); - parent_index = 0; - } else { - parent_index = -1; - } - // send back socket listening port to master - utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14"); - // close connection to master - master.Close(); - // accept links from childs - for (int i = 0; i < nchild; ++i) { - LinkRecord r; - while (true) { - r.sock = sock_listen.Accept(); - if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) { - utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15"); - break; - } else { - // not a valid child - r.sock.Close(); - } - } - links.push_back(r); - } - // close listening sockets - sock_listen.Close(); - // setup selecter - selecter.Clear(); - for (size_t i = 0; i < links.size(); ++i) { - // set the socket to non-blocking mode - links[i].sock.SetNonBlock(true); - selecter.WatchRead(links[i].sock); - selecter.WatchWrite(links[i].sock); - } - // done - } - /*! \brief get rank */ - virtual int GetRank(void) const { - return rank; - } - /*! \brief get rank */ - virtual int GetWorldSize(void) const { - return world_size; - } - /*! \brief get rank */ - virtual std::string GetHost(void) const { - return host_uri; - } - /*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe - * \param sendrecvbuf_ buffer for both sending and recving data - * \param type_n4bytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - */ - virtual void AllReduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - if (links.size() == 0) return; - // total size of message - const size_t total_size = type_nbytes * count; - // number of links - const int nlink = static_cast(links.size()); - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - // size of space that we already performs reduce in up pass - size_t size_up_reduce = 0; - // size of space that we have already passed to parent - size_t size_up_out = 0; - // size of message we received, and send in the down pass - size_t size_down_in = 0; - - // initialize the link ring-buffer and pointer - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - links[i].InitBuffer(type_nbytes, count, reduce_buffer_size); - } - links[i].ResetSize(); - } - // if no childs, no need to reduce - if (nlink == static_cast(parent_index != -1)) { - size_up_reduce = total_size; - } - - // while we have not passed the messages out - while(true) { - selecter.Select(); - // read data from childs - for (int i = 0; i < nlink; ++i) { - if (i != parent_index && selecter.CheckRead(links[i].sock)) { - if (!links[i].ReadToRingBuffer(size_up_out)) { - utils::Socket::Error("Recv"); - } - } - } - // this node have childs, peform reduce - if (nlink > static_cast(parent_index != -1)) { - size_t buffer_size = 0; - // do upstream reduce - size_t max_reduce = total_size; - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - max_reduce= std::min(max_reduce, links[i].size_read); - utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size, - "buffer size inconsistent"); - buffer_size = links[i].buffer_size; - } - } - utils::Assert(buffer_size != 0, "must assign buffer_size"); - // round to type_n4bytes - max_reduce = (max_reduce / type_nbytes * type_nbytes); - // peform reduce, can be at most two rounds - while (size_up_reduce < max_reduce) { - // start position - size_t start = size_up_reduce % buffer_size; - // peform read till end of buffer - size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce); - utils::Assert(nread % type_nbytes == 0, "AllReduce: size check"); - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - reducer(links[i].buffer_head + start, - sendrecvbuf + size_up_reduce, - static_cast(nread / type_nbytes), - MPI::Datatype(type_nbytes)); - } - } - size_up_reduce += nread; - } - } - if (parent_index != -1) { - // pass message up to parent, can pass data that are already been reduced - if (selecter.CheckWrite(links[parent_index].sock)) { - ssize_t len = links[parent_index].sock. - Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); - if (len != -1) { - size_up_out += static_cast(len); - } else { - if (errno != EAGAIN && errno != EWOULDBLOCK) utils::Socket::Error("Recv"); - } - } - // read data from parent - if (selecter.CheckRead(links[parent_index].sock)) { - ssize_t len = links[parent_index].sock. - Recv(sendrecvbuf + size_down_in, total_size - size_down_in); - if (len != -1) { - size_down_in += static_cast(len); - utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error"); - } else { - if (errno != EAGAIN && errno != EWOULDBLOCK) utils::Socket::Error("Recv"); - } - } - } else { - // this is root, can use reduce as most recent point - size_down_in = size_up_out = size_up_reduce; - } - // check if we finished the job of message passing - size_t nfinished = size_down_in; - // can pass message down to childs - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - if (selecter.CheckWrite(links[i].sock)) { - if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) { - utils::Socket::Error("Send"); - } - } - nfinished = std::min(links[i].size_write, nfinished); - } - } - // check boundary condition - if (nfinished >= total_size) break; - } - } - /*! - * \brief broadcast data from root to all nodes - * \param sendrecvbuf_ buffer for both sending and recving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - */ - virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { - if (links.size() == 0) return; - // number of links - const int nlink = static_cast(links.size()); - // size of space already read from data - size_t size_in = 0; - // input link, -2 means unknown yet, -1 means this is root - int in_link = -2; - - // initialize the link statistics - for (int i = 0; i < nlink; ++i) { - links[i].ResetSize(); - } - // root have all the data - if (this->rank == root) { - size_in = total_size; - in_link = -1; - } - - // while we have not passed the messages out - while(true) { - selecter.Select(); - if (in_link == -2) { - // probe in-link - for (int i = 0; i < nlink; ++i) { - if (selecter.CheckRead(links[i].sock)) { - if (!links[i].ReadToArray(sendrecvbuf_, total_size)) { - utils::Socket::Error("Recv"); - } - size_in = links[i].size_read; - if (size_in != 0) { - in_link = i; break; - } - } - } - } else { - // read from in link - if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { - if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) { - utils::Socket::Error("Recv"); - } - size_in = links[in_link].size_read; - } - } - size_t nfinished = total_size; - // send data to all out-link - for (int i = 0; i < nlink; ++i) { - if (i != in_link) { - if (selecter.CheckWrite(links[i].sock)) { - if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) { - utils::Socket::Error("Send"); - } - } - nfinished = std::min(nfinished, links[i].size_write); - } - } - // check boundary condition - if (nfinished >= total_size) break; - } - } - virtual bool LoadCheckPoint(utils::ISerializable *p_model) { - return false; - } - virtual void CheckPoint(const utils::ISerializable &model) { - } - - private: - // an independent child record - struct LinkRecord { - public: - // socket to get data from/to link - utils::TCPSocket sock; - // size of data readed from link - size_t size_read; - // size of data sent to the link - size_t size_write; - // pointer to buffer head - char *buffer_head; - // buffer size, in bytes - size_t buffer_size; - // initialize buffer - inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { - size_t n = (type_nbytes * count + 7)/ 8; - buffer_.resize(std::min(reduce_buffer_size, n)); - // make sure align to type_nbytes - buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; - utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size); - // set buffer head - buffer_head = reinterpret_cast(BeginPtr(buffer_)); - } - // reset the recv and sent size - inline void ResetSize(void) { - size_write = size_read = 0; - } - /*! - * \brief read data into ring-buffer, with care not to existing useful override data - * position after protect_start - * \param protect_start all data start from protect_start is still needed in buffer - * read shall not override this - * \return true if it is an successful read, false if there is some error happens, check errno - */ - inline bool ReadToRingBuffer(size_t protect_start) { - size_t ngap = size_read - protect_start; - utils::Assert(ngap <= buffer_size, "AllReduce: boundary check"); - size_t offset = size_read % buffer_size; - size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); - ssize_t len = sock.Recv(buffer_head + offset, nmax); - if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; - size_read += static_cast(len); - return true; - } - /*! - * \brief read data into array, - * this function can not be used together with ReadToRingBuffer - * a link can either read into the ring buffer, or existing array - * \param max_size maximum size of array - * \return true if it is an successful read, false if there is some error happens, check errno - */ - inline bool ReadToArray(void *recvbuf_, size_t max_size) { - char *p = static_cast(recvbuf_); - ssize_t len = sock.Recv(p + size_read, max_size - size_read); - if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; - size_read += static_cast(len); - return true; - } - /*! - * \brief write data in array to sock - * \param sendbuf_ head of array - * \param max_size maximum size of array - * \return true if it is an successful write, false if there is some error happens, check errno - */ - inline bool WriteFromArray(const void *sendbuf_, size_t max_size) { - const char *p = static_cast(sendbuf_); - ssize_t len = sock.Send(p + size_write, max_size - size_write); - if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; - size_write += static_cast(len); - return true; - } - - private: - // recv buffer to get data from child - // aligned with 64 bits, will be able to perform 64 bits operations freely - std::vector buffer_; - }; - //------------------ - // uri of current host, to be set by Init - std::string host_uri; - // uri of master - std::string master_uri; - // port of master address - int master_port; - // port of slave process - int slave_port, nport_trial; - // reduce buffer size - size_t reduce_buffer_size; - // current rank - int rank; - // world size - int world_size; - // index of parent link, can be -1, meaning this is root of the tree - int parent_index; - // sockets of all links - std::vector links; - // select helper - utils::SelectHelper selecter; - -}; - -// singleton sync manager -SyncManager manager; - -/*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { - for (int i = 1; i < argc; ++i) { - char name[256], val[256]; - if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { - manager.SetParam(name, val); - } - } - manager.Init(); -} - -/*! \brief finalize syncrhonization module */ -void Finalize(void) { - manager.Shutdown(); -} -/*! \brief singleton method to get engine */ -IEngine *GetEngine(void) { - return &manager; -} - -} // namespace engine diff --git a/src/utils.h b/src/utils.h index 2c529c449..81bba7dfd 100644 --- a/src/utils.h +++ b/src/utils.h @@ -76,6 +76,9 @@ inline void HandleCheckError(const char *msg) { inline void HandlePrint(const char *msg) { printf("%s", msg); } +inline void HandleLogPrint(const char *msg) { + fprintf(stderr, "%s", msg); +} #else #ifndef ALLREDUCE_STRICT_CXX98_ // include declarations, some one must implement this @@ -101,6 +104,15 @@ inline void Printf(const char *fmt, ...) { va_end(args); HandlePrint(msg.c_str()); } +/*! \brief printf, print message to the console */ +inline void LogPrintf(const char *fmt, ...) { + std::string msg(kPrintBuffer, '\0'); + va_list args; + va_start(args, fmt); + vsnprintf(&msg[0], kPrintBuffer, fmt, args); + va_end(args); + HandleLogPrint(msg.c_str()); +} /*! \brief portable version of snprintf */ inline int SPrintf(char *buf, size_t size, const char *fmt, ...) { va_list args; diff --git a/test/Makefile b/test/Makefile index c773fe45b..49aca06e1 100644 --- a/test/Makefile +++ b/test/Makefile @@ -12,23 +12,25 @@ endif # specify tensor path BIN = test_allreduce -OBJ = engine_robust.o engine_tcp.o +OBJ = engine_base.o engine_robust.o engine.o .PHONY: clean all all: $(BIN) $(MPIBIN) engine_tcp.o: ../src/engine_tcp.cpp ../src/*.h -engine_robust.o: ../src/engine_robust.cpp ../src/*.h -test_allreduce: test_allreduce.cpp ../src/*.h engine_robust.o +engine_base.o: ../src/engine_base.cc ../src/*.h +engine.o: ../src/engine.cc ../src/*.h +engine_robust.o: ../src/engine_robust.cc ../src/*.h +test_allreduce: test_allreduce.cpp ../src/*.h $(OBJ) $(BIN) : - $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) + $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(OBJ) : - $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) ) + $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(MPIBIN) : - $(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) + $(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) clean: $(RM) $(OBJ) $(BIN) $(MPIBIN) *~ ../src/*~ diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 40c85ea0b..3a2cc2a9d 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -74,11 +74,11 @@ int main(int argc, char *argv[]) { test::Mock mock(rank, argv[2], argv[3]); - printf("[%d] start at %s\n", rank, name.c_str()); + utils::LogPrintf("[%d] start at %s\n", rank, name.c_str()); TestMax(mock, n); - printf("[%d] !!!TestMax pass\n", rank); + utils::LogPrintf("[%d] !!!TestMax pass\n", rank); TestSum(mock, n); - printf("[%d] !!!TestSum pass\n", rank); + utils::LogPrintf("[%d] !!!TestSum pass\n", rank); sync::Finalize(); printf("[%d] all check pass\n", rank); return 0;