start testing allreduce
This commit is contained in:
parent
cb1c34aef0
commit
c499dd0f0c
108
src/sync/submit_tcp.py
Executable file
108
src/sync/submit_tcp.py
Executable file
@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
"""
|
||||||
|
Master script for xgboost submit_tcp
|
||||||
|
This script can be used to start jobs of multi-node xgboost using sync_tcp
|
||||||
|
|
||||||
|
Tianqi Chen
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
class ExSocket:
|
||||||
|
def __init__(self, sock):
|
||||||
|
self.sock = sock
|
||||||
|
def recvall(self, nbytes):
|
||||||
|
res = []
|
||||||
|
sock = self.sock
|
||||||
|
nread = 0
|
||||||
|
while nread < nbytes:
|
||||||
|
chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL)
|
||||||
|
nread += len(chunk)
|
||||||
|
res.append(chunk)
|
||||||
|
return ''.join(res)
|
||||||
|
def recvint(self):
|
||||||
|
return struct.unpack('!i', self.recvall(4))[0]
|
||||||
|
def sendint(self, n):
|
||||||
|
self.sock.sendall(struct.pack('!i', n))
|
||||||
|
def sendstr(self, s):
|
||||||
|
self.sendint(len(s))
|
||||||
|
self.sock.sendall(s)
|
||||||
|
|
||||||
|
# magic number used to verify existence of data
|
||||||
|
kMagic = 0xff99
|
||||||
|
|
||||||
|
class Master:
|
||||||
|
def __init__(self, port = 9000, port_end = 9999):
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
for port in range(port, port_end):
|
||||||
|
try:
|
||||||
|
sock.bind(('', port))
|
||||||
|
self.port = port
|
||||||
|
break
|
||||||
|
except socket.error:
|
||||||
|
continue
|
||||||
|
sock.listen(16)
|
||||||
|
self.sock = sock
|
||||||
|
print 'start listen on %s:%d' % (socket.gethostname(), self.port)
|
||||||
|
def __del__(self):
|
||||||
|
self.sock.close()
|
||||||
|
def slave_args(self):
|
||||||
|
return ['master_uri=%s' % socket.gethostname(),
|
||||||
|
'master_port=%s' % self.port]
|
||||||
|
def accept_slaves(self, nslave):
|
||||||
|
slave_addrs = []
|
||||||
|
for rank in range(nslave):
|
||||||
|
while True:
|
||||||
|
fd, s_addr = self.sock.accept()
|
||||||
|
print 'accept connection from %s' % s_addr
|
||||||
|
slave = ExSocket(fd)
|
||||||
|
nparent = int(rank != 0)
|
||||||
|
nchild = 0
|
||||||
|
if (rank + 1) * 2 - 1 < nslave:
|
||||||
|
nchild += 1
|
||||||
|
if (rank + 1) * 2 < nslave:
|
||||||
|
nchild += 1
|
||||||
|
try:
|
||||||
|
magic = slave.readint()
|
||||||
|
if magic != kMagic:
|
||||||
|
slave.sock.close()
|
||||||
|
continue
|
||||||
|
except socket.error:
|
||||||
|
slave.sock.close()
|
||||||
|
continue
|
||||||
|
slave.sendint(kMagic)
|
||||||
|
slave.sendint(rank)
|
||||||
|
slave.sendint(nslave)
|
||||||
|
slave.sendint(nparent)
|
||||||
|
slave.sendint(nchild)
|
||||||
|
if nparent != 0:
|
||||||
|
parent_index = (rank + 1) / 2 - 1
|
||||||
|
ptuple = slave_addrs[parent_index]
|
||||||
|
slave.sendstr(ptuple[0])
|
||||||
|
slave.sendint(ptuple[1])
|
||||||
|
s_port = slave.recvint()
|
||||||
|
assert rank == len(slave_addrs)
|
||||||
|
slave_addrs.append(s_addr, s_port)
|
||||||
|
break
|
||||||
|
print 'all slaves setup complete'
|
||||||
|
|
||||||
|
def mpi_submit(nslave, args):
|
||||||
|
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
||||||
|
print cmd
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
def submit(nslave, args, fun_submit = mpi_submit):
|
||||||
|
master = Master()
|
||||||
|
fun_submit(nslave, args + master.slave_args())
|
||||||
|
master.accept_slaves(nslave)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print 'Usage: <nslave> <cmd>'
|
||||||
|
exit(0)
|
||||||
|
submit(int(sys.argv[1]), sys.argv[2:])
|
||||||
|
|
||||||
@ -23,16 +23,64 @@ class SyncManager {
|
|||||||
public:
|
public:
|
||||||
const static int kMagic = 0xff99;
|
const static int kMagic = 0xff99;
|
||||||
SyncManager(void) {
|
SyncManager(void) {
|
||||||
master_uri = "localhost";
|
master_uri = "NULL";
|
||||||
master_port = 9000;
|
master_port = 9000;
|
||||||
|
host_uri = "";
|
||||||
slave_port = 9010;
|
slave_port = 9010;
|
||||||
nport_trial = 1000;
|
nport_trial = 1000;
|
||||||
|
rank = 0;
|
||||||
|
world_size = 1;
|
||||||
|
reduce_buffer_size = 128;
|
||||||
}
|
}
|
||||||
~SyncManager(void) {
|
~SyncManager(void) {
|
||||||
this->Shutdown();
|
this->Shutdown();
|
||||||
}
|
}
|
||||||
|
inline void Shutdown(void) {
|
||||||
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
|
links[i].sock.Close();
|
||||||
|
}
|
||||||
|
links.clear();
|
||||||
|
}
|
||||||
|
/*! \brief set parameters to the sync manager */
|
||||||
|
inline void SetParam(const char *name, const char *val) {
|
||||||
|
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||||
|
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||||
|
if (!strcmp(name, "reduce_buffer")) {
|
||||||
|
char unit;
|
||||||
|
unsigned long amount;
|
||||||
|
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
||||||
|
switch (unit) {
|
||||||
|
case 'B': reduce_buffer_size = amount; break;
|
||||||
|
case 'K': reduce_buffer_size = amount << 10UL; break;
|
||||||
|
case 'M': reduce_buffer_size = amount << 20UL; break;
|
||||||
|
case 'G': reduce_buffer_size = amount << 30UL; break;
|
||||||
|
default: utils::Error("invalid format for reduce buffer");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*! \brief get rank */
|
||||||
|
inline int GetRank(void) const {
|
||||||
|
return rank;
|
||||||
|
}
|
||||||
|
/*! \brief check whether its distributed mode */
|
||||||
|
inline bool IsDistributed(void) const {
|
||||||
|
return links.size() != 0;
|
||||||
|
}
|
||||||
|
/*! \brief get rank */
|
||||||
|
inline int GetWorldSize(void) const {
|
||||||
|
return world_size;
|
||||||
|
}
|
||||||
|
/*! \brief get rank */
|
||||||
|
inline std::string GetHost(void) const {
|
||||||
|
return host_uri;
|
||||||
|
}
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
inline void Init(void) {
|
inline void Init(void) {
|
||||||
|
// single node mode
|
||||||
|
if (master_uri == "NULL") return;
|
||||||
utils::Assert(links.size() == 0, "can only call Init once");
|
utils::Assert(links.size() == 0, "can only call Init once");
|
||||||
int magic = kMagic;
|
int magic = kMagic;
|
||||||
int nchild = 0, nparent = 0;
|
int nchild = 0, nparent = 0;
|
||||||
@ -108,29 +156,6 @@ class SyncManager {
|
|||||||
}
|
}
|
||||||
// done
|
// done
|
||||||
}
|
}
|
||||||
inline void Shutdown(void) {
|
|
||||||
for (size_t i = 0; i < links.size(); ++i) {
|
|
||||||
links[i].sock.Close();
|
|
||||||
}
|
|
||||||
links.clear();
|
|
||||||
}
|
|
||||||
/*! \brief set parameters to the sync manager */
|
|
||||||
inline void SetParam(const char *name, const char *val) {
|
|
||||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
|
||||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
|
||||||
}
|
|
||||||
/*! \brief get rank */
|
|
||||||
inline int GetRank(void) const {
|
|
||||||
return rank;
|
|
||||||
}
|
|
||||||
/*! \brief get rank */
|
|
||||||
inline int GetWorldSize(void) const {
|
|
||||||
return world_size;
|
|
||||||
}
|
|
||||||
/*! \brief get rank */
|
|
||||||
inline std::string GetHost(void) const {
|
|
||||||
return host_uri;
|
|
||||||
}
|
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
* this function is NOT thread-safe
|
* this function is NOT thread-safe
|
||||||
@ -159,7 +184,9 @@ class SyncManager {
|
|||||||
|
|
||||||
// initialize the link ring-buffer and pointer
|
// initialize the link ring-buffer and pointer
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (i != parent_index) links[i].InitBuffer(type_nbytes, count);
|
if (i != parent_index) {
|
||||||
|
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
|
||||||
|
}
|
||||||
links[i].ResetSize();
|
links[i].ResetSize();
|
||||||
}
|
}
|
||||||
// if no childs, no need to reduce
|
// if no childs, no need to reduce
|
||||||
@ -301,8 +328,6 @@ class SyncManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
// 128 MB
|
|
||||||
const static size_t kBufferSize = 128;
|
|
||||||
// an independent child record
|
// an independent child record
|
||||||
struct LinkRecord {
|
struct LinkRecord {
|
||||||
public:
|
public:
|
||||||
@ -317,10 +342,10 @@ class SyncManager {
|
|||||||
// buffer size, in bytes
|
// buffer size, in bytes
|
||||||
size_t buffer_size;
|
size_t buffer_size;
|
||||||
// initialize buffer
|
// initialize buffer
|
||||||
inline void InitBuffer(size_t type_nbytes, size_t count) {
|
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
||||||
utils::Assert(type_nbytes < kBufferSize, "too large type_nbytes");
|
utils::Assert(type_nbytes < reduce_buffer_size, "too large type_nbytes");
|
||||||
size_t n = (type_nbytes * count + 7)/ 8;
|
size_t n = (type_nbytes * count + 7)/ 8;
|
||||||
buffer_.resize(std::min(kBufferSize, n));
|
buffer_.resize(std::min(reduce_buffer_size, n));
|
||||||
// make sure align to type_nbytes
|
// make sure align to type_nbytes
|
||||||
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||||
// set buffer head
|
// set buffer head
|
||||||
@ -377,6 +402,8 @@ class SyncManager {
|
|||||||
int master_port;
|
int master_port;
|
||||||
// port of slave process
|
// port of slave process
|
||||||
int slave_port, nport_trial;
|
int slave_port, nport_trial;
|
||||||
|
// reduce buffer size
|
||||||
|
size_t reduce_buffer_size;
|
||||||
// current rank
|
// current rank
|
||||||
int rank;
|
int rank;
|
||||||
// world size
|
// world size
|
||||||
@ -405,9 +432,8 @@ int GetWorldSize(void) {
|
|||||||
std::string GetProcessorName(void) {
|
std::string GetProcessorName(void) {
|
||||||
return manager.GetHost();
|
return manager.GetHost();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsDistributed(void) {
|
bool IsDistributed(void) {
|
||||||
return true;
|
return manager.IsDistributed();
|
||||||
}
|
}
|
||||||
/*! \brief intiialize the synchronization module */
|
/*! \brief intiialize the synchronization module */
|
||||||
void Init(int argc, char *argv[]) {
|
void Init(int argc, char *argv[]) {
|
||||||
|
|||||||
@ -11,19 +11,24 @@ else
|
|||||||
endif
|
endif
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = test_group_data test_quantile test_sock
|
BIN = test_group_data test_quantile test_allreduce
|
||||||
|
OBJ = sync_tcp.o
|
||||||
.PHONY: clean all
|
.PHONY: clean all
|
||||||
|
|
||||||
all: $(BIN) $(MPIBIN)
|
all: $(BIN) $(MPIBIN)
|
||||||
|
|
||||||
|
sync_tcp.o: ../src/sync/sync_tcp.cpp ../src/utils/*.h
|
||||||
|
|
||||||
test_group_data: test_group_data.cpp ../src/utils/*.h
|
test_group_data: test_group_data.cpp ../src/utils/*.h
|
||||||
test_quantile: test_quantile.cpp ../src/utils/*.h
|
test_quantile: test_quantile.cpp ../src/utils/*.h
|
||||||
test_sock: test_sock.cpp ../src/utils/*.h
|
test_allreduce: test_allreduce.cpp ../src/utils/*.h ../src/sync/sync.h sync_tcp.o
|
||||||
|
|
||||||
$(BIN) :
|
$(BIN) :
|
||||||
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||||
|
|
||||||
|
$(OBJ) :
|
||||||
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
|
||||||
|
|
||||||
$(MPIBIN) :
|
$(MPIBIN) :
|
||||||
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||||
|
|
||||||
|
|||||||
22
test/test_allreduce.cpp
Normal file
22
test/test_allreduce.cpp
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#include <sync/sync.h>
|
||||||
|
|
||||||
|
using namespace xgboost;
|
||||||
|
|
||||||
|
int main(int argc, char *argv[]) {
|
||||||
|
sync::Init(argc, argv);
|
||||||
|
int rank = sync::GetRank();
|
||||||
|
std::string name = sync::GetProcessorName().c_str();
|
||||||
|
printf("start %s rank=%d\n", name.c_str(), rank);
|
||||||
|
|
||||||
|
std::vector<float> ndata(16);
|
||||||
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
|
ndata[i] = i + rank;
|
||||||
|
}
|
||||||
|
sync::AllReduce(&ndata[0], ndata.size(), sync::kMax);
|
||||||
|
sync::Finalize();
|
||||||
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
|
printf("%lu: %f\n", i, ndata[i]);
|
||||||
|
}
|
||||||
|
printf("all end\n");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user