check in allreduce tcp, check if there could be more concise form
This commit is contained in:
parent
b6e1b19205
commit
7ec3fc936a
6
Makefile
6
Makefile
@ -12,7 +12,7 @@ endif
|
||||
|
||||
# specify tensor path
|
||||
BIN = xgboost
|
||||
OBJ = updater.o gbm.o io.o main.o sync_empty.o
|
||||
OBJ = updater.o gbm.o io.o main.o sync_empty.o sync_tcp.o
|
||||
MPIOBJ = sync_mpi.o
|
||||
MPIBIN = xgboost-mpi
|
||||
SLIB = wrapper/libxgboostwrapper.so
|
||||
@ -24,11 +24,11 @@ mpi: $(MPIBIN)
|
||||
|
||||
python: wrapper/libxgboostwrapper.so
|
||||
# now the wrapper takes in two files. io and wrapper part
|
||||
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp $(OBJ)
|
||||
updater.o: src/tree/updater.cpp src/tree/*.hpp src/*.h src/tree/*.h src/utils/*.h
|
||||
gbm.o: src/gbm/gbm.cpp src/gbm/*.hpp src/gbm/*.h
|
||||
io.o: src/io/io.cpp src/io/*.hpp src/utils/*.h src/learner/dmatrix.h src/*.h
|
||||
sync_mpi.o: src/sync/sync_mpi.cpp
|
||||
sync_mpi.o: src/sync/sync_mpi.cpp
|
||||
sync_tcp.o: src/sync/sync_tcp.cpp
|
||||
sync_empty.o: src/sync/sync_empty.cpp
|
||||
main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h
|
||||
xgboost-mpi: updater.o gbm.o io.o main.o sync_mpi.o
|
||||
|
||||
280
src/sync/sync_tcp.cpp
Normal file
280
src/sync/sync_tcp.cpp
Normal file
@ -0,0 +1,280 @@
|
||||
/*!
|
||||
* \file sync_tcp.cpp
|
||||
* \brief implementation of sync AllReduce using TCP sockets
|
||||
* with use async socket and tree-shape reduction
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include "./sync.h"
|
||||
#include "../utils/socket.h"
|
||||
|
||||
namespace MPI {
|
||||
struct Datatype {
|
||||
size_t type_size;
|
||||
Datatype(size_t type_size) : type_size(type_size) {}
|
||||
};
|
||||
}
|
||||
namespace xgboost {
|
||||
namespace sync {
|
||||
/*! \brief implementation of sync goes to here */
|
||||
class SyncManager {
|
||||
public:
|
||||
// initialize the manager
|
||||
inline void Init(int argc, char *argv[]) {
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is not thread-safe
|
||||
* \param sendrecvbuf buffer for both sending and recving data
|
||||
* \param type_n4bytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
*/
|
||||
inline void AllReduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceHandle::ReduceFunction reducer) {
|
||||
if (parent.size() == 0 && childs.size() == 0) return;
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
// size of space that we already performs reduce in up pass
|
||||
size_t size_up_reduce = 0;
|
||||
// size of space that we have already passed to parent
|
||||
size_t size_up_out = 0;
|
||||
// size of message we received, and send in the down pass
|
||||
size_t size_down_in = 0;
|
||||
// initialize the send buffer
|
||||
for (size_t i = 0; i < childs.size(); ++i) {
|
||||
childs[i].Init(type_nbytes, count);
|
||||
}
|
||||
// if no childs, no need to reduce
|
||||
if (childs.size() == 0) size_up_reduce = total_size;
|
||||
// while we have not passed the messages out
|
||||
while(true) {
|
||||
selecter.Select();
|
||||
// read data from childs
|
||||
for (size_t i = 0; i < childs.size(); ++i) {
|
||||
if (selecter.CheckRead(childs[i].sock)) {
|
||||
childs[i].Read(size_up_out);
|
||||
}
|
||||
}
|
||||
// peform reduce
|
||||
if (childs.size() != 0) {
|
||||
const size_t buffer_size = childs[0].buffer_size;
|
||||
// do upstream reduce
|
||||
size_t min_read = childs[0].size_read;
|
||||
for (size_t i = 1; i < childs.size(); ++i) {
|
||||
min_read = std::min(min_read, childs[i].size_read);
|
||||
}
|
||||
// align to type_nbytes
|
||||
min_read = (min_read / type_nbytes * type_nbytes);
|
||||
// start position
|
||||
size_t start = size_up_reduce % buffer_size;
|
||||
// peform read till end of buffer
|
||||
if (start + min_read - size_up_reduce > buffer_size) {
|
||||
const size_t nread = buffer_size - start;
|
||||
utils::Assert(nread % type_nbytes == 0, "AllReduce: size check");
|
||||
for (size_t i = 0; i < childs.size(); ++i) {
|
||||
reducer(childs[i].buffer_head + start,
|
||||
sendrecvbuf + size_up_reduce,
|
||||
nread / type_nbytes,
|
||||
MPI::Datatype(type_nbytes));
|
||||
}
|
||||
size_up_reduce += nread;
|
||||
start = 0;
|
||||
}
|
||||
// peform second phase of reduce
|
||||
const size_t nread = min_read - size_up_reduce;
|
||||
if (nread != 0) {
|
||||
utils::Assert(nread % type_nbytes == 0, "AllReduce: size check");
|
||||
for (size_t i = 0; i < childs.size(); ++i) {
|
||||
reducer(childs[i].buffer_head + start,
|
||||
sendrecvbuf + size_up_reduce,
|
||||
nread / type_nbytes,
|
||||
MPI::Datatype(type_nbytes));
|
||||
}
|
||||
}
|
||||
size_up_reduce += nread;
|
||||
}
|
||||
if (parent.size() != 0) {
|
||||
// can pass message up to parent
|
||||
if (selecter.CheckWrite(parent[0])) {
|
||||
size_up_out += parent[0]
|
||||
.Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
||||
}
|
||||
// read data from parent
|
||||
if (selecter.CheckRead(parent[0])) {
|
||||
size_down_in += parent[0]
|
||||
.Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||
utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error");
|
||||
}
|
||||
} else {
|
||||
// this is root, can use reduce as most recent point
|
||||
size_down_in = size_up_out = size_up_reduce;
|
||||
}
|
||||
// check if we finished the job of message passing
|
||||
size_t nfinished = size_down_in;
|
||||
// can pass message down to childs
|
||||
for (size_t i = 0; i < childs.size(); ++i) {
|
||||
if (selecter.CheckWrite(childs[i].sock)) {
|
||||
childs[i].size_write += childs[i].sock
|
||||
.Send(sendrecvbuf + childs[i].size_write, size_down_in - childs[i].size_write);
|
||||
}
|
||||
nfinished = std::min(childs[i].size_write, nfinished);
|
||||
}
|
||||
// check boundary condition
|
||||
if (nfinished >= total_size) {
|
||||
utils::Assert(nfinished == total_size, "AllReduce: nfinished check");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void Bcast(std::string *sendrecv_data, int root) {
|
||||
if (parent.size() == 0 && childs.size() == 0) return;
|
||||
// message send to parent
|
||||
size_t size_up_out = 0;
|
||||
// all messages received
|
||||
size_t size_in = 0;
|
||||
// all headers received so far
|
||||
size_t header_in = 0;
|
||||
// total size of data
|
||||
size_t total_size;
|
||||
// input channel, -1 means parent, -2 means unknown yet
|
||||
// otherwise its child index
|
||||
int in_channel = -2;
|
||||
// root already reads all data in
|
||||
if (root == rank) {
|
||||
in_channel = -3;
|
||||
total_size = size_in = sendrecv_data->length();
|
||||
header_in = sizeof(total_size);
|
||||
}
|
||||
// initialize write position
|
||||
for (size_t i = 0; i < childs.size(); ++i) {
|
||||
childs[i].size_write = 0;
|
||||
}
|
||||
const int nchilds = static_cast<int>(childs.size());
|
||||
|
||||
while (true) {
|
||||
selecter.Select();
|
||||
if (selecter.CheckRead(parent[0])) {
|
||||
utils::Assert(in_channel == -2 || in_channel == -1, "invalid in channel");
|
||||
this->BcastRecvData(parent[0], sendrecv_data,
|
||||
header_in, size_in, total_size);
|
||||
if (header_in != 0) in_channel = -1;
|
||||
}
|
||||
for (int i = 0; i < nchilds; ++i) {
|
||||
if (selecter.CheckRead(childs[i].sock)) {
|
||||
utils::Assert(in_channel == -2 || in_channel == i, "invalid in channel");
|
||||
this->BcastRecvData(parent[0], sendrecv_data,
|
||||
header_in, size_in, total_size);
|
||||
if (header_in != 0) in_channel = i;
|
||||
}
|
||||
}
|
||||
if (in_channel == -2) continue;
|
||||
if (in_channel != -1) {
|
||||
if (selecter.CheckWrite(parent[0])) {
|
||||
size_t nsend = size_in - size_up_out;
|
||||
if (nsend != 0) {
|
||||
size_up_out += parent[0].Send(&(*sendrecv_data)[0] + size_up_out, nsend);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_up_out = size_in;
|
||||
}
|
||||
size_t nfinished = size_up_out;
|
||||
for (int i = 0; i < nchilds; ++i) {
|
||||
if (in_channel != i) {
|
||||
if (selecter.CheckWrite(childs[i].sock)) {
|
||||
size_t nsend = size_in - childs[i].size_write;
|
||||
if (nsend != 0) {
|
||||
childs[i].size_write += childs[i].sock
|
||||
.Send(&(*sendrecv_data)[0] + childs[i].size_write, nsend);
|
||||
}
|
||||
}
|
||||
nfinished = std::min(nfinished, childs[i].size_write);
|
||||
}
|
||||
}
|
||||
// check boundary condition
|
||||
if (nfinished >= total_size) {
|
||||
utils::Assert(nfinished == total_size, "Bcast: nfinished check");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
inline void BcastRecvData(utils::TCPSocket &sock,
|
||||
std::string *sendrecv_data,
|
||||
size_t &header_in,
|
||||
size_t &size_in,
|
||||
size_t &total_size) {
|
||||
if (header_in < sizeof(total_size)) {
|
||||
char *p = reinterpret_cast<char*>(&total_size);
|
||||
header_in += sock.Recv(p + size_in, sizeof(total_size) - header_in);
|
||||
if (header_in == sizeof(total_size)) {
|
||||
sendrecv_data->resize(total_size);
|
||||
}
|
||||
} else {
|
||||
size_t nread = total_size - size_in;
|
||||
if (nread != 0) {
|
||||
size_in += sock
|
||||
.Recv(&(*sendrecv_data)[0] + size_in, nread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 128 MB
|
||||
const static size_t kBufferSize = 128;
|
||||
// an independent child record
|
||||
struct ChildRecord {
|
||||
public:
|
||||
// socket to get data from child
|
||||
utils::TCPSocket sock;
|
||||
// size of data readed from child
|
||||
size_t size_read;
|
||||
// size of data write into child
|
||||
size_t size_write;
|
||||
// pointer to buffer head
|
||||
char *buffer_head;
|
||||
// buffer size, in bytes
|
||||
size_t buffer_size;
|
||||
// initialize buffer
|
||||
inline void Init(size_t type_nbytes, size_t count) {
|
||||
utils::Assert(type_nbytes < kBufferSize, "too large type_nbytes");
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
buffer_.resize(std::min(kBufferSize, n));
|
||||
// make sure align to type_nbytes
|
||||
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
// set buffer head
|
||||
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
||||
// set write head
|
||||
size_write = size_read = 0;
|
||||
}
|
||||
// maximum number of bytes we are able to read
|
||||
// currently without corrupt the data
|
||||
inline void Read(size_t size_up_out) {
|
||||
size_t ngap = size_read - size_up_out;
|
||||
utils::Assert(ngap <= buffer_size, "AllReduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = std::min(ngap, buffer_size - offset);
|
||||
size_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
size_read += len;
|
||||
}
|
||||
|
||||
private:
|
||||
// recv buffer to get data from child
|
||||
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
||||
std::vector<uint64_t> buffer_;
|
||||
};
|
||||
// current rank
|
||||
int rank;
|
||||
// parent socket, can be of size 0 or 1
|
||||
std::vector<utils::TCPSocket> parent;
|
||||
// sockets of all childs, can be of size 0, 1, 2 or more
|
||||
std::vector<ChildRecord> childs;
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
};
|
||||
|
||||
} // namespace sync
|
||||
} // namespace xgboost
|
||||
@ -2,8 +2,7 @@
|
||||
#define XGBOOST_UTILS_SOCKET_H
|
||||
/*!
|
||||
* \file socket.h
|
||||
* \brief this file aims to provide a platform independent wrapper
|
||||
* of socket
|
||||
* \brief this file aims to provide a wrapper of sockets
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <fcntl.h>
|
||||
@ -63,7 +62,7 @@ class TCPSocket {
|
||||
// constructor
|
||||
TCPSocket(void) {}
|
||||
// default conversion to int
|
||||
inline int operator()() const {
|
||||
inline operator int() const {
|
||||
return sockfd;
|
||||
}
|
||||
/*!
|
||||
@ -122,6 +121,7 @@ class TCPSocket {
|
||||
* \return size of data actually sent
|
||||
*/
|
||||
inline size_t Send(const void *buf, size_t len, int flag = 0) {
|
||||
if (len == 0) return 0;
|
||||
ssize_t ret = send(sockfd, buf, len, flag);
|
||||
if (ret == -1) SockError("Send", errno);
|
||||
return ret;
|
||||
@ -134,6 +134,7 @@ class TCPSocket {
|
||||
* \return size of data actually received
|
||||
*/
|
||||
inline size_t Recv(void *buf, size_t len, int flags = 0) {
|
||||
if (len == 0) return 0;
|
||||
ssize_t ret = recv(sockfd, buf, len, flags);
|
||||
if (ret == -1) SockError("Recv", errno);
|
||||
return ret;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user