basic test pass

This commit is contained in:
tqchen
2014-11-23 11:15:48 -08:00
parent c499dd0f0c
commit 115424826b
7 changed files with 195 additions and 53 deletions

View File

@@ -85,7 +85,8 @@ class ReduceHandle {
void AllReduce(void *sendrecvbuf, size_t type_n4bytes, size_t count);
/*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype);
private:
protected:
// handle data field
void *handle;
// handle to the type field

View File

@@ -1,7 +1,7 @@
/*!
* \file sync_tcp.cpp
* \brief implementation of sync AllReduce using TCP sockets
* with use async socket and tree-shape reduction
* with use non-block socket and tree-shape reduction
* \author Tianqi Chen
*/
#include <vector>
@@ -11,7 +11,8 @@
#include "../utils/socket.h"
namespace MPI {
struct Datatype {
class Datatype {
public:
size_t type_size;
Datatype(size_t type_size) : type_size(type_size) {}
};
@@ -30,7 +31,7 @@ class SyncManager {
nport_trial = 1000;
rank = 0;
world_size = 1;
reduce_buffer_size = 128;
this->SetParam("reduce_buffer", "256MB");
}
~SyncManager(void) {
this->Shutdown();
@@ -50,10 +51,10 @@ class SyncManager {
unsigned long amount;
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
switch (unit) {
case 'B': reduce_buffer_size = amount; break;
case 'K': reduce_buffer_size = amount << 10UL; break;
case 'M': reduce_buffer_size = amount << 20UL; break;
case 'G': reduce_buffer_size = amount << 30UL; break;
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
case 'K': reduce_buffer_size = amount << 7UL; break;
case 'M': reduce_buffer_size = amount << 17UL; break;
case 'G': reduce_buffer_size = amount << 27UL; break;
default: utils::Error("invalid format for reduce buffer");
}
} else {
@@ -117,16 +118,16 @@ class SyncManager {
utils::Assert(master.RecvAll(&hname[0], len) == static_cast<size_t>(len), "sync::Init failure 10");
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11");
links[0].sock.Create();
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure");
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure");
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12");
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13");
utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch");
parent_index = 0;
} else {
parent_index = -1;
}
// send back socket listening port to master
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 12");
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14");
// close connection to master
master.Close();
// accept links from childs
@@ -134,10 +135,10 @@ class SyncManager {
LinkRecord r;
while (true) {
r.sock = sock_listen.Accept();
if (links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) {
utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure");
if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) {
utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15");
break;
} else {
} else {
// not a valid child
r.sock.Close();
}
@@ -150,7 +151,7 @@ class SyncManager {
selecter.Clear();
for (size_t i = 0; i < links.size(); ++i) {
// set the socket to non-blocking mode
links[i].sock.SetNonBlock();
links[i].sock.SetNonBlock(true);
selecter.WatchRead(links[i].sock);
selecter.WatchWrite(links[i].sock);
}
@@ -343,11 +344,11 @@ class SyncManager {
size_t buffer_size;
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
utils::Assert(type_nbytes < reduce_buffer_size, "too large type_nbytes");
size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n));
// make sure align to type_nbytes
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes < buffer_size, "too large type_nbytes=%lu, buffer_size", type_nbytes, buffer_size);
// set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}

30
src/sync/submit_tcp.py → src/sync/tcp_master.py Executable file → Normal file
View File

@@ -1,6 +1,5 @@
#!/usr/bin/python
"""
Master script for xgboost submit_tcp
Master script for xgboost, tcp_master
This script can be used to start jobs of multi-node xgboost using sync_tcp
Tianqi Chen
@@ -11,6 +10,7 @@ import os
import socket
import struct
import subprocess
from threading import Thread
class ExSocket:
def __init__(self, sock):
@@ -25,9 +25,9 @@ class ExSocket:
res.append(chunk)
return ''.join(res)
def recvint(self):
return struct.unpack('!i', self.recvall(4))[0]
return struct.unpack('@i', self.recvall(4))[0]
def sendint(self, n):
self.sock.sendall(struct.pack('!i', n))
self.sock.sendall(struct.pack('@i', n))
def sendstr(self, s):
self.sendint(len(s))
self.sock.sendall(s)
@@ -58,7 +58,6 @@ class Master:
for rank in range(nslave):
while True:
fd, s_addr = self.sock.accept()
print 'accept connection from %s' % s_addr
slave = ExSocket(fd)
nparent = int(rank != 0)
nchild = 0
@@ -67,11 +66,13 @@ class Master:
if (rank + 1) * 2 < nslave:
nchild += 1
try:
magic = slave.readint()
magic = slave.recvint()
if magic != kMagic:
print 'invalid magic number=%d from %s' % (magic, s_addr[0])
slave.sock.close()
continue
except socket.error:
print 'sock error in %s' % (s_addr[0])
slave.sock.close()
continue
slave.sendint(kMagic)
@@ -86,23 +87,20 @@ class Master:
slave.sendint(ptuple[1])
s_port = slave.recvint()
assert rank == len(slave_addrs)
slave_addrs.append(s_addr, s_port)
slave_addrs.append((s_addr[0], s_port))
slave.sock.close()
print 'finish starting rank=%d at %s' % (rank, s_addr[0])
break
print 'all slaves setup complete'
def mpi_submit(nslave, args):
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
print cmd
os.system(cmd)
return subprocess.check_call(cmd, shell = True)
def submit(nslave, args, fun_submit = mpi_submit):
master = Master()
fun_submit(nslave, args + master.slave_args())
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
submit_thread.start()
master.accept_slaves(nslave)
if __name__ == '__main__':
if len(sys.argv) < 2:
print 'Usage: <nslave> <cmd>'
exit(0)
submit(int(sys.argv[1]), sys.argv[2:])
submit_thread.join()

View File

@@ -71,7 +71,8 @@ class TCPSocket {
explicit TCPSocket(int sockfd) : sockfd(sockfd) {
}
~TCPSocket(void) {
if (sockfd != -1) this->Close();
// do nothing in destructor
// user need to take care of close
}
// default conversion to int
inline operator int() const {
@@ -99,11 +100,22 @@ class TCPSocket {
inline static void Finalize(void) {
}
/*!
* \brief set this socket to use async I/O
* \brief set this socket to use non-blocking mode
* \param non_block whether set it to be non-block, if it is false
* it will set it back to block mode
*/
inline void SetNonBlock(void) {
if (fcntl(sockfd, fcntl(sockfd, F_GETFL) | O_NONBLOCK) == -1) {
SockError("SetNonBlock", errno);
inline void SetNonBlock(bool non_block) {
int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) {
SockError("SetNonBlock-1", errno);
}
if (non_block) {
flag |= O_NONBLOCK;
} else {
flag &= ~O_NONBLOCK;
}
if (fcntl(sockfd, F_SETFL, flag) == -1) {
SockError("SetNonBlock-2", errno);
}
}
/*!
@@ -209,7 +221,7 @@ class TCPSocket {
const char *buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = send(sockfd, buf, len, 0);
ssize_t ret = send(sockfd, buf, len - ndone, 0);
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
SockError("Recv", errno);
@@ -230,7 +242,7 @@ class TCPSocket {
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(sockfd, buf, len, MSG_WAITALL);
ssize_t ret = recv(sockfd, buf, len - ndone, MSG_WAITALL);
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
SockError("Recv", errno);