add more broadcast and basic broadcast
This commit is contained in:
parent
20b51cc9ce
commit
34f2f887b1
@ -19,90 +19,20 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
host_uri = "";
|
host_uri = "";
|
||||||
slave_port = 9010;
|
slave_port = 9010;
|
||||||
nport_trial = 1000;
|
nport_trial = 1000;
|
||||||
rank = 0;
|
rank = -1;
|
||||||
world_size = 1;
|
world_size = 1;
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
|
job_id = "NULL";
|
||||||
this->SetParam("reduce_buffer", "256MB");
|
this->SetParam("reduce_buffer", "256MB");
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialization function
|
// initialization function
|
||||||
void AllreduceBase::Init(void) {
|
void AllreduceBase::Init(void) {
|
||||||
utils::Socket::Startup();
|
utils::Socket::Startup();
|
||||||
// single node mode
|
|
||||||
if (master_uri == "NULL") return;
|
|
||||||
utils::Assert(links.size() == 0, "can only call Init once");
|
utils::Assert(links.size() == 0, "can only call Init once");
|
||||||
int magic = kMagic;
|
|
||||||
int nchild = 0, nparent = 0;
|
|
||||||
this->host_uri = utils::SockAddr::GetHostName();
|
this->host_uri = utils::SockAddr::GetHostName();
|
||||||
// get information from master
|
// get information from master
|
||||||
utils::TCPSocket master;
|
this->ReConnectLinks();
|
||||||
master.Create();
|
|
||||||
if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) {
|
|
||||||
utils::Socket::Error("Connect");
|
|
||||||
}
|
|
||||||
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1");
|
|
||||||
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2");
|
|
||||||
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
|
|
||||||
utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3");
|
|
||||||
utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4");
|
|
||||||
utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5");
|
|
||||||
utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6");
|
|
||||||
utils::Assert(nchild >= 0, "in correct number of childs");
|
|
||||||
utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent");
|
|
||||||
|
|
||||||
// create listen
|
|
||||||
utils::TCPSocket sock_listen;
|
|
||||||
sock_listen.Create();
|
|
||||||
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
|
||||||
utils::Check(port != -1, "sync::Init fail to bind the ports specified");
|
|
||||||
sock_listen.Listen();
|
|
||||||
|
|
||||||
if (nparent != 0) {
|
|
||||||
parent_index = 0;
|
|
||||||
links.push_back(LinkRecord());
|
|
||||||
int len, hport;
|
|
||||||
std::string hname;
|
|
||||||
utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9");
|
|
||||||
hname.resize(len);
|
|
||||||
utils::Assert(len != 0, "string must not be empty");
|
|
||||||
utils::Assert(master.RecvAll(&hname[0], len) == static_cast<size_t>(len), "sync::Init failure 10");
|
|
||||||
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11");
|
|
||||||
links[0].sock.Create();
|
|
||||||
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
|
|
||||||
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12");
|
|
||||||
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13");
|
|
||||||
utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch");
|
|
||||||
parent_index = 0;
|
|
||||||
} else {
|
|
||||||
parent_index = -1;
|
|
||||||
}
|
|
||||||
// send back socket listening port to master
|
|
||||||
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14");
|
|
||||||
// close connection to master
|
|
||||||
master.Close();
|
|
||||||
// accept links from childs
|
|
||||||
for (int i = 0; i < nchild; ++i) {
|
|
||||||
LinkRecord r;
|
|
||||||
while (true) {
|
|
||||||
r.sock = sock_listen.Accept();
|
|
||||||
if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) {
|
|
||||||
utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15");
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
// not a valid child
|
|
||||||
r.sock.Close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
links.push_back(r);
|
|
||||||
}
|
|
||||||
// close listening sockets
|
|
||||||
sock_listen.Close();
|
|
||||||
// setup selecter
|
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
|
||||||
// set the socket to non-blocking mode
|
|
||||||
links[i].sock.SetNonBlock(true);
|
|
||||||
}
|
|
||||||
// done
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllreduceBase::Shutdown(void) {
|
void AllreduceBase::Shutdown(void) {
|
||||||
@ -110,6 +40,22 @@ void AllreduceBase::Shutdown(void) {
|
|||||||
links[i].sock.Close();
|
links[i].sock.Close();
|
||||||
}
|
}
|
||||||
links.clear();
|
links.clear();
|
||||||
|
|
||||||
|
if (master_uri == "NULL") return;
|
||||||
|
int magic = kMagic;
|
||||||
|
// notify master rank i have shutdown
|
||||||
|
utils::TCPSocket master;
|
||||||
|
master.Create();
|
||||||
|
if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) {
|
||||||
|
utils::Socket::Error("Connect Master");
|
||||||
|
}
|
||||||
|
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
|
||||||
|
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
|
||||||
|
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
|
||||||
|
|
||||||
|
utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||||
|
master.SendStr(job_id);
|
||||||
|
master.SendStr(std::string("shutdown"));
|
||||||
utils::TCPSocket::Finalize();
|
utils::TCPSocket::Finalize();
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -120,6 +66,7 @@ void AllreduceBase::Shutdown(void) {
|
|||||||
void AllreduceBase::SetParam(const char *name, const char *val) {
|
void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||||
|
if (!strcmp(name, "job_id")) job_id = val;
|
||||||
if (!strcmp(name, "reduce_buffer")) {
|
if (!strcmp(name, "reduce_buffer")) {
|
||||||
char unit;
|
char unit;
|
||||||
unsigned long amount;
|
unsigned long amount;
|
||||||
@ -136,7 +83,129 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief connect to the master to fix the the missing links
|
||||||
|
* this function is also used when the engine start up
|
||||||
|
*/
|
||||||
|
void AllreduceBase::ReConnectLinks(void) {
|
||||||
|
// single node mode
|
||||||
|
if (master_uri == "NULL") {
|
||||||
|
rank = 0; return;
|
||||||
|
}
|
||||||
|
int magic = kMagic;
|
||||||
|
// get information from master
|
||||||
|
utils::TCPSocket master;
|
||||||
|
master.Create();
|
||||||
|
if (!master.Connect(utils::SockAddr(master_uri.c_str(), master_port))) {
|
||||||
|
utils::Socket::Error("Connect");
|
||||||
|
}
|
||||||
|
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
|
||||||
|
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
|
||||||
|
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
|
||||||
|
|
||||||
|
utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||||
|
master.SendStr(job_id);
|
||||||
|
master.SendStr(std::string("start"));
|
||||||
|
{// get new ranks
|
||||||
|
int newrank;
|
||||||
|
utils::Assert(master.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||||
|
"ReConnectLink failure 4");
|
||||||
|
utils::Assert(master.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
|
||||||
|
"ReConnectLink failure 4");
|
||||||
|
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
|
||||||
|
rank = newrank;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create listening socket
|
||||||
|
utils::TCPSocket sock_listen;
|
||||||
|
sock_listen.Create();
|
||||||
|
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
||||||
|
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||||
|
sock_listen.Listen();
|
||||||
|
|
||||||
|
// get number of to connect and number of to accept nodes from master
|
||||||
|
int num_conn, num_accept, num_error = 1;
|
||||||
|
|
||||||
|
do {
|
||||||
|
// send over good links
|
||||||
|
std::vector<int> good_link;
|
||||||
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
|
if (!links[i].sock.BadSocket()) {
|
||||||
|
good_link.push_back(static_cast<int>(links[i].rank));
|
||||||
|
} else {
|
||||||
|
if (!links[i].sock.IsClosed()) links[i].sock.Close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int ngood = static_cast<int>(good_link.size());
|
||||||
|
utils::Assert(master.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||||
|
"ReConnectLink failure 5");
|
||||||
|
for (size_t i = 0; i < good_link.size(); ++i) {
|
||||||
|
utils::Assert(master.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]),
|
||||||
|
"ReConnectLink failure 6");
|
||||||
|
}
|
||||||
|
utils::Assert(master.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||||
|
"ReConnectLink failure 7");
|
||||||
|
utils::Assert(master.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
|
||||||
|
"ReConnectLink failure 8");
|
||||||
|
num_error = 0;
|
||||||
|
for (int i = 0; i < num_conn; ++i) {
|
||||||
|
LinkRecord r;
|
||||||
|
int hport, hrank;
|
||||||
|
std::string hname;
|
||||||
|
master.RecvStr(&hname);
|
||||||
|
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||||
|
utils::Assert(master.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||||
|
r.sock.Create();
|
||||||
|
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||||
|
num_error += 1; r.sock.Close(); continue;
|
||||||
|
}
|
||||||
|
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12");
|
||||||
|
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
|
||||||
|
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
|
||||||
|
bool match = false;
|
||||||
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
|
if (links[i].rank == hrank) {
|
||||||
|
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
|
||||||
|
links[i].sock = r.sock; match = true; break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!match) links.push_back(r);
|
||||||
|
}
|
||||||
|
utils::Assert(master.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
|
||||||
|
} while (num_error != 0);
|
||||||
|
// send back socket listening port to master
|
||||||
|
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||||
|
// close connection to master
|
||||||
|
master.Close();
|
||||||
|
// listen to incoming links
|
||||||
|
for (int i = 0; i < num_accept; ++i) {
|
||||||
|
LinkRecord r;
|
||||||
|
r.sock = sock_listen.Accept();
|
||||||
|
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
|
||||||
|
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
|
||||||
|
bool match = false;
|
||||||
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
|
if (links[i].rank == r.rank) {
|
||||||
|
utils::Assert(links[i].sock.IsClosed(), "Override a link that is active");
|
||||||
|
links[i].sock = r.sock; match = true; break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!match) links.push_back(r);
|
||||||
|
}
|
||||||
|
// close listening sockets
|
||||||
|
sock_listen.Close();
|
||||||
|
this->parent_index = -1;
|
||||||
|
// setup selecter
|
||||||
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
|
utils::Assert(!links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
|
// set the socket to non-blocking mode
|
||||||
|
links[i].sock.SetNonBlock(true);
|
||||||
|
if (links[i].rank == parent_rank) parent_index = static_cast<int>(i);
|
||||||
|
}
|
||||||
|
if (parent_rank != -1) {
|
||||||
|
utils::Assert(parent_index != -1, "cannot find parent in the link");
|
||||||
|
}
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
*
|
*
|
||||||
@ -209,7 +278,6 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
// finish runing allreduce
|
// finish runing allreduce
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
|
|||||||
@ -138,6 +138,8 @@ class AllreduceBase : public IEngine {
|
|||||||
public:
|
public:
|
||||||
// socket to get data from/to link
|
// socket to get data from/to link
|
||||||
utils::TCPSocket sock;
|
utils::TCPSocket sock;
|
||||||
|
// rank of the node in this link
|
||||||
|
int rank;
|
||||||
// size of data readed from link
|
// size of data readed from link
|
||||||
size_t size_read;
|
size_t size_read;
|
||||||
// size of data sent to the link
|
// size of data sent to the link
|
||||||
@ -222,6 +224,11 @@ class AllreduceBase : public IEngine {
|
|||||||
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
||||||
std::vector<uint64_t> buffer_;
|
std::vector<uint64_t> buffer_;
|
||||||
};
|
};
|
||||||
|
/*!
|
||||||
|
* \brief connect to the master to fix the the missing links
|
||||||
|
* this function is also used when the engine start up
|
||||||
|
*/
|
||||||
|
void ReConnectLinks(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
*
|
*
|
||||||
@ -255,9 +262,14 @@ class AllreduceBase : public IEngine {
|
|||||||
//---- local data related to link ----
|
//---- local data related to link ----
|
||||||
// index of parent link, can be -1, meaning this is root of the tree
|
// index of parent link, can be -1, meaning this is root of the tree
|
||||||
int parent_index;
|
int parent_index;
|
||||||
|
// rank of parent node, can be -1
|
||||||
|
int parent_rank;
|
||||||
// sockets of all links
|
// sockets of all links
|
||||||
std::vector<LinkRecord> links;
|
std::vector<LinkRecord> links;
|
||||||
//----- meta information-----
|
//----- meta information-----
|
||||||
|
// unique identifier of the possible job this process is doing
|
||||||
|
// used to assign ranks, optional, default to NULL
|
||||||
|
std::string job_id;
|
||||||
// uri of current host, to be set by Init
|
// uri of current host, to be set by Init
|
||||||
std::string host_uri;
|
std::string host_uri;
|
||||||
// uri of master
|
// uri of master
|
||||||
|
|||||||
@ -42,7 +42,7 @@ public:
|
|||||||
|
|
||||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||||
utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank);
|
utils::Assert(verify(broadcast), "[%d] error when broadcasting", rank);
|
||||||
rabit::Bcast(sendrecv_data, root);
|
rabit::Broadcast(sendrecv_data, root);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -91,16 +91,32 @@ inline int GetWorldSize(void) {
|
|||||||
inline std::string GetProcessorName(void) {
|
inline std::string GetProcessorName(void) {
|
||||||
return engine::GetEngine()->GetHost();
|
return engine::GetEngine()->GetHost();
|
||||||
}
|
}
|
||||||
// broadcast an std::string to all others from root
|
// broadcast data to all other nodes from root
|
||||||
inline void Bcast(std::string *sendrecv_data, int root) {
|
inline void Broadcast(void *sendrecv_data, size_t size, int root) {
|
||||||
engine::IEngine *e = engine::GetEngine();
|
engine::GetEngine()->Broadcast(sendrecv_data, size, root);
|
||||||
unsigned len = static_cast<unsigned>(sendrecv_data->length());
|
}
|
||||||
e->Broadcast(&len, sizeof(len), root);
|
template<typename DType>
|
||||||
sendrecv_data->resize(len);
|
inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
|
||||||
if (len != 0) {
|
size_t size = sendrecv_data->size();
|
||||||
e->Broadcast(&(*sendrecv_data)[0], len, root);
|
Broadcast(&size, sizeof(size), root);
|
||||||
|
if (sendrecv_data->size() != size) {
|
||||||
|
sendrecv_data->resize(size);
|
||||||
|
}
|
||||||
|
if (size != 0) {
|
||||||
|
Broadcast(&sendrecv_data[0], size * sizeof(DType), root);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||||
|
size_t size = sendrecv_data->length();
|
||||||
|
Broadcast(&size, sizeof(size), root);
|
||||||
|
if (sendrecv_data->length() != size) {
|
||||||
|
sendrecv_data->resize(size);
|
||||||
|
}
|
||||||
|
if (size != 0) {
|
||||||
|
Broadcast(&sendrecv_data[0], size * sizeof(char), root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// perform inplace Allreduce
|
// perform inplace Allreduce
|
||||||
template<typename OP, typename DType>
|
template<typename OP, typename DType>
|
||||||
inline void Allreduce(DType *sendrecvbuf, size_t count) {
|
inline void Allreduce(DType *sendrecvbuf, size_t count) {
|
||||||
|
|||||||
24
src/rabit.h
24
src/rabit.h
@ -8,6 +8,8 @@
|
|||||||
*
|
*
|
||||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||||
*/
|
*/
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
#include "./engine.h"
|
#include "./engine.h"
|
||||||
|
|
||||||
/*! \brief namespace of rabit */
|
/*! \brief namespace of rabit */
|
||||||
@ -43,11 +45,27 @@ inline std::string GetProcessorName(void);
|
|||||||
/*!
|
/*!
|
||||||
* \brief broadcast an std::string to all others from root
|
* \brief broadcast an std::string to all others from root
|
||||||
* \param sendrecv_data the pointer to send or recive buffer,
|
* \param sendrecv_data the pointer to send or recive buffer,
|
||||||
* receive buffer does not need to be pre-allocated
|
* \param size the size of the data
|
||||||
* and string will be resized to correct length
|
|
||||||
* \param root the root of process
|
* \param root the root of process
|
||||||
*/
|
*/
|
||||||
inline void Bcast(std::string *sendrecv_data, int root);
|
inline void Broadcast(void *sendrecv_data, size_t size, int root);
|
||||||
|
/*!
|
||||||
|
* \brief broadcast an std::vector<DType> to all others from root
|
||||||
|
* \param sendrecv_data the pointer to send or recive vector,
|
||||||
|
* for receiver, the vector does not need to be pre-allocated
|
||||||
|
* \param root the root of process
|
||||||
|
* \tparam DType the data type stored in vector, have to be simple data type
|
||||||
|
* that can be directly send by sending the sizeof(DType) data
|
||||||
|
*/
|
||||||
|
template<typename DType>
|
||||||
|
inline void Broadcast(std::vector<DType> *sendrecv_data, int root);
|
||||||
|
/*!
|
||||||
|
* \brief broadcast an std::string to all others from root
|
||||||
|
* \param sendrecv_data the pointer to send or recive vector,
|
||||||
|
* for receiver, the vector does not need to be pre-allocated
|
||||||
|
* \param root the root of process
|
||||||
|
*/
|
||||||
|
inline void Broadcast(std::string *sendrecv_data, int root);
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
|
|||||||
25
src/socket.h
25
src/socket.h
@ -331,6 +331,31 @@ class TCPSocket : public Socket{
|
|||||||
}
|
}
|
||||||
return ndone;
|
return ndone;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief send a string over network
|
||||||
|
* \param str the string to be sent
|
||||||
|
*/
|
||||||
|
inline void SendStr(const std::string &str) {
|
||||||
|
unsigned len = static_cast<int>(str.length());
|
||||||
|
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
|
||||||
|
"error during send SendStr");
|
||||||
|
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
|
||||||
|
"error during send SendStr");
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief recv a string from network
|
||||||
|
* \param out_str the string to receive
|
||||||
|
*/
|
||||||
|
inline void RecvStr(std::string *out_str) {
|
||||||
|
unsigned len;
|
||||||
|
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
|
||||||
|
"error during send RecvStr");
|
||||||
|
out_str->resize(len);
|
||||||
|
if (len != 0) {
|
||||||
|
utils::Assert(this->RecvAll(&(*out_str)[0], len) == len,
|
||||||
|
"error during send SendStr");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief helper data structure to perform select */
|
/*! \brief helper data structure to perform select */
|
||||||
|
|||||||
@ -1,106 +0,0 @@
|
|||||||
"""
|
|
||||||
Master script for xgboost, tcp_master
|
|
||||||
This script can be used to start jobs of multi-node xgboost using sync_tcp
|
|
||||||
|
|
||||||
Tianqi Chen
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import struct
|
|
||||||
import subprocess
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
class ExSocket:
|
|
||||||
def __init__(self, sock):
|
|
||||||
self.sock = sock
|
|
||||||
def recvall(self, nbytes):
|
|
||||||
res = []
|
|
||||||
sock = self.sock
|
|
||||||
nread = 0
|
|
||||||
while nread < nbytes:
|
|
||||||
chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL)
|
|
||||||
nread += len(chunk)
|
|
||||||
res.append(chunk)
|
|
||||||
return ''.join(res)
|
|
||||||
def recvint(self):
|
|
||||||
return struct.unpack('@i', self.recvall(4))[0]
|
|
||||||
def sendint(self, n):
|
|
||||||
self.sock.sendall(struct.pack('@i', n))
|
|
||||||
def sendstr(self, s):
|
|
||||||
self.sendint(len(s))
|
|
||||||
self.sock.sendall(s)
|
|
||||||
|
|
||||||
# magic number used to verify existence of data
|
|
||||||
kMagic = 0xff99
|
|
||||||
|
|
||||||
class Master:
|
|
||||||
def __init__(self, port = 9000, port_end = 9999):
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
for port in range(port, port_end):
|
|
||||||
try:
|
|
||||||
sock.bind(('', port))
|
|
||||||
self.port = port
|
|
||||||
break
|
|
||||||
except socket.error:
|
|
||||||
continue
|
|
||||||
sock.listen(16)
|
|
||||||
self.sock = sock
|
|
||||||
print 'start listen on %s:%d' % (socket.gethostname(), self.port)
|
|
||||||
def __del__(self):
|
|
||||||
self.sock.close()
|
|
||||||
def slave_args(self):
|
|
||||||
return ['master_uri=%s' % socket.gethostname(),
|
|
||||||
'master_port=%s' % self.port]
|
|
||||||
def accept_slaves(self, nslave):
|
|
||||||
slave_addrs = []
|
|
||||||
for rank in range(nslave):
|
|
||||||
while True:
|
|
||||||
fd, s_addr = self.sock.accept()
|
|
||||||
slave = ExSocket(fd)
|
|
||||||
nparent = int(rank != 0)
|
|
||||||
nchild = 0
|
|
||||||
if (rank + 1) * 2 - 1 < nslave:
|
|
||||||
nchild += 1
|
|
||||||
if (rank + 1) * 2 < nslave:
|
|
||||||
nchild += 1
|
|
||||||
try:
|
|
||||||
magic = slave.recvint()
|
|
||||||
if magic != kMagic:
|
|
||||||
print 'invalid magic number=%d from %s' % (magic, s_addr[0])
|
|
||||||
slave.sock.close()
|
|
||||||
continue
|
|
||||||
except socket.error:
|
|
||||||
print 'sock error in %s' % (s_addr[0])
|
|
||||||
slave.sock.close()
|
|
||||||
continue
|
|
||||||
slave.sendint(kMagic)
|
|
||||||
slave.sendint(rank)
|
|
||||||
slave.sendint(nslave)
|
|
||||||
slave.sendint(nparent)
|
|
||||||
slave.sendint(nchild)
|
|
||||||
if nparent != 0:
|
|
||||||
parent_index = (rank + 1) / 2 - 1
|
|
||||||
ptuple = slave_addrs[parent_index]
|
|
||||||
slave.sendstr(ptuple[0])
|
|
||||||
slave.sendint(ptuple[1])
|
|
||||||
s_port = slave.recvint()
|
|
||||||
assert rank == len(slave_addrs)
|
|
||||||
slave_addrs.append((s_addr[0], s_port))
|
|
||||||
slave.sock.close()
|
|
||||||
print 'finish starting rank=%d at %s' % (rank, s_addr[0])
|
|
||||||
break
|
|
||||||
print 'all slaves setup complete'
|
|
||||||
|
|
||||||
def mpi_submit(nslave, args):
|
|
||||||
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
|
||||||
print cmd
|
|
||||||
return subprocess.check_call(cmd, shell = True)
|
|
||||||
|
|
||||||
def submit(nslave, args, fun_submit = mpi_submit):
|
|
||||||
master = Master()
|
|
||||||
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
|
||||||
submit_thread.start()
|
|
||||||
master.accept_slaves(nslave)
|
|
||||||
submit_thread.join()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user