From 67c5d8a2e6b0063c7cc7c586dfd64ea60cb1356c Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 22 Nov 2014 17:12:19 -0800 Subject: [PATCH] allreduce server side ok, need to add master --- src/sync/sync_tcp.cpp | 227 +++++++++++++++++++++++++++++++++++++++++- src/utils/socket.h | 145 +++++++++++++++++++++++---- 2 files changed, 350 insertions(+), 22 deletions(-) diff --git a/src/sync/sync_tcp.cpp b/src/sync/sync_tcp.cpp index 5e46a7119..cfd0d57cd 100644 --- a/src/sync/sync_tcp.cpp +++ b/src/sync/sync_tcp.cpp @@ -5,6 +5,8 @@ * \author Tianqi Chen */ #include +#include +#include #include "./sync.h" #include "../utils/socket.h" @@ -19,8 +21,113 @@ namespace sync { /*! \brief implementation of sync goes to here */ class SyncManager { public: + const static int kMagic = 0xff99; + SyncManager(void) { + master_uri = "localhost"; + master_port = 9000; + slave_port = 9010; + nport_trial = 1000; + } + ~SyncManager(void) { + this->Shutdown(); + } // initialize the manager - inline void Init(int argc, char *argv[]) { + inline void Init(void) { + utils::Assert(links.size() == 0, "can only call Init once"); + int magic = kMagic; + int nchild = 0, nparent = 0; + this->host_uri = utils::SockAddr::GetHostName(); + // get information from master + utils::TCPSocket master; + master.Create(); + master.Connect(utils::SockAddr(master_uri.c_str(), master_port)); + utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1"); + utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2"); + utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); + utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3"); + utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4"); + utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5"); + utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6"); + utils::Assert(nchild >= 0, "in correct number of childs"); + utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent"); + + // create listen + utils::TCPSocket sock_listen; + sock_listen.Create(); + int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); + utils::Check(port != -1, "sync::Init fail to bind the ports specified"); + sock_listen.Listen(); + + if (nparent != 0) { + parent_index = 0; + links.push_back(LinkRecord()); + int len, hport; + std::string hname; + utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9"); + hname.resize(len); + utils::Assert(len != 0, "string must not be empty"); + utils::Assert(master.RecvAll(&hname[0], len) == static_cast(len), "sync::Init failure 10"); + utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11"); + links[0].sock.Create(); + links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport)); + utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure"); + utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure"); + utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch"); + parent_index = 0; + } else { + parent_index = -1; + } + // send back socket listening port to master + utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 12"); + // close connection to master + master.Close(); + // accept links from childs + for (int i = 0; i < nchild; ++i) { + LinkRecord r; + while (true) { + r.sock = sock_listen.Accept(); + if (links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) { + utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure"); + break; + } else { + // not a valid child + r.sock.Close(); + } + } + links.push_back(r); + } + // close listening sockets + sock_listen.Close(); + // setup selecter + selecter.Clear(); + for (size_t i = 0; i < links.size(); ++i) { + selecter.WatchRead(links[i].sock); + selecter.WatchWrite(links[i].sock); + } + // done + } + inline void Shutdown(void) { + for (size_t i = 0; i < links.size(); ++i) { + links[i].sock.Close(); + } + links.clear(); + } + /*! \brief set parameters to the sync manager */ + inline void SetParam(const char *name, const char *val) { + if (!strcmp(name, "master_uri")) master_uri = val; + if (!strcmp(name, "master_port")) master_port = atoi(val); + } + /*! \brief get rank */ + inline int GetRank(void) const { + return rank; + } + /*! \brief get rank */ + inline int GetWorldSize(void) const { + return world_size; + } + /*! \brief get rank */ + inline std::string GetHost(void) const { + return host_uri; } /*! * \brief perform in-place allreduce, on sendrecvbuf @@ -259,8 +366,19 @@ class SyncManager { // aligned with 64 bits, will be able to perform 64 bits operations freely std::vector buffer_; }; + //------------------ + // uri of current host, to be set by Init + std::string host_uri; + // uri of master + std::string master_uri; + // port of master address + int master_port; + // port of slave process + int slave_port, nport_trial; // current rank int rank; + // world size + int world_size; // index of parent link, can be -1, meaning this is root of the tree int parent_index; // sockets of all links @@ -269,5 +387,112 @@ class SyncManager { utils::SelectHelper selecter; }; +// singleton sync manager +SyncManager manager; + +/*! \brief get rank of current process */ +int GetRank(void) { + return manager.GetRank(); +} +/*! \brief get total number of process */ +int GetWorldSize(void) { + return manager.GetWorldSize(); +} + +/*! \brief get name of processor */ +std::string GetProcessorName(void) { + return manager.GetHost(); +} + +bool IsDistributed(void) { + return true; +} +/*! \brief intiialize the synchronization module */ +void Init(int argc, char *argv[]) { + for (int i = 1; i < argc; ++i) { + char name[256], val[256]; + if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { + manager.SetParam(name, val); + } + } + manager.Init(); +} + +/*! \brief finalize syncrhonization module */ +void Finalize(void) { + manager.Shutdown(); +} + +// this can only be used for data that was smaller than 64 bit +template +inline void ReduceSum(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { + const DType *src = (const DType*)src_; + DType *dst = (DType*)dst_; + for (int i = 0; i < len; ++i) { + dst[i] += src[i]; + } +} +template +inline void ReduceMax(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { + const DType *src = (const DType*)src_; + DType *dst = (DType*)dst_; + for (int i = 0; i < len; ++i) { + if (src[i] > dst[i]) dst[i] = src[i]; + } +} +template +inline void ReduceBitOR(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { + const DType *src = (const DType*)src_; + DType *dst = (DType*)dst_; + for (int i = 0; i < len; ++i) { + dst[i] |= src[i]; + } +} + +template<> +void AllReduce(uint32_t *sendrecvbuf, int count, ReduceOp op) { + typedef uint32_t DType; + switch(op) { + case kBitwiseOR: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceBitOR); return; + default: utils::Error("reduce op not supported"); + } +} + +template<> +void AllReduce(float *sendrecvbuf, int count, ReduceOp op) { + typedef float DType; + switch(op) { + case kSum: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceSum); return; + case kMax: manager.AllReduce(sendrecvbuf, sizeof(DType), count, ReduceMax); return; + default: utils::Error("unknown ReduceOp"); + } +} + +void Bcast(std::string *sendrecv_data, int root) { + unsigned len = static_cast(sendrecv_data->length()); + manager.Bcast(&len, sizeof(len), root); + sendrecv_data->resize(len); + if (len != 0) { + manager.Bcast(&(*sendrecv_data)[0], len, root); + } +} + +// code for reduce handle +ReduceHandle::ReduceHandle(void) : handle(NULL), htype(NULL) { +} +ReduceHandle::~ReduceHandle(void) {} + +int ReduceHandle::TypeSize(const MPI::Datatype &dtype) { + return dtype.type_size; +} +void ReduceHandle::Init(ReduceFunction redfunc, size_t type_n4bytes, bool commute) { + utils::Assert(handle == NULL, "cannot initialize reduce handle twice"); + handle = reinterpret_cast(redfunc); +} +void ReduceHandle::AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count) { + utils::Assert(handle != NULL, "must intialize handle to call AllReduce"); + manager.AllReduce(sendrecvbuf, type_n4bytes * 4, count, reinterpret_cast(handle)); +} + } // namespace sync } // namespace xgboost diff --git a/src/utils/socket.h b/src/utils/socket.h index 299c5468e..48b917d6a 100644 --- a/src/utils/socket.h +++ b/src/utils/socket.h @@ -28,28 +28,34 @@ struct SockAddr { SockAddr(const char *url, int port) { this->Set(url, port); } + inline static std::string GetHostName(void) { + std::string buf; buf.resize(256); + utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); + return std::string(buf.c_str()); + } /*! * \brief set the address * \param url the url of the address * \param port the port of address */ - inline void Set(const char *url, int port) { - hostent *hp = gethostbyname(url); - Check(hp != NULL, "cannot obtain address of %s", url); + inline void Set(const char *host, int port) { + hostent *hp = gethostbyname(host); + Check(hp != NULL, "cannot obtain address of %s", host); memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = htons(port); memcpy(&addr.sin_addr, hp->h_addr_list[0], hp->h_length); } + /*! \brief return port of the address*/ + inline int port(void) const { + return ntohs(addr.sin_port); + } /*! \return a string representation of the address */ - inline std::string ToString(void) const { + inline std::string AddrStr(void) const { std::string buf; buf.resize(256); - const char *s = inet_ntop(AF_INET, &addr, &buf[0], buf.length()); + const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length()); Assert(s != NULL, "cannot decode address"); - std::string res = s; - sprintf(&buf[0], "%u", ntohs(addr.sin_port)); - res += ":" + buf; - return res; + return std::string(s); } }; /*! @@ -60,11 +66,27 @@ class TCPSocket { /*! \brief the file descriptor of socket */ int sockfd; // constructor - TCPSocket(void) {} + TCPSocket(void) : sockfd(-1) { + } + explicit TCPSocket(int sockfd) : sockfd(sockfd) { + } + ~TCPSocket(void) { + if (sockfd != -1) this->Close(); + } // default conversion to int inline operator int() const { return sockfd; } + /*! + * \brief create the socket, call this before using socket + * \param af domain + */ + inline void Create(int af = PF_INET) { + sockfd = socket(PF_INET, SOCK_STREAM, 0); + if (sockfd == -1) { + SockError("Create", errno); + } + } /*! * \brief start up the socket module * call this before using the sockets @@ -79,9 +101,9 @@ class TCPSocket { /*! * \brief set this socket to use async I/O */ - inline void SetAsync(void) { + inline void SetNonBlock(void) { if (fcntl(sockfd, fcntl(sockfd, F_GETFL) | O_NONBLOCK) == -1) { - SockError("SetAsync", errno); + SockError("SetNonBlock", errno); } } /*! @@ -91,15 +113,42 @@ class TCPSocket { inline void Listen(int backlog = 16) { listen(sockfd, backlog); } + /*! \brief get a new connection */ + TCPSocket Accept(void) { + int newfd = accept(sockfd, NULL, NULL); + if (newfd == -1) { + SockError("Accept", errno); + } + return TCPSocket(newfd); + } /*! * \brief bind the socket to an address - * \param 3 + * \param addr */ inline void Bind(const SockAddr &addr) { if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { SockError("Bind", errno); } } + /*! + * \brief try bind the socket to host, from start_port to end_port + * \param start_port starting port number to try + * \param end_port ending port number to try + * \param out_addr the binding address, if successful + * \return whether the binding is successful + */ + inline int TryBindHost(int start_port, int end_port) { + for (int port = start_port; port < end_port; ++port) { + SockAddr addr("0.0.0.0", port); + if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) { + return port; + } + if (errno != EADDRINUSE) { + SockError("TryBindHost", errno); + } + } + return -1; + } /*! * \brief connect to an address * \param addr the address to connect to @@ -111,7 +160,11 @@ class TCPSocket { } /*! \brief close the connection */ inline void Close(void) { - close(sockfd); + if (sockfd != -1) { + close(sockfd); sockfd = -1; + } else { + Error("TCPSocket::Close double close the socket or close without create"); + } } /*! * \brief send data using the socket @@ -123,22 +176,72 @@ class TCPSocket { inline size_t Send(const void *buf, size_t len, int flag = 0) { if (len == 0) return 0; ssize_t ret = send(sockfd, buf, len, flag); - if (ret == -1) SockError("Send", errno); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return 0; + SockError("Send", errno); + } return ret; - } + } /*! - * \brief send data using the socket + * \brief receive data using the socket * \param buf the pointer to the buffer * \param len the size of the buffer * \param flags extra flags * \return size of data actually received */ inline size_t Recv(void *buf, size_t len, int flags = 0) { - if (len == 0) return 0; + if (len == 0) return 0; ssize_t ret = recv(sockfd, buf, len, flags); - if (ret == -1) SockError("Recv", errno); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return 0; + SockError("Recv", errno); + } return ret; - } + } + /*! + * \brief peform block write that will attempt to send all data out + * can still return smaller than request when error occurs + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \return size of data actually sent + */ + inline size_t SendAll(const void *buf_, size_t len) { + const char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = send(sockfd, buf, len, 0); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; + SockError("Recv", errno); + } + buf += ret; + ndone += ret; + } + return ndone; + } + /*! + * \brief peforma block read that will attempt to read all data + * can still return smaller than request when error occurs + * \param buf_ the buffer pointer + * \param len length of data to recv + * \return size of data actually sent + */ + inline size_t RecvAll(void *buf_, size_t len) { + char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = recv(sockfd, buf, len, MSG_WAITALL); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; + SockError("Recv", errno); + } + if (ret == 0) return ndone; + buf += ret; + ndone += ret; + } + return ndone; + } + private: // report an socket error inline static void SockError(const char *msg, int errsv) { @@ -216,7 +319,7 @@ struct SelectHelper { if (ret == -1) { int errsv = errno; char buf[256]; - Error("Select Error:%s", strerror_r(errsv, buf, sizeof(buf))); + Error("Select Error: %s", strerror_r(errsv, buf, sizeof(buf))); } return ret; }