start testing allreduce
This commit is contained in:
parent
cb1c34aef0
commit
c499dd0f0c
108
src/sync/submit_tcp.py
Executable file
108
src/sync/submit_tcp.py
Executable file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/python
|
||||
"""
|
||||
Master script for xgboost submit_tcp
|
||||
This script can be used to start jobs of multi-node xgboost using sync_tcp
|
||||
|
||||
Tianqi Chen
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
|
||||
class ExSocket:
|
||||
def __init__(self, sock):
|
||||
self.sock = sock
|
||||
def recvall(self, nbytes):
|
||||
res = []
|
||||
sock = self.sock
|
||||
nread = 0
|
||||
while nread < nbytes:
|
||||
chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL)
|
||||
nread += len(chunk)
|
||||
res.append(chunk)
|
||||
return ''.join(res)
|
||||
def recvint(self):
|
||||
return struct.unpack('!i', self.recvall(4))[0]
|
||||
def sendint(self, n):
|
||||
self.sock.sendall(struct.pack('!i', n))
|
||||
def sendstr(self, s):
|
||||
self.sendint(len(s))
|
||||
self.sock.sendall(s)
|
||||
|
||||
# magic number used to verify existence of data
|
||||
kMagic = 0xff99
|
||||
|
||||
class Master:
|
||||
def __init__(self, port = 9000, port_end = 9999):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
for port in range(port, port_end):
|
||||
try:
|
||||
sock.bind(('', port))
|
||||
self.port = port
|
||||
break
|
||||
except socket.error:
|
||||
continue
|
||||
sock.listen(16)
|
||||
self.sock = sock
|
||||
print 'start listen on %s:%d' % (socket.gethostname(), self.port)
|
||||
def __del__(self):
|
||||
self.sock.close()
|
||||
def slave_args(self):
|
||||
return ['master_uri=%s' % socket.gethostname(),
|
||||
'master_port=%s' % self.port]
|
||||
def accept_slaves(self, nslave):
|
||||
slave_addrs = []
|
||||
for rank in range(nslave):
|
||||
while True:
|
||||
fd, s_addr = self.sock.accept()
|
||||
print 'accept connection from %s' % s_addr
|
||||
slave = ExSocket(fd)
|
||||
nparent = int(rank != 0)
|
||||
nchild = 0
|
||||
if (rank + 1) * 2 - 1 < nslave:
|
||||
nchild += 1
|
||||
if (rank + 1) * 2 < nslave:
|
||||
nchild += 1
|
||||
try:
|
||||
magic = slave.readint()
|
||||
if magic != kMagic:
|
||||
slave.sock.close()
|
||||
continue
|
||||
except socket.error:
|
||||
slave.sock.close()
|
||||
continue
|
||||
slave.sendint(kMagic)
|
||||
slave.sendint(rank)
|
||||
slave.sendint(nslave)
|
||||
slave.sendint(nparent)
|
||||
slave.sendint(nchild)
|
||||
if nparent != 0:
|
||||
parent_index = (rank + 1) / 2 - 1
|
||||
ptuple = slave_addrs[parent_index]
|
||||
slave.sendstr(ptuple[0])
|
||||
slave.sendint(ptuple[1])
|
||||
s_port = slave.recvint()
|
||||
assert rank == len(slave_addrs)
|
||||
slave_addrs.append(s_addr, s_port)
|
||||
break
|
||||
print 'all slaves setup complete'
|
||||
|
||||
def mpi_submit(nslave, args):
|
||||
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
||||
print cmd
|
||||
os.system(cmd)
|
||||
|
||||
def submit(nslave, args, fun_submit = mpi_submit):
|
||||
master = Master()
|
||||
fun_submit(nslave, args + master.slave_args())
|
||||
master.accept_slaves(nslave)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
print 'Usage: <nslave> <cmd>'
|
||||
exit(0)
|
||||
submit(int(sys.argv[1]), sys.argv[2:])
|
||||
|
||||
@ -23,16 +23,64 @@ class SyncManager {
|
||||
public:
|
||||
const static int kMagic = 0xff99;
|
||||
SyncManager(void) {
|
||||
master_uri = "localhost";
|
||||
master_uri = "NULL";
|
||||
master_port = 9000;
|
||||
host_uri = "";
|
||||
slave_port = 9010;
|
||||
nport_trial = 1000;
|
||||
rank = 0;
|
||||
world_size = 1;
|
||||
reduce_buffer_size = 128;
|
||||
}
|
||||
~SyncManager(void) {
|
||||
this->Shutdown();
|
||||
}
|
||||
inline void Shutdown(void) {
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
links[i].sock.Close();
|
||||
}
|
||||
links.clear();
|
||||
}
|
||||
/*! \brief set parameters to the sync manager */
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||
if (!strcmp(name, "reduce_buffer")) {
|
||||
char unit;
|
||||
unsigned long amount;
|
||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
||||
switch (unit) {
|
||||
case 'B': reduce_buffer_size = amount; break;
|
||||
case 'K': reduce_buffer_size = amount << 10UL; break;
|
||||
case 'M': reduce_buffer_size = amount << 20UL; break;
|
||||
case 'G': reduce_buffer_size = amount << 30UL; break;
|
||||
default: utils::Error("invalid format for reduce buffer");
|
||||
}
|
||||
} else {
|
||||
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
||||
}
|
||||
}
|
||||
}
|
||||
/*! \brief get rank */
|
||||
inline int GetRank(void) const {
|
||||
return rank;
|
||||
}
|
||||
/*! \brief check whether its distributed mode */
|
||||
inline bool IsDistributed(void) const {
|
||||
return links.size() != 0;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
inline int GetWorldSize(void) const {
|
||||
return world_size;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
inline std::string GetHost(void) const {
|
||||
return host_uri;
|
||||
}
|
||||
// initialize the manager
|
||||
inline void Init(void) {
|
||||
// single node mode
|
||||
if (master_uri == "NULL") return;
|
||||
utils::Assert(links.size() == 0, "can only call Init once");
|
||||
int magic = kMagic;
|
||||
int nchild = 0, nparent = 0;
|
||||
@ -108,29 +156,6 @@ class SyncManager {
|
||||
}
|
||||
// done
|
||||
}
|
||||
inline void Shutdown(void) {
|
||||
for (size_t i = 0; i < links.size(); ++i) {
|
||||
links[i].sock.Close();
|
||||
}
|
||||
links.clear();
|
||||
}
|
||||
/*! \brief set parameters to the sync manager */
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||
}
|
||||
/*! \brief get rank */
|
||||
inline int GetRank(void) const {
|
||||
return rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
inline int GetWorldSize(void) const {
|
||||
return world_size;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
inline std::string GetHost(void) const {
|
||||
return host_uri;
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -159,7 +184,9 @@ class SyncManager {
|
||||
|
||||
// initialize the link ring-buffer and pointer
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) links[i].InitBuffer(type_nbytes, count);
|
||||
if (i != parent_index) {
|
||||
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
|
||||
}
|
||||
links[i].ResetSize();
|
||||
}
|
||||
// if no childs, no need to reduce
|
||||
@ -301,8 +328,6 @@ class SyncManager {
|
||||
}
|
||||
}
|
||||
private:
|
||||
// 128 MB
|
||||
const static size_t kBufferSize = 128;
|
||||
// an independent child record
|
||||
struct LinkRecord {
|
||||
public:
|
||||
@ -317,10 +342,10 @@ class SyncManager {
|
||||
// buffer size, in bytes
|
||||
size_t buffer_size;
|
||||
// initialize buffer
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count) {
|
||||
utils::Assert(type_nbytes < kBufferSize, "too large type_nbytes");
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
||||
utils::Assert(type_nbytes < reduce_buffer_size, "too large type_nbytes");
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
buffer_.resize(std::min(kBufferSize, n));
|
||||
buffer_.resize(std::min(reduce_buffer_size, n));
|
||||
// make sure align to type_nbytes
|
||||
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
// set buffer head
|
||||
@ -377,6 +402,8 @@ class SyncManager {
|
||||
int master_port;
|
||||
// port of slave process
|
||||
int slave_port, nport_trial;
|
||||
// reduce buffer size
|
||||
size_t reduce_buffer_size;
|
||||
// current rank
|
||||
int rank;
|
||||
// world size
|
||||
@ -405,9 +432,8 @@ int GetWorldSize(void) {
|
||||
std::string GetProcessorName(void) {
|
||||
return manager.GetHost();
|
||||
}
|
||||
|
||||
bool IsDistributed(void) {
|
||||
return true;
|
||||
return manager.IsDistributed();
|
||||
}
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
|
||||
@ -11,19 +11,24 @@ else
|
||||
endif
|
||||
|
||||
# specify tensor path
|
||||
BIN = test_group_data test_quantile test_sock
|
||||
|
||||
BIN = test_group_data test_quantile test_allreduce
|
||||
OBJ = sync_tcp.o
|
||||
.PHONY: clean all
|
||||
|
||||
all: $(BIN) $(MPIBIN)
|
||||
|
||||
sync_tcp.o: ../src/sync/sync_tcp.cpp ../src/utils/*.h
|
||||
|
||||
test_group_data: test_group_data.cpp ../src/utils/*.h
|
||||
test_quantile: test_quantile.cpp ../src/utils/*.h
|
||||
test_sock: test_sock.cpp ../src/utils/*.h
|
||||
test_allreduce: test_allreduce.cpp ../src/utils/*.h ../src/sync/sync.h sync_tcp.o
|
||||
|
||||
$(BIN) :
|
||||
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||
|
||||
$(OBJ) :
|
||||
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
|
||||
|
||||
$(MPIBIN) :
|
||||
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||
|
||||
|
||||
22
test/test_allreduce.cpp
Normal file
22
test/test_allreduce.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
#include <sync/sync.h>
|
||||
|
||||
using namespace xgboost;
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
sync::Init(argc, argv);
|
||||
int rank = sync::GetRank();
|
||||
std::string name = sync::GetProcessorName().c_str();
|
||||
printf("start %s rank=%d\n", name.c_str(), rank);
|
||||
|
||||
std::vector<float> ndata(16);
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
ndata[i] = i + rank;
|
||||
}
|
||||
sync::AllReduce(&ndata[0], ndata.size(), sync::kMax);
|
||||
sync::Finalize();
|
||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||
printf("%lu: %f\n", i, ndata[i]);
|
||||
}
|
||||
printf("all end\n");
|
||||
return 0;
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user