From 34f2f887b1556da0ac80d7429cd2df4c4e3ac72e Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 3 Dec 2014 09:59:13 -0800 Subject: [PATCH] add more broadcast and basic broadcast --- src/allreduce_base.cc | 216 +++++++++++++++++++++++++++--------------- src/allreduce_base.h | 12 +++ src/mock.h | 2 +- src/rabit-inl.h | 32 +++++-- src/rabit.h | 24 ++++- src/socket.h | 25 +++++ src/tcp_master.py | 106 --------------------- 7 files changed, 225 insertions(+), 192 deletions(-) delete mode 100644 src/tcp_master.py diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 0cfb9fd34..42c7afaab 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -19,90 +19,20 @@ AllreduceBase::AllreduceBase(void) { host_uri = ""; slave_port = 9010; nport_trial = 1000; - rank = 0; + rank = -1; world_size = 1; version_number = 0; + job_id = "NULL"; this->SetParam("reduce_buffer", "256MB"); } // initialization function void AllreduceBase::Init(void) { utils::Socket::Startup(); - // single node mode - if (master_uri == "NULL") return; utils::Assert(links.size() == 0, "can only call Init once"); - int magic = kMagic; - int nchild = 0, nparent = 0; this->host_uri = utils::SockAddr::GetHostName(); // get information from master - utils::TCPSocket master; - master.Create(); - if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) { - utils::Socket::Error("Connect"); - } - utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1"); - utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2"); - utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); - utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3"); - utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4"); - utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5"); - utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6"); - utils::Assert(nchild >= 0, "in correct number of childs"); - utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent"); - - // create listen - utils::TCPSocket sock_listen; - sock_listen.Create(); - int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); - utils::Check(port != -1, "sync::Init fail to bind the ports specified"); - sock_listen.Listen(); - - if (nparent != 0) { - parent_index = 0; - links.push_back(LinkRecord()); - int len, hport; - std::string hname; - utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9"); - hname.resize(len); - utils::Assert(len != 0, "string must not be empty"); - utils::Assert(master.RecvAll(&hname[0], len) == static_cast(len), "sync::Init failure 10"); - utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11"); - links[0].sock.Create(); - links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport)); - utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12"); - utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13"); - utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch"); - parent_index = 0; - } else { - parent_index = -1; - } - // send back socket listening port to master - utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14"); - // close connection to master - master.Close(); - // accept links from childs - for (int i = 0; i < nchild; ++i) { - LinkRecord r; - while (true) { - r.sock = sock_listen.Accept(); - if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) { - utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15"); - break; - } else { - // not a valid child - r.sock.Close(); - } - } - links.push_back(r); - } - // close listening sockets - sock_listen.Close(); - // setup selecter - for (size_t i = 0; i < links.size(); ++i) { - // set the socket to non-blocking mode - links[i].sock.SetNonBlock(true); - } - // done + this->ReConnectLinks(); } void AllreduceBase::Shutdown(void) { @@ -110,6 +40,22 @@ void AllreduceBase::Shutdown(void) { links[i].sock.Close(); } links.clear(); + + if (master_uri == "NULL") return; + int magic = kMagic; + // notify master rank i have shutdown + utils::TCPSocket master; + master.Create(); + if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) { + utils::Socket::Error("Connect Master"); + } + utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1"); + utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2"); + utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); + + utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); + master.SendStr(job_id); + master.SendStr(std::string("shutdown")); utils::TCPSocket::Finalize(); } /*! @@ -120,6 +66,7 @@ void AllreduceBase::Shutdown(void) { void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "master_uri")) master_uri = val; if (!strcmp(name, "master_port")) master_port = atoi(val); + if (!strcmp(name, "job_id")) job_id = val; if (!strcmp(name, "reduce_buffer")) { char unit; unsigned long amount; @@ -136,7 +83,129 @@ void AllreduceBase::SetParam(const char *name, const char *val) { } } } +/*! + * \brief connect to the master to fix the the missing links + * this function is also used when the engine start up + */ +void AllreduceBase::ReConnectLinks(void) { + // single node mode + if (master_uri == "NULL") { + rank = 0; return; + } + int magic = kMagic; + // get information from master + utils::TCPSocket master; + master.Create(); + if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) { + utils::Socket::Error("Connect"); + } + utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1"); + utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2"); + utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); + utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); + master.SendStr(job_id); + master.SendStr(std::string("start")); + {// get new ranks + int newrank; + utils::Assert(master.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), + "ReConnectLink failure 4"); + utils::Assert(master.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank), + "ReConnectLink failure 4"); + utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one"); + rank = newrank; + } + + // create listening socket + utils::TCPSocket sock_listen; + sock_listen.Create(); + int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); + utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); + sock_listen.Listen(); + + // get number of to connect and number of to accept nodes from master + int num_conn, num_accept, num_error = 1; + + do { + // send over good links + std::vector good_link; + for (size_t i = 0; i < links.size(); ++i) { + if (!links[i].sock.BadSocket()) { + good_link.push_back(static_cast(links[i].rank)); + } else { + if (!links[i].sock.IsClosed()) links[i].sock.Close(); + } + } + int ngood = static_cast(good_link.size()); + utils::Assert(master.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), + "ReConnectLink failure 5"); + for (size_t i = 0; i < good_link.size(); ++i) { + utils::Assert(master.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]), + "ReConnectLink failure 6"); + } + utils::Assert(master.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), + "ReConnectLink failure 7"); + utils::Assert(master.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), + "ReConnectLink failure 8"); + num_error = 0; + for (int i = 0; i < num_conn; ++i) { + LinkRecord r; + int hport, hrank; + std::string hname; + master.RecvStr(&hname); + utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); + utils::Assert(master.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); + r.sock.Create(); + if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) { + num_error += 1; r.sock.Close(); continue; + } + utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12"); + utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13"); + utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent"); + bool match = false; + for (size_t i = 0; i < links.size(); ++i) { + if (links[i].rank == hrank) { + utils::Assert(links[i].sock.IsClosed(), "Override a link that is active"); + links[i].sock = r.sock; match = true; break; + } + } + if (!match) links.push_back(r); + } + utils::Assert(master.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14"); + } while (num_error != 0); + // send back socket listening port to master + utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); + // close connection to master + master.Close(); + // listen to incoming links + for (int i = 0; i < num_accept; ++i) { + LinkRecord r; + r.sock = sock_listen.Accept(); + utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15"); + utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15"); + bool match = false; + for (size_t i = 0; i < links.size(); ++i) { + if (links[i].rank == r.rank) { + utils::Assert(links[i].sock.IsClosed(), "Override a link that is active"); + links[i].sock = r.sock; match = true; break; + } + } + if (!match) links.push_back(r); + } + // close listening sockets + sock_listen.Close(); + this->parent_index = -1; + // setup selecter + for (size_t i = 0; i < links.size(); ++i) { + utils::Assert(!links[i].sock.BadSocket(), "ReConnectLink: bad socket"); + // set the socket to non-blocking mode + links[i].sock.SetNonBlock(true); + if (links[i].rank == parent_rank) parent_index = static_cast(i); + } + if (parent_rank != -1) { + utils::Assert(parent_index != -1, "cannot find parent in the link"); + } +} /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * @@ -209,7 +278,6 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_, finished = false; } } - } // finish runing allreduce if (finished) break; diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 578b941f1..d5172f9f7 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -138,6 +138,8 @@ class AllreduceBase : public IEngine { public: // socket to get data from/to link utils::TCPSocket sock; + // rank of the node in this link + int rank; // size of data readed from link size_t size_read; // size of data sent to the link @@ -222,6 +224,11 @@ class AllreduceBase : public IEngine { // aligned with 64 bits, will be able to perform 64 bits operations freely std::vector buffer_; }; + /*! + * \brief connect to the master to fix the the missing links + * this function is also used when the engine start up + */ + void ReConnectLinks(void); /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * @@ -255,9 +262,14 @@ class AllreduceBase : public IEngine { //---- local data related to link ---- // index of parent link, can be -1, meaning this is root of the tree int parent_index; + // rank of parent node, can be -1 + int parent_rank; // sockets of all links std::vector links; //----- meta information----- + // unique identifier of the possible job this process is doing + // used to assign ranks, optional, default to NULL + std::string job_id; // uri of current host, to be set by Init std::string host_uri; // uri of master diff --git a/src/mock.h b/src/mock.h index 5c85b841f..31c93d113 100644 --- a/src/mock.h +++ b/src/mock.h @@ -42,7 +42,7 @@ public: inline void Broadcast(std::string *sendrecv_data, int root) { utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank); - rabit::Bcast(sendrecv_data, root); + rabit::Broadcast(sendrecv_data, root); } diff --git a/src/rabit-inl.h b/src/rabit-inl.h index 4ea741efe..c38766bbb 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -91,16 +91,32 @@ inline int GetWorldSize(void) { inline std::string GetProcessorName(void) { return engine::GetEngine()->GetHost(); } -// broadcast an std::string to all others from root -inline void Bcast(std::string *sendrecv_data, int root) { - engine::IEngine *e = engine::GetEngine(); - unsigned len = static_cast(sendrecv_data->length()); - e->Broadcast(&len, sizeof(len), root); - sendrecv_data->resize(len); - if (len != 0) { - e->Broadcast(&(*sendrecv_data)[0], len, root); +// broadcast data to all other nodes from root +inline void Broadcast(void *sendrecv_data, size_t size, int root) { + engine::GetEngine()->Broadcast(sendrecv_data, size, root); +} +template +inline void Broadcast(std::vector *sendrecv_data, int root) { + size_t size = sendrecv_data->size(); + Broadcast(&size, sizeof(size), root); + if (sendrecv_data->size() != size) { + sendrecv_data->resize(size); + } + if (size != 0) { + Broadcast(&sendrecv_data[0], size * sizeof(DType), root); } } +inline void Broadcast(std::string *sendrecv_data, int root) { + size_t size = sendrecv_data->length(); + Broadcast(&size, sizeof(size), root); + if (sendrecv_data->length() != size) { + sendrecv_data->resize(size); + } + if (size != 0) { + Broadcast(&sendrecv_data[0], size * sizeof(char), root); + } +} + // perform inplace Allreduce template inline void Allreduce(DType *sendrecvbuf, size_t count) { diff --git a/src/rabit.h b/src/rabit.h index 859b5488a..99dd0b4d9 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -8,6 +8,8 @@ * * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou */ +#include +#include #include "./engine.h" /*! \brief namespace of rabit */ @@ -43,11 +45,27 @@ inline std::string GetProcessorName(void); /*! * \brief broadcast an std::string to all others from root * \param sendrecv_data the pointer to send or recive buffer, - * receive buffer does not need to be pre-allocated - * and string will be resized to correct length + * \param size the size of the data * \param root the root of process */ -inline void Bcast(std::string *sendrecv_data, int root); +inline void Broadcast(void *sendrecv_data, size_t size, int root); +/*! + * \brief broadcast an std::vector to all others from root + * \param sendrecv_data the pointer to send or recive vector, + * for receiver, the vector does not need to be pre-allocated + * \param root the root of process + * \tparam DType the data type stored in vector, have to be simple data type + * that can be directly send by sending the sizeof(DType) data + */ +template +inline void Broadcast(std::vector *sendrecv_data, int root); +/*! + * \brief broadcast an std::string to all others from root + * \param sendrecv_data the pointer to send or recive vector, + * for receiver, the vector does not need to be pre-allocated + * \param root the root of process + */ +inline void Broadcast(std::string *sendrecv_data, int root); /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe diff --git a/src/socket.h b/src/socket.h index 296b8aeea..3386b7d1d 100644 --- a/src/socket.h +++ b/src/socket.h @@ -331,6 +331,31 @@ class TCPSocket : public Socket{ } return ndone; } + /*! + * \brief send a string over network + * \param str the string to be sent + */ + inline void SendStr(const std::string &str) { + unsigned len = static_cast(str.length()); + utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len), + "error during send SendStr"); + utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(), + "error during send SendStr"); + } + /*! + * \brief recv a string from network + * \param out_str the string to receive + */ + inline void RecvStr(std::string *out_str) { + unsigned len; + utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len), + "error during send RecvStr"); + out_str->resize(len); + if (len != 0) { + utils::Assert(this->RecvAll(&(*out_str)[0], len) == len, + "error during send SendStr"); + } + } }; /*! \brief helper data structure to perform select */ diff --git a/src/tcp_master.py b/src/tcp_master.py deleted file mode 100644 index 015b48784..000000000 --- a/src/tcp_master.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -Master script for xgboost, tcp_master -This script can be used to start jobs of multi-node xgboost using sync_tcp - -Tianqi Chen -""" - -import sys -import os -import socket -import struct -import subprocess -from threading import Thread - -class ExSocket: - def __init__(self, sock): - self.sock = sock - def recvall(self, nbytes): - res = [] - sock = self.sock - nread = 0 - while nread < nbytes: - chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL) - nread += len(chunk) - res.append(chunk) - return ''.join(res) - def recvint(self): - return struct.unpack('@i', self.recvall(4))[0] - def sendint(self, n): - self.sock.sendall(struct.pack('@i', n)) - def sendstr(self, s): - self.sendint(len(s)) - self.sock.sendall(s) - -# magic number used to verify existence of data -kMagic = 0xff99 - -class Master: - def __init__(self, port = 9000, port_end = 9999): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - for port in range(port, port_end): - try: - sock.bind(('', port)) - self.port = port - break - except socket.error: - continue - sock.listen(16) - self.sock = sock - print 'start listen on %s:%d' % (socket.gethostname(), self.port) - def __del__(self): - self.sock.close() - def slave_args(self): - return ['master_uri=%s' % socket.gethostname(), - 'master_port=%s' % self.port] - def accept_slaves(self, nslave): - slave_addrs = [] - for rank in range(nslave): - while True: - fd, s_addr = self.sock.accept() - slave = ExSocket(fd) - nparent = int(rank != 0) - nchild = 0 - if (rank + 1) * 2 - 1 < nslave: - nchild += 1 - if (rank + 1) * 2 < nslave: - nchild += 1 - try: - magic = slave.recvint() - if magic != kMagic: - print 'invalid magic number=%d from %s' % (magic, s_addr[0]) - slave.sock.close() - continue - except socket.error: - print 'sock error in %s' % (s_addr[0]) - slave.sock.close() - continue - slave.sendint(kMagic) - slave.sendint(rank) - slave.sendint(nslave) - slave.sendint(nparent) - slave.sendint(nchild) - if nparent != 0: - parent_index = (rank + 1) / 2 - 1 - ptuple = slave_addrs[parent_index] - slave.sendstr(ptuple[0]) - slave.sendint(ptuple[1]) - s_port = slave.recvint() - assert rank == len(slave_addrs) - slave_addrs.append((s_addr[0], s_port)) - slave.sock.close() - print 'finish starting rank=%d at %s' % (rank, s_addr[0]) - break - print 'all slaves setup complete' - -def mpi_submit(nslave, args): - cmd = ' '.join(['mpirun -n %d' % nslave] + args) - print cmd - return subprocess.check_call(cmd, shell = True) - -def submit(nslave, args, fun_submit = mpi_submit): - master = Master() - submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args())) - submit_thread.start() - master.accept_slaves(nslave) - submit_thread.join()