From 2f1ba40786259f2aa15e85fa320dd44962cc78ce Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 27 Nov 2014 16:17:07 -0800 Subject: [PATCH] change in socket, to pass out error code --- src/engine_tcp.cpp | 68 ++++++++++---- src/socket.h | 223 +++++++++++++++++++++++++-------------------- 2 files changed, 173 insertions(+), 118 deletions(-) diff --git a/src/engine_tcp.cpp b/src/engine_tcp.cpp index 957319db4..e00b70f1c 100644 --- a/src/engine_tcp.cpp +++ b/src/engine_tcp.cpp @@ -76,7 +76,7 @@ class SyncManager : public IEngine { } // initialize the manager inline void Init(void) { - utils::TCPSocket::Startup(); + utils::Socket::Startup(); // single node mode if (master_uri == "NULL") return; utils::Assert(links.size() == 0, "can only call Init once"); @@ -86,7 +86,9 @@ class SyncManager : public IEngine { // get information from master utils::TCPSocket master; master.Create(); - master.Connect(utils::SockAddr(master_uri.c_str(), master_port)); + if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) { + utils::Socket::Error("Connect"); + } utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1"); utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2"); utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); @@ -213,7 +215,9 @@ class SyncManager : public IEngine { // read data from childs for (int i = 0; i < nlink; ++i) { if (i != parent_index && selecter.CheckRead(links[i].sock)) { - links[i].ReadToRingBuffer(size_up_out); + if (!links[i].ReadToRingBuffer(size_up_out)) { + utils::Socket::Error("Recv"); + } } } // this node have childs, peform reduce @@ -252,15 +256,25 @@ class SyncManager : public IEngine { } if (parent_index != -1) { // pass message up to parent, can pass data that are already been reduced - if (selecter.CheckWrite(links[parent_index].sock)) { - size_up_out += links[parent_index].sock. + if (selecter.CheckWrite(links[parent_index].sock)) { + ssize_t len = links[parent_index].sock. Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); + if (len != -1) { + size_up_out += static_cast(len); + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) utils::Socket::Error("Recv"); + } } // read data from parent if (selecter.CheckRead(links[parent_index].sock)) { - size_down_in += links[parent_index].sock. + ssize_t len = links[parent_index].sock. Recv(sendrecvbuf + size_down_in, total_size - size_down_in); - utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error"); + if (len != -1) { + size_down_in += static_cast(len); + utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error"); + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) utils::Socket::Error("Recv"); + } } } else { // this is root, can use reduce as most recent point @@ -272,7 +286,9 @@ class SyncManager : public IEngine { for (int i = 0; i < nlink; ++i) { if (i != parent_index) { if (selecter.CheckWrite(links[i].sock)) { - links[i].WriteFromArray(sendrecvbuf, size_down_in); + if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) { + utils::Socket::Error("Send"); + } } nfinished = std::min(links[i].size_write, nfinished); } @@ -317,7 +333,9 @@ class SyncManager : public IEngine { // probe in-link for (int i = 0; i < nlink; ++i) { if (selecter.CheckRead(links[i].sock)) { - links[i].ReadToArray(sendrecvbuf_, total_size); + if (!links[i].ReadToArray(sendrecvbuf_, total_size)) { + utils::Socket::Error("Recv"); + } size_in = links[i].size_read; if (size_in != 0) { in_link = i; break; @@ -327,7 +345,9 @@ class SyncManager : public IEngine { } else { // read from in link if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { - links[in_link].ReadToArray(sendrecvbuf_, total_size); + if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) { + utils::Socket::Error("Recv"); + } size_in = links[in_link].size_read; } } @@ -336,7 +356,9 @@ class SyncManager : public IEngine { for (int i = 0; i < nlink; ++i) { if (i != in_link) { if (selecter.CheckWrite(links[i].sock)) { - links[i].WriteFromArray(sendrecvbuf_, size_in); + if (!links[i].WriteFromArray(sendrecvbuf_, size_in)) { + utils::Socket::Error("Send"); + } } nfinished = std::min(nfinished, links[i].size_write); } @@ -390,32 +412,44 @@ class SyncManager : public IEngine { * position after protect_start * \param protect_start all data start from protect_start is still needed in buffer * read shall not override this + * \return true if it is an successful read, false if there is some error happens, check errno */ - inline void ReadToRingBuffer(size_t protect_start) { + inline bool ReadToRingBuffer(size_t protect_start) { size_t ngap = size_read - protect_start; utils::Assert(ngap <= buffer_size, "AllReduce: boundary check"); size_t offset = size_read % buffer_size; size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); - size_read += sock.Recv(buffer_head + offset, nmax); + ssize_t len = sock.Recv(buffer_head + offset, nmax); + if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; + size_read += static_cast(len); + return true; } /*! * \brief read data into array, * this function can not be used together with ReadToRingBuffer * a link can either read into the ring buffer, or existing array * \param max_size maximum size of array + * \return true if it is an successful read, false if there is some error happens, check errno */ - inline void ReadToArray(void *recvbuf_, size_t max_size) { + inline bool ReadToArray(void *recvbuf_, size_t max_size) { char *p = static_cast(recvbuf_); - size_read += sock.Recv(p + size_read, max_size - size_read); + ssize_t len = sock.Recv(p + size_read, max_size - size_read); + if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; + size_read += static_cast(len); + return true; } /*! * \brief write data in array to sock * \param sendbuf_ head of array * \param max_size maximum size of array + * \return true if it is an successful write, false if there is some error happens, check errno */ - inline void WriteFromArray(const void *sendbuf_, size_t max_size) { + inline bool WriteFromArray(const void *sendbuf_, size_t max_size) { const char *p = static_cast(sendbuf_); - size_write += sock.Send(p + size_write, max_size - size_write); + ssize_t len = sock.Send(p + size_write, max_size - size_write); + if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; + size_write += static_cast(len); + return true; } private: diff --git a/src/socket.h b/src/socket.h index a18d9a576..307ec89df 100644 --- a/src/socket.h +++ b/src/socket.h @@ -22,7 +22,6 @@ #include #include "./utils.h" -namespace utils { #if defined(_WIN32) typedef int ssize_t; typedef int sock_size_t; @@ -32,6 +31,7 @@ typedef size_t sock_size_t; const int INVALID_SOCKET = -1; #endif +namespace utils { /*! \brief data structure for network address */ struct SockAddr { sockaddr_in addr; @@ -74,36 +74,18 @@ struct SockAddr { return std::string(s); } }; + /*! - * \brief a wrapper of TCP socket that hopefully be cross platform + * \brief base class containing common operations of TCP and UDP sockets */ -class TCPSocket { +class Socket { public: /*! \brief the file descriptor of socket */ SOCKET sockfd; - // constructor - TCPSocket(void) : sockfd(INVALID_SOCKET) { - } - explicit TCPSocket(SOCKET sockfd) : sockfd(sockfd) { - } - ~TCPSocket(void) { - // do nothing in destructor - // user need to take care of close - } // default conversion to int inline operator SOCKET() const { return sockfd; } - /*! - * \brief create the socket, call this before using socket - * \param af domain - */ - inline void Create(int af = PF_INET) { - sockfd = socket(PF_INET, SOCK_STREAM, 0); - if (sockfd == INVALID_SOCKET) { - SockError("Create"); - } - } /*! * \brief start up the socket module * call this before using the sockets @@ -112,7 +94,7 @@ class TCPSocket { #ifdef _WIN32 WSADATA wsa_data; if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) { - SockError("Startup"); + Socket::Error("Startup"); } if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { WSACleanup(); @@ -137,12 +119,12 @@ class TCPSocket { #ifdef _WIN32 u_long mode = non_block ? 1 : 0; if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { - SockError("SetNonBlock"); + Socket::Error("SetNonBlock"); } #else int flag = fcntl(sockfd, F_GETFL, 0); if (flag == -1) { - SockError("SetNonBlock-1"); + Socket::Error("SetNonBlock-1"); } if (non_block) { flag |= O_NONBLOCK; @@ -150,10 +132,81 @@ class TCPSocket { flag &= ~O_NONBLOCK; } if (fcntl(sockfd, F_SETFL, flag) == -1) { - SockError("SetNonBlock-2"); + Socket::Error("SetNonBlock-2"); } #endif } + /*! + * \brief bind the socket to an address + * \param addr + */ + inline void Bind(const SockAddr &addr) { + if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { + Socket::Error("Bind"); + } + } + /*! + * \brief try bind the socket to host, from start_port to end_port + * \param start_port starting port number to try + * \param end_port ending port number to try + * \return the port successfully bind to, return -1 if failed to bind any port + */ + inline int TryBindHost(int start_port, int end_port) { + for (int port = start_port; port < end_port; ++port) { + SockAddr addr("0.0.0.0", port); + if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) { + return port; + } + if (errno != EADDRINUSE) { + Socket::Error("TryBindHost"); + } + } + return -1; + } + /*! \brief close the socket */ + inline void Close(void) { + if (sockfd != INVALID_SOCKET) { +#ifdef _WIN32 + closesocket(sockfd); +#else + close(sockfd); +#endif + sockfd = INVALID_SOCKET; + } else { + Error("Socket::Close double close the socket or close without create"); + } + } + + // report an socket error + inline static void Error(const char *msg) { + int errsv = errno; + utils::Error("Socket %s Error:%s", msg, strerror(errsv)); + } + protected: + explicit Socket(SOCKET sockfd) : sockfd(sockfd) { + } +}; + +/*! + * \brief a wrapper of TCP socket that hopefully be cross platform + */ +class TCPSocket : public Socket{ + public: + // constructor + TCPSocket(void) : Socket(INVALID_SOCKET) { + } + explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { + } + /*! + * \brief create the socket, call this before using socket + * \param af domain + */ + inline void Create(int af = PF_INET) { + sockfd = socket(PF_INET, SOCK_STREAM, 0); + if (sockfd == INVALID_SOCKET) { + Socket::Error("Create"); + } + } /*! * \brief perform listen of the socket * \param backlog backlog parameter @@ -165,93 +218,43 @@ class TCPSocket { TCPSocket Accept(void) { SOCKET newfd = accept(sockfd, NULL, NULL); if (newfd == INVALID_SOCKET) { - SockError("Accept"); + Socket::Error("Accept"); } return TCPSocket(newfd); } - /*! - * \brief bind the socket to an address - * \param addr - */ - inline void Bind(const SockAddr &addr) { - if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { - SockError("Bind"); - } - } - /*! - * \brief try bind the socket to host, from start_port to end_port - * \param start_port starting port number to try - * \param end_port ending port number to try - * \param out_addr the binding address, if successful - * \return whether the binding is successful - */ - inline int TryBindHost(int start_port, int end_port) { - for (int port = start_port; port < end_port; ++port) { - SockAddr addr("0.0.0.0", port); - if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) { - return port; - } - if (errno != EADDRINUSE) { - SockError("TryBindHost"); - } - } - return -1; - } /*! * \brief connect to an address * \param addr the address to connect to + * \return whether connect is successful */ - inline void Connect(const SockAddr &addr) { - if (connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { - SockError("Connect"); - } - } - /*! \brief close the connection */ - inline void Close(void) { - if (sockfd != -1) { -#ifdef _WIN32 - closesocket(sockfd); -#else - close(sockfd); -#endif - sockfd = INVALID_SOCKET; - } else { - Error("TCPSocket::Close double close the socket or close without create"); - } + inline bool Connect(const SockAddr &addr) { + return connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0; } /*! - * \brief send data using the socket + * \brief send data using the socket * \param buf the pointer to the buffer * \param len the size of the buffer * \param flags extra flags * \return size of data actually sent + * return -1 if error occurs */ - inline size_t Send(const void *buf_, size_t len, int flag = 0) { + inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { const char *buf = reinterpret_cast(buf_); if (len == 0) return 0; - ssize_t ret = send(sockfd, buf, static_cast(len), flag); - if (ret == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) return 0; - SockError("Send"); - } - return ret; - } + return send(sockfd, buf, static_cast(len), flag); + } /*! * \brief receive data using the socket * \param buf_ the pointer to the buffer * \param len the size of the buffer * \param flags extra flags - * \return size of data actually received + * \return size of data actually received + * return -1 if error occurs */ - inline size_t Recv(void *buf_, size_t len, int flags = 0) { + inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { char *buf = reinterpret_cast(buf_); - if (len == 0) return 0; - ssize_t ret = recv(sockfd, buf, static_cast(len), flags); - if (ret == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) return 0; - SockError("Recv"); - } - return ret; + if (len == 0) return 0; + return recv(sockfd, buf, static_cast(len), flags); } /*! * \brief peform block write that will attempt to send all data out @@ -267,7 +270,7 @@ class TCPSocket { ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); if (ret == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; - SockError("Recv"); + Socket::Error("SendAll"); } buf += ret; ndone += ret; @@ -288,7 +291,7 @@ class TCPSocket { ssize_t ret = recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); if (ret == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; - SockError("Recv"); + Socket::Error("RecvAll"); } if (ret == 0) return ndone; buf += ret; @@ -298,12 +301,8 @@ class TCPSocket { } private: - // report an socket error - inline static void SockError(const char *msg) { - int errsv = errno; - Error("Socket %s Error:%s", msg, strerror(errsv)); - } }; + /*! \brief helper data structure to perform select */ struct SelectHelper { public: @@ -326,6 +325,14 @@ struct SelectHelper { write_fds.push_back(fd); if (fd > maxfd) maxfd = fd; } + /*! + * \brief add file descriptor to watch for exception + * \param fd file descriptor to be watched + */ + inline void WatchException(SOCKET fd) { + except_fds.push_back(fd); + if (fd > maxfd) maxfd = fd; + } /*! * \brief Check if the descriptor is ready for read * \param fd file descriptor to check status @@ -340,12 +347,20 @@ struct SelectHelper { inline bool CheckWrite(SOCKET fd) const { return FD_ISSET(fd, &write_set) != 0; } + /*! + * \brief Check if the descriptor has any exception + * \param fd file descriptor to check status + */ + inline bool CheckExcept(SOCKET fd) const { + return FD_ISSET(fd, &except_set) != 0; + } /*! * \brief clear all the monitored descriptors */ inline void Clear(void) { read_fds.clear(); write_fds.clear(); + except_fds.clear(); maxfd = 0; } /*! @@ -356,20 +371,26 @@ struct SelectHelper { inline int Select(long timeout = 0) { FD_ZERO(&read_set); FD_ZERO(&write_set); + FD_ZERO(&except_set); for (size_t i = 0; i < read_fds.size(); ++i) { FD_SET(read_fds[i], &read_set); } for (size_t i = 0; i < write_fds.size(); ++i) { FD_SET(write_fds[i], &write_set); } + for (size_t i = 0; i < except_fds.size(); ++i) { + FD_SET(except_fds[i], &except_set); + } int ret; if (timeout == 0) { - ret = select(static_cast(maxfd + 1), &read_set, &write_set, NULL, NULL); + ret = select(static_cast(maxfd + 1), &read_set, + &write_set, &except_set, NULL); } else { timeval tm; tm.tv_usec = (timeout % 1000) * 1000; tm.tv_sec = timeout / 1000; - ret = select(static_cast(maxfd + 1), &read_set, &write_set, NULL, &tm); + ret = select(static_cast(maxfd + 1), &read_set, + &write_set, &except_set, &tm); } if (ret == -1) { int errsv = errno; @@ -380,8 +401,8 @@ struct SelectHelper { private: SOCKET maxfd; - fd_set read_set, write_set; - std::vector read_fds, write_fds; + fd_set read_set, write_set, except_set; + std::vector read_fds, write_fds, except_fds; }; } #endif