add more broadcast and basic broadcast

This commit is contained in:
tqchen 2014-12-03 09:59:13 -08:00
parent 20b51cc9ce
commit 34f2f887b1
7 changed files with 225 additions and 192 deletions

View File

@ -19,90 +19,20 @@ AllreduceBase::AllreduceBase(void) {
host_uri = ""; host_uri = "";
slave_port = 9010; slave_port = 9010;
nport_trial = 1000; nport_trial = 1000;
rank = 0; rank = -1;
world_size = 1; world_size = 1;
version_number = 0; version_number = 0;
job_id = "NULL";
this->SetParam("reduce_buffer", "256MB"); this->SetParam("reduce_buffer", "256MB");
} }
// initialization function // initialization function
void AllreduceBase::Init(void) { void AllreduceBase::Init(void) {
utils::Socket::Startup(); utils::Socket::Startup();
// single node mode
if (master_uri == "NULL") return;
utils::Assert(links.size() == 0, "can only call Init once"); utils::Assert(links.size() == 0, "can only call Init once");
int magic = kMagic;
int nchild = 0, nparent = 0;
this->host_uri = utils::SockAddr::GetHostName(); this->host_uri = utils::SockAddr::GetHostName();
// get information from master // get information from master
utils::TCPSocket master; this->ReConnectLinks();
master.Create();
if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) {
utils::Socket::Error("Connect");
}
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1");
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2");
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3");
utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4");
utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5");
utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6");
utils::Assert(nchild >= 0, "in correct number of childs");
utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent");
// create listen
utils::TCPSocket sock_listen;
sock_listen.Create();
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
utils::Check(port != -1, "sync::Init fail to bind the ports specified");
sock_listen.Listen();
if (nparent != 0) {
parent_index = 0;
links.push_back(LinkRecord());
int len, hport;
std::string hname;
utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9");
hname.resize(len);
utils::Assert(len != 0, "string must not be empty");
utils::Assert(master.RecvAll(&hname[0], len) == static_cast<size_t>(len), "sync::Init failure 10");
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11");
links[0].sock.Create();
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12");
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13");
utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch");
parent_index = 0;
} else {
parent_index = -1;
}
// send back socket listening port to master
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14");
// close connection to master
master.Close();
// accept links from childs
for (int i = 0; i < nchild; ++i) {
LinkRecord r;
while (true) {
r.sock = sock_listen.Accept();
if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) {
utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15");
break;
} else {
// not a valid child
r.sock.Close();
}
}
links.push_back(r);
}
// close listening sockets
sock_listen.Close();
// setup selecter
for (size_t i = 0; i < links.size(); ++i) {
// set the socket to non-blocking mode
links[i].sock.SetNonBlock(true);
}
// done
} }
void AllreduceBase::Shutdown(void) { void AllreduceBase::Shutdown(void) {
@ -110,6 +40,22 @@ void AllreduceBase::Shutdown(void) {
links[i].sock.Close(); links[i].sock.Close();
} }
links.clear(); links.clear();
if (master_uri == "NULL") return;
int magic = kMagic;
// notify master rank i have shutdown
utils::TCPSocket master;
master.Create();
if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) {
utils::Socket::Error("Connect Master");
}
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
master.SendStr(job_id);
master.SendStr(std::string("shutdown"));
utils::TCPSocket::Finalize(); utils::TCPSocket::Finalize();
} }
/*! /*!
@ -120,6 +66,7 @@ void AllreduceBase::Shutdown(void) {
void AllreduceBase::SetParam(const char *name, const char *val) { void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "master_uri")) master_uri = val; if (!strcmp(name, "master_uri")) master_uri = val;
if (!strcmp(name, "master_port")) master_port = atoi(val); if (!strcmp(name, "master_port")) master_port = atoi(val);
if (!strcmp(name, "job_id")) job_id = val;
if (!strcmp(name, "reduce_buffer")) { if (!strcmp(name, "reduce_buffer")) {
char unit; char unit;
unsigned long amount; unsigned long amount;
@ -136,7 +83,129 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
} }
} }
} }
/*!
* \brief connect to the master to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(void) {
// single node mode
if (master_uri == "NULL") {
rank = 0; return;
}
int magic = kMagic;
// get information from master
utils::TCPSocket master;
master.Create();
if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) {
utils::Socket::Error("Connect");
}
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
master.SendStr(job_id);
master.SendStr(std::string("start"));
{// get new ranks
int newrank;
utils::Assert(master.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
utils::Assert(master.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
"ReConnectLink failure 4");
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
rank = newrank;
}
// create listening socket
utils::TCPSocket sock_listen;
sock_listen.Create();
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
// get number of to connect and number of to accept nodes from master
int num_conn, num_accept, num_error = 1;
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < links.size(); ++i) {
if (!links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(links[i].rank));
} else {
if (!links[i].sock.IsClosed()) links[i].sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
utils::Assert(master.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
utils::Assert(master.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]),
"ReConnectLink failure 6");
}
utils::Assert(master.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
utils::Assert(master.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
"ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
master.RecvStr(&hname);
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
utils::Assert(master.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1; r.sock.Close(); continue;
}
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12");
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < links.size(); ++i) {
if (links[i].rank == hrank) {
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
links[i].sock = r.sock; match = true; break;
}
}
if (!match) links.push_back(r);
}
utils::Assert(master.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to master
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to master
master.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < links.size(); ++i) {
if (links[i].rank == r.rank) {
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
links[i].sock = r.sock; match = true; break;
}
}
if (!match) links.push_back(r);
}
// close listening sockets
sock_listen.Close();
this->parent_index = -1;
// setup selecter
for (size_t i = 0; i < links.size(); ++i) {
utils::Assert(!links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode
links[i].sock.SetNonBlock(true);
if (links[i].rank == parent_rank) parent_index = static_cast<int>(i);
}
if (parent_rank != -1) {
utils::Assert(parent_index != -1, "cannot find parent in the link");
}
}
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
* *
@ -209,7 +278,6 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
finished = false; finished = false;
} }
} }
} }
// finish runing allreduce // finish runing allreduce
if (finished) break; if (finished) break;

View File

@ -138,6 +138,8 @@ class AllreduceBase : public IEngine {
public: public:
// socket to get data from/to link // socket to get data from/to link
utils::TCPSocket sock; utils::TCPSocket sock;
// rank of the node in this link
int rank;
// size of data readed from link // size of data readed from link
size_t size_read; size_t size_read;
// size of data sent to the link // size of data sent to the link
@ -222,6 +224,11 @@ class AllreduceBase : public IEngine {
// aligned with 64 bits, will be able to perform 64 bits operations freely // aligned with 64 bits, will be able to perform 64 bits operations freely
std::vector<uint64_t> buffer_; std::vector<uint64_t> buffer_;
}; };
/*!
* \brief connect to the master to fix the the missing links
* this function is also used when the engine start up
*/
void ReConnectLinks(void);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
* *
@ -255,9 +262,14 @@ class AllreduceBase : public IEngine {
//---- local data related to link ---- //---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree // index of parent link, can be -1, meaning this is root of the tree
int parent_index; int parent_index;
// rank of parent node, can be -1
int parent_rank;
// sockets of all links // sockets of all links
std::vector<LinkRecord> links; std::vector<LinkRecord> links;
//----- meta information----- //----- meta information-----
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string job_id;
// uri of current host, to be set by Init // uri of current host, to be set by Init
std::string host_uri; std::string host_uri;
// uri of master // uri of master

View File

@ -42,7 +42,7 @@ public:
inline void Broadcast(std::string *sendrecv_data, int root) { inline void Broadcast(std::string *sendrecv_data, int root) {
utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank); utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank);
rabit::Bcast(sendrecv_data, root); rabit::Broadcast(sendrecv_data, root);
} }

View File

@ -91,16 +91,32 @@ inline int GetWorldSize(void) {
inline std::string GetProcessorName(void) { inline std::string GetProcessorName(void) {
return engine::GetEngine()->GetHost(); return engine::GetEngine()->GetHost();
} }
// broadcast an std::string to all others from root // broadcast data to all other nodes from root
inline void Bcast(std::string *sendrecv_data, int root) { inline void Broadcast(void *sendrecv_data, size_t size, int root) {
engine::IEngine *e = engine::GetEngine(); engine::GetEngine()->Broadcast(sendrecv_data, size, root);
unsigned len = static_cast<unsigned>(sendrecv_data->length()); }
e->Broadcast(&len, sizeof(len), root); template<typename DType>
sendrecv_data->resize(len); inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
if (len != 0) { size_t size = sendrecv_data->size();
e->Broadcast(&(*sendrecv_data)[0], len, root); Broadcast(&size, sizeof(size), root);
if (sendrecv_data->size() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&sendrecv_data[0], size * sizeof(DType), root);
} }
} }
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&sendrecv_data[0], size * sizeof(char), root);
}
}
// perform inplace Allreduce // perform inplace Allreduce
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count) { inline void Allreduce(DType *sendrecvbuf, size_t count) {

View File

@ -8,6 +8,8 @@
* *
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#include <string>
#include <vector>
#include "./engine.h" #include "./engine.h"
/*! \brief namespace of rabit */ /*! \brief namespace of rabit */
@ -43,11 +45,27 @@ inline std::string GetProcessorName(void);
/*! /*!
* \brief broadcast an std::string to all others from root * \brief broadcast an std::string to all others from root
* \param sendrecv_data the pointer to send or recive buffer, * \param sendrecv_data the pointer to send or recive buffer,
* receive buffer does not need to be pre-allocated * \param size the size of the data
* and string will be resized to correct length
* \param root the root of process * \param root the root of process
*/ */
inline void Bcast(std::string *sendrecv_data, int root); inline void Broadcast(void *sendrecv_data, size_t size, int root);
/*!
* \brief broadcast an std::vector<DType> to all others from root
* \param sendrecv_data the pointer to send or recive vector,
* for receiver, the vector does not need to be pre-allocated
* \param root the root of process
* \tparam DType the data type stored in vector, have to be simple data type
* that can be directly send by sending the sizeof(DType) data
*/
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
/*!
* \brief broadcast an std::string to all others from root
* \param sendrecv_data the pointer to send or recive vector,
* for receiver, the vector does not need to be pre-allocated
* \param root the root of process
*/
inline void Broadcast(std::string *sendrecv_data, int root);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe

View File

@ -331,6 +331,31 @@ class TCPSocket : public Socket{
} }
return ndone; return ndone;
} }
/*!
* \brief send a string over network
* \param str the string to be sent
*/
inline void SendStr(const std::string &str) {
unsigned len = static_cast<int>(str.length());
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
"error during send SendStr");
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
"error during send SendStr");
}
/*!
* \brief recv a string from network
* \param out_str the string to receive
*/
inline void RecvStr(std::string *out_str) {
unsigned len;
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
"error during send RecvStr");
out_str->resize(len);
if (len != 0) {
utils::Assert(this->RecvAll(&(*out_str)[0], len) == len,
"error during send SendStr");
}
}
}; };
/*! \brief helper data structure to perform select */ /*! \brief helper data structure to perform select */

View File

@ -1,106 +0,0 @@
"""
Master script for xgboost, tcp_master
This script can be used to start jobs of multi-node xgboost using sync_tcp
Tianqi Chen
"""
import sys
import os
import socket
import struct
import subprocess
from threading import Thread
class ExSocket:
def __init__(self, sock):
self.sock = sock
def recvall(self, nbytes):
res = []
sock = self.sock
nread = 0
while nread < nbytes:
chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL)
nread += len(chunk)
res.append(chunk)
return ''.join(res)
def recvint(self):
return struct.unpack('@i', self.recvall(4))[0]
def sendint(self, n):
self.sock.sendall(struct.pack('@i', n))
def sendstr(self, s):
self.sendint(len(s))
self.sock.sendall(s)
# magic number used to verify existence of data
kMagic = 0xff99
class Master:
def __init__(self, port = 9000, port_end = 9999):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
for port in range(port, port_end):
try:
sock.bind(('', port))
self.port = port
break
except socket.error:
continue
sock.listen(16)
self.sock = sock
print 'start listen on %s:%d' % (socket.gethostname(), self.port)
def __del__(self):
self.sock.close()
def slave_args(self):
return ['master_uri=%s' % socket.gethostname(),
'master_port=%s' % self.port]
def accept_slaves(self, nslave):
slave_addrs = []
for rank in range(nslave):
while True:
fd, s_addr = self.sock.accept()
slave = ExSocket(fd)
nparent = int(rank != 0)
nchild = 0
if (rank + 1) * 2 - 1 < nslave:
nchild += 1
if (rank + 1) * 2 < nslave:
nchild += 1
try:
magic = slave.recvint()
if magic != kMagic:
print 'invalid magic number=%d from %s' % (magic, s_addr[0])
slave.sock.close()
continue
except socket.error:
print 'sock error in %s' % (s_addr[0])
slave.sock.close()
continue
slave.sendint(kMagic)
slave.sendint(rank)
slave.sendint(nslave)
slave.sendint(nparent)
slave.sendint(nchild)
if nparent != 0:
parent_index = (rank + 1) / 2 - 1
ptuple = slave_addrs[parent_index]
slave.sendstr(ptuple[0])
slave.sendint(ptuple[1])
s_port = slave.recvint()
assert rank == len(slave_addrs)
slave_addrs.append((s_addr[0], s_port))
slave.sock.close()
print 'finish starting rank=%d at %s' % (rank, s_addr[0])
break
print 'all slaves setup complete'
def mpi_submit(nslave, args):
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
print cmd
return subprocess.check_call(cmd, shell = True)
def submit(nslave, args, fun_submit = mpi_submit):
master = Master()
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
submit_thread.start()
master.accept_slaves(nslave)
submit_thread.join()