initial version of allreduce
This commit is contained in:
parent
5e5bdda491
commit
d37f38c455
3
.gitignore
vendored
3
.gitignore
vendored
@ -26,3 +26,6 @@
|
|||||||
*.exe
|
*.exe
|
||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
*~
|
||||||
|
*.pyc
|
||||||
|
test
|
||||||
6
README.md
Normal file
6
README.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
AllReduce Abstraction
|
||||||
|
====
|
||||||
|
* Tianqi, Nacho, Tianyi
|
||||||
|
|
||||||
|
Go!
|
||||||
|
|
||||||
107
src/allreduce.h
Normal file
107
src/allreduce.h
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
/*!
|
||||||
|
* \file allreduce.h
|
||||||
|
* \brief This file defines a template wrapper of engine to ensure
|
||||||
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
|
*/
|
||||||
|
#include "./engine.h"
|
||||||
|
|
||||||
|
/*! \brief namespace of all reduce */
|
||||||
|
namespace sync {
|
||||||
|
/*! \brief namespace of operator */
|
||||||
|
namespace op {
|
||||||
|
struct Max {
|
||||||
|
template<typename DType>
|
||||||
|
inline static void Reduce(DType &dst, const DType &src) {
|
||||||
|
if (dst < src) dst = src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct Sum {
|
||||||
|
template<typename DType>
|
||||||
|
inline static void Reduce(DType &dst, const DType &src) {
|
||||||
|
dst += src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct BitOR {
|
||||||
|
template<typename DType>
|
||||||
|
inline static void Reduce(DType &dst, const DType &src) {
|
||||||
|
dst |= src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template<typename OP, typename DType>
|
||||||
|
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||||
|
const DType *src = (const DType*)src_;
|
||||||
|
DType *dst = (DType*)dst_;
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
OP::Reduce(dst[i], src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace op
|
||||||
|
|
||||||
|
void Init(int argc, char *argv[]) {
|
||||||
|
engine::Init(argc, argv);
|
||||||
|
}
|
||||||
|
void Finalize(void) {
|
||||||
|
engine::Finalize();
|
||||||
|
}
|
||||||
|
/*! \brief get rank of current process */
|
||||||
|
inline int GetRank(void) {
|
||||||
|
return engine::GetEngine()->GetRank();
|
||||||
|
}
|
||||||
|
/*! \brief get total number of process */
|
||||||
|
int GetWorldSize(void) {
|
||||||
|
return engine::GetEngine()->GetWorldSize();
|
||||||
|
}
|
||||||
|
/*! \brief get name of processor */
|
||||||
|
std::string GetProcessorName(void) {
|
||||||
|
return engine::GetEngine()->GetHost();
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief broadcast an std::string to all others from root
|
||||||
|
* \param sendrecv_data the pointer to send or recive buffer,
|
||||||
|
* receive buffer does not need to be pre-allocated
|
||||||
|
* and string will be resized to correct length
|
||||||
|
* \param root the root of process
|
||||||
|
*/
|
||||||
|
inline void Bcast(std::string *sendrecv_data, int root) {
|
||||||
|
engine::IEngine *e = engine::GetEngine();
|
||||||
|
unsigned len = static_cast<unsigned>(sendrecv_data->length());
|
||||||
|
e->Broadcast(&len, sizeof(len), root);
|
||||||
|
sendrecv_data->resize(len);
|
||||||
|
if (len != 0) {
|
||||||
|
e->Broadcast(&(*sendrecv_data)[0], len, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* this function is NOT thread-safe
|
||||||
|
* Example Usage: the following code gives sum of the result
|
||||||
|
* vector<int> data(10);
|
||||||
|
* ...
|
||||||
|
* AllReduce<op::Sum>(&data[0], data.size());
|
||||||
|
* ...
|
||||||
|
* \param sendrecvbuf buffer for both sending and recving data
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \tparam OP see namespace op, reduce operator
|
||||||
|
* \tparam DType type of data
|
||||||
|
*/
|
||||||
|
template<typename OP, typename DType>
|
||||||
|
inline void AllReduce(DType *sendrecvbuf, size_t count) {
|
||||||
|
engine::GetEngine()->AllReduce(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief load latest check point
|
||||||
|
* \param p_model pointer to the model
|
||||||
|
* \return true if there was stored checkpoint and load was successful
|
||||||
|
* false if there was no stored checkpoint, means we are start over gain
|
||||||
|
*/
|
||||||
|
inline bool LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
|
return engine::GetEngine()->LoadCheckPoint(p_model);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
|
* \param p_model pointer to the model
|
||||||
|
*/
|
||||||
|
inline void CheckPoint(const utils::ISerializable &model) {
|
||||||
|
engine::GetEngine()->CheckPoint(model);
|
||||||
|
}
|
||||||
|
} // namespace allreduce
|
||||||
80
src/engine.h
Normal file
80
src/engine.h
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#ifndef ALLREDUCE_ENGINE_H
|
||||||
|
#define ALLREDUCE_ENGINE_H
|
||||||
|
/*!
|
||||||
|
* \file engine.h
|
||||||
|
* \brief This file defines the interface of allreduce library
|
||||||
|
* \author Tianqi Chen, Nacho, Tianyi
|
||||||
|
*/
|
||||||
|
#include "./io.h"
|
||||||
|
|
||||||
|
namespace MPI {
|
||||||
|
/*! \brief MPI data type just to be compatible with MPI reduce function*/
|
||||||
|
class Datatype;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief namespace of allreduce functionality */
|
||||||
|
namespace engine {
|
||||||
|
/*! \brief interface of core AllReduce engine */
|
||||||
|
class IEngine {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief reduce function, the same form of MPI reduce function is used,
|
||||||
|
* to be compatible with MPI interface
|
||||||
|
* In all the functions, the memory is ensured to aligned to 64-bit
|
||||||
|
* which means it is OK to cast src,dst to double* int* etc
|
||||||
|
* \param src pointer to source space
|
||||||
|
* \param dst pointer to destination reduction
|
||||||
|
* \param count total number of elements to be reduced(note this is total number of elements instead of bytes)
|
||||||
|
* the definition of reduce function should be type aware
|
||||||
|
* \param dtype the data type object, to be compatible with MPI reduce
|
||||||
|
*/
|
||||||
|
typedef void (ReduceFunction) (const void *src,
|
||||||
|
void *dst, int count,
|
||||||
|
const MPI::Datatype &dtype);
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* this function is NOT thread-safe
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_n4bytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
*/
|
||||||
|
virtual void AllReduce(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief broadcast data from root to all nodes
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param size the size of the data to be broadcasted
|
||||||
|
* \param root the root worker id to broadcast the data
|
||||||
|
*/
|
||||||
|
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief load latest check point
|
||||||
|
* \param p_model pointer to the model
|
||||||
|
* \return true if there was stored checkpoint and load was successful
|
||||||
|
* false if there was no stored checkpoint, means we are start over gain
|
||||||
|
*/
|
||||||
|
virtual bool LoadCheckPoint(utils::ISerializable *p_model) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
|
* \param p_model pointer to the model
|
||||||
|
*/
|
||||||
|
virtual void CheckPoint(const utils::ISerializable &model) = 0;
|
||||||
|
/*! \brief get rank of current node */
|
||||||
|
virtual int GetRank(void) const = 0;
|
||||||
|
/*! \brief get total number of */
|
||||||
|
virtual int GetWorldSize(void) const = 0;
|
||||||
|
/*! \brief get the host name of current node */
|
||||||
|
virtual std::string GetHost(void) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief intiialize the engine module */
|
||||||
|
void Init(int argc, char *argv[]);
|
||||||
|
/*! \brief finalize engine module */
|
||||||
|
void Finalize(void);
|
||||||
|
/*! \brief singleton method to get engine */
|
||||||
|
IEngine *GetEngine(void);
|
||||||
|
} // namespace engine
|
||||||
|
#endif // ALLREDUCE_ENGINE_H
|
||||||
448
src/engine_tcp.cpp
Normal file
448
src/engine_tcp.cpp
Normal file
@ -0,0 +1,448 @@
|
|||||||
|
/*!
|
||||||
|
* \file engine_tcp.cpp
|
||||||
|
* \brief implementation of sync AllReduce using TCP sockets
|
||||||
|
* with use non-block socket and tree-shape reduction
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
|
#define NOMINMAX
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <cstring>
|
||||||
|
#include "./engine.h"
|
||||||
|
#include "./socket.h"
|
||||||
|
|
||||||
|
namespace MPI {
|
||||||
|
class Datatype {
|
||||||
|
public:
|
||||||
|
size_t type_size;
|
||||||
|
Datatype(size_t type_size) : type_size(type_size) {}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
namespace engine {
|
||||||
|
/*! \brief implementation of sync goes to here */
|
||||||
|
class SyncManager : public IEngine {
|
||||||
|
public:
|
||||||
|
const static int kMagic = 0xff99;
|
||||||
|
SyncManager(void) {
|
||||||
|
master_uri = "NULL";
|
||||||
|
master_port = 9000;
|
||||||
|
host_uri = "";
|
||||||
|
slave_port = 9010;
|
||||||
|
nport_trial = 1000;
|
||||||
|
rank = 0;
|
||||||
|
world_size = 1;
|
||||||
|
this->SetParam("reduce_buffer", "256MB");
|
||||||
|
}
|
||||||
|
~SyncManager(void) {
|
||||||
|
}
|
||||||
|
inline void Shutdown(void) {
|
||||||
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
|
links[i].sock.Close();
|
||||||
|
}
|
||||||
|
links.clear();
|
||||||
|
utils::TCPSocket::Finalize();
|
||||||
|
}
|
||||||
|
/*! \brief set parameters to the sync manager */
|
||||||
|
inline void SetParam(const char *name, const char *val) {
|
||||||
|
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||||
|
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||||
|
if (!strcmp(name, "reduce_buffer")) {
|
||||||
|
char unit;
|
||||||
|
unsigned long amount;
|
||||||
|
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
||||||
|
switch (unit) {
|
||||||
|
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
|
||||||
|
case 'K': reduce_buffer_size = amount << 7UL; break;
|
||||||
|
case 'M': reduce_buffer_size = amount << 17UL; break;
|
||||||
|
case 'G': reduce_buffer_size = amount << 27UL; break;
|
||||||
|
default: utils::Error("invalid format for reduce buffer");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// initialize the manager
|
||||||
|
inline void Init(void) {
|
||||||
|
utils::TCPSocket::Startup();
|
||||||
|
// single node mode
|
||||||
|
if (master_uri == "NULL") return;
|
||||||
|
utils::Assert(links.size() == 0, "can only call Init once");
|
||||||
|
int magic = kMagic;
|
||||||
|
int nchild = 0, nparent = 0;
|
||||||
|
this->host_uri = utils::SockAddr::GetHostName();
|
||||||
|
// get information from master
|
||||||
|
utils::TCPSocket master;
|
||||||
|
master.Create();
|
||||||
|
master.Connect(utils::SockAddr(master_uri.c_str(), master_port));
|
||||||
|
utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1");
|
||||||
|
utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2");
|
||||||
|
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
|
||||||
|
utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3");
|
||||||
|
utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4");
|
||||||
|
utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5");
|
||||||
|
utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6");
|
||||||
|
utils::Assert(nchild >= 0, "in correct number of childs");
|
||||||
|
utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent");
|
||||||
|
|
||||||
|
// create listen
|
||||||
|
utils::TCPSocket sock_listen;
|
||||||
|
sock_listen.Create();
|
||||||
|
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
||||||
|
utils::Check(port != -1, "sync::Init fail to bind the ports specified");
|
||||||
|
sock_listen.Listen();
|
||||||
|
|
||||||
|
if (nparent != 0) {
|
||||||
|
parent_index = 0;
|
||||||
|
links.push_back(LinkRecord());
|
||||||
|
int len, hport;
|
||||||
|
std::string hname;
|
||||||
|
utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9");
|
||||||
|
hname.resize(len);
|
||||||
|
utils::Assert(len != 0, "string must not be empty");
|
||||||
|
utils::Assert(master.RecvAll(&hname[0], len) == static_cast<size_t>(len), "sync::Init failure 10");
|
||||||
|
utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11");
|
||||||
|
links[0].sock.Create();
|
||||||
|
links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport));
|
||||||
|
utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12");
|
||||||
|
utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13");
|
||||||
|
utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch");
|
||||||
|
parent_index = 0;
|
||||||
|
} else {
|
||||||
|
parent_index = -1;
|
||||||
|
}
|
||||||
|
// send back socket listening port to master
|
||||||
|
utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14");
|
||||||
|
// close connection to master
|
||||||
|
master.Close();
|
||||||
|
// accept links from childs
|
||||||
|
for (int i = 0; i < nchild; ++i) {
|
||||||
|
LinkRecord r;
|
||||||
|
while (true) {
|
||||||
|
r.sock = sock_listen.Accept();
|
||||||
|
if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) {
|
||||||
|
utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
// not a valid child
|
||||||
|
r.sock.Close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
links.push_back(r);
|
||||||
|
}
|
||||||
|
// close listening sockets
|
||||||
|
sock_listen.Close();
|
||||||
|
// setup selecter
|
||||||
|
selecter.Clear();
|
||||||
|
for (size_t i = 0; i < links.size(); ++i) {
|
||||||
|
// set the socket to non-blocking mode
|
||||||
|
links[i].sock.SetNonBlock(true);
|
||||||
|
selecter.WatchRead(links[i].sock);
|
||||||
|
selecter.WatchWrite(links[i].sock);
|
||||||
|
}
|
||||||
|
// done
|
||||||
|
}
|
||||||
|
/*! \brief get rank */
|
||||||
|
virtual int GetRank(void) const {
|
||||||
|
return rank;
|
||||||
|
}
|
||||||
|
/*! \brief get rank */
|
||||||
|
virtual int GetWorldSize(void) const {
|
||||||
|
return world_size;
|
||||||
|
}
|
||||||
|
/*! \brief get rank */
|
||||||
|
virtual std::string GetHost(void) const {
|
||||||
|
return host_uri;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform in-place allreduce, on sendrecvbuf
|
||||||
|
* this function is NOT thread-safe
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param type_n4bytes the unit number of bytes the type have
|
||||||
|
* \param count number of elements to be reduced
|
||||||
|
* \param reducer reduce function
|
||||||
|
*/
|
||||||
|
virtual void AllReduce(void *sendrecvbuf_,
|
||||||
|
size_t type_nbytes,
|
||||||
|
size_t count,
|
||||||
|
ReduceFunction reducer) {
|
||||||
|
if (links.size() == 0) return;
|
||||||
|
// total size of message
|
||||||
|
const size_t total_size = type_nbytes * count;
|
||||||
|
// number of links
|
||||||
|
const int nlink = static_cast<int>(links.size());
|
||||||
|
// send recv buffer
|
||||||
|
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
|
// size of space that we already performs reduce in up pass
|
||||||
|
size_t size_up_reduce = 0;
|
||||||
|
// size of space that we have already passed to parent
|
||||||
|
size_t size_up_out = 0;
|
||||||
|
// size of message we received, and send in the down pass
|
||||||
|
size_t size_down_in = 0;
|
||||||
|
|
||||||
|
// initialize the link ring-buffer and pointer
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (i != parent_index) {
|
||||||
|
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
|
||||||
|
}
|
||||||
|
links[i].ResetSize();
|
||||||
|
}
|
||||||
|
// if no childs, no need to reduce
|
||||||
|
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||||
|
size_up_reduce = total_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// while we have not passed the messages out
|
||||||
|
while(true) {
|
||||||
|
selecter.Select();
|
||||||
|
// read data from childs
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
||||||
|
links[i].ReadToRingBuffer(size_up_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// this node have childs, peform reduce
|
||||||
|
if (nlink > static_cast<int>(parent_index != -1)) {
|
||||||
|
size_t buffer_size = 0;
|
||||||
|
// do upstream reduce
|
||||||
|
size_t max_reduce = total_size;
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (i != parent_index) {
|
||||||
|
max_reduce= std::min(max_reduce, links[i].size_read);
|
||||||
|
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
|
||||||
|
"buffer size inconsistent");
|
||||||
|
buffer_size = links[i].buffer_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
utils::Assert(buffer_size != 0, "must assign buffer_size");
|
||||||
|
// round to type_n4bytes
|
||||||
|
max_reduce = (max_reduce / type_nbytes * type_nbytes);
|
||||||
|
// peform reduce, can be at most two rounds
|
||||||
|
while (size_up_reduce < max_reduce) {
|
||||||
|
// start position
|
||||||
|
size_t start = size_up_reduce % buffer_size;
|
||||||
|
// peform read till end of buffer
|
||||||
|
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce);
|
||||||
|
utils::Assert(nread % type_nbytes == 0, "AllReduce: size check");
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (i != parent_index) {
|
||||||
|
reducer(links[i].buffer_head + start,
|
||||||
|
sendrecvbuf + size_up_reduce,
|
||||||
|
static_cast<int>(nread / type_nbytes),
|
||||||
|
MPI::Datatype(type_nbytes));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_up_reduce += nread;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parent_index != -1) {
|
||||||
|
// pass message up to parent, can pass data that are already been reduced
|
||||||
|
if (selecter.CheckWrite(links[parent_index].sock)) {
|
||||||
|
size_up_out += links[parent_index].sock.
|
||||||
|
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
||||||
|
}
|
||||||
|
// read data from parent
|
||||||
|
if (selecter.CheckRead(links[parent_index].sock)) {
|
||||||
|
size_down_in += links[parent_index].sock.
|
||||||
|
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||||
|
utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// this is root, can use reduce as most recent point
|
||||||
|
size_down_in = size_up_out = size_up_reduce;
|
||||||
|
}
|
||||||
|
// check if we finished the job of message passing
|
||||||
|
size_t nfinished = size_down_in;
|
||||||
|
// can pass message down to childs
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (i != parent_index) {
|
||||||
|
if (selecter.CheckWrite(links[i].sock)) {
|
||||||
|
links[i].WriteFromArray(sendrecvbuf, size_down_in);
|
||||||
|
}
|
||||||
|
nfinished = std::min(links[i].size_write, nfinished);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check boundary condition
|
||||||
|
if (nfinished >= total_size) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief broadcast data from root to all nodes
|
||||||
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||||
|
* \param size the size of the data to be broadcasted
|
||||||
|
* \param root the root worker id to broadcast the data
|
||||||
|
*/
|
||||||
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
|
if (links.size() == 0) return;
|
||||||
|
// number of links
|
||||||
|
const int nlink = static_cast<int>(links.size());
|
||||||
|
// size of space already read from data
|
||||||
|
size_t size_in = 0;
|
||||||
|
// input link, -2 means unknown yet, -1 means this is root
|
||||||
|
int in_link = -2;
|
||||||
|
|
||||||
|
// initialize the link statistics
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
links[i].ResetSize();
|
||||||
|
}
|
||||||
|
// root have all the data
|
||||||
|
if (this->rank == root) {
|
||||||
|
size_in = total_size;
|
||||||
|
in_link = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// while we have not passed the messages out
|
||||||
|
while(true) {
|
||||||
|
selecter.Select();
|
||||||
|
if (in_link == -2) {
|
||||||
|
// probe in-link
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (selecter.CheckRead(links[i].sock)) {
|
||||||
|
links[i].ReadToArray(sendrecvbuf_, total_size);
|
||||||
|
size_in = links[i].size_read;
|
||||||
|
if (size_in != 0) {
|
||||||
|
in_link = i; break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// read from in link
|
||||||
|
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
|
||||||
|
links[in_link].ReadToArray(sendrecvbuf_, total_size);
|
||||||
|
size_in = links[in_link].size_read;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t nfinished = total_size;
|
||||||
|
// send data to all out-link
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (i != in_link) {
|
||||||
|
if (selecter.CheckWrite(links[i].sock)) {
|
||||||
|
links[i].WriteFromArray(sendrecvbuf_, size_in);
|
||||||
|
}
|
||||||
|
nfinished = std::min(nfinished, links[i].size_write);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check boundary condition
|
||||||
|
if (nfinished >= total_size) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
virtual void CheckPoint(const utils::ISerializable &model) {
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// an independent child record
|
||||||
|
struct LinkRecord {
|
||||||
|
public:
|
||||||
|
// socket to get data from/to link
|
||||||
|
utils::TCPSocket sock;
|
||||||
|
// size of data readed from link
|
||||||
|
size_t size_read;
|
||||||
|
// size of data sent to the link
|
||||||
|
size_t size_write;
|
||||||
|
// pointer to buffer head
|
||||||
|
char *buffer_head;
|
||||||
|
// buffer size, in bytes
|
||||||
|
size_t buffer_size;
|
||||||
|
// initialize buffer
|
||||||
|
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
||||||
|
size_t n = (type_nbytes * count + 7)/ 8;
|
||||||
|
buffer_.resize(std::min(reduce_buffer_size, n));
|
||||||
|
// make sure align to type_nbytes
|
||||||
|
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||||
|
utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size);
|
||||||
|
// set buffer head
|
||||||
|
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
||||||
|
}
|
||||||
|
// reset the recv and sent size
|
||||||
|
inline void ResetSize(void) {
|
||||||
|
size_write = size_read = 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief read data into ring-buffer, with care not to existing useful override data
|
||||||
|
* position after protect_start
|
||||||
|
* \param protect_start all data start from protect_start is still needed in buffer
|
||||||
|
* read shall not override this
|
||||||
|
*/
|
||||||
|
inline void ReadToRingBuffer(size_t protect_start) {
|
||||||
|
size_t ngap = size_read - protect_start;
|
||||||
|
utils::Assert(ngap <= buffer_size, "AllReduce: boundary check");
|
||||||
|
size_t offset = size_read % buffer_size;
|
||||||
|
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||||
|
size_read += sock.Recv(buffer_head + offset, nmax);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief read data into array,
|
||||||
|
* this function can not be used together with ReadToRingBuffer
|
||||||
|
* a link can either read into the ring buffer, or existing array
|
||||||
|
* \param max_size maximum size of array
|
||||||
|
*/
|
||||||
|
inline void ReadToArray(void *recvbuf_, size_t max_size) {
|
||||||
|
char *p = static_cast<char*>(recvbuf_);
|
||||||
|
size_read += sock.Recv(p + size_read, max_size - size_read);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief write data in array to sock
|
||||||
|
* \param sendbuf_ head of array
|
||||||
|
* \param max_size maximum size of array
|
||||||
|
*/
|
||||||
|
inline void WriteFromArray(const void *sendbuf_, size_t max_size) {
|
||||||
|
const char *p = static_cast<const char*>(sendbuf_);
|
||||||
|
size_write += sock.Send(p + size_write, max_size - size_write);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// recv buffer to get data from child
|
||||||
|
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
||||||
|
std::vector<uint64_t> buffer_;
|
||||||
|
};
|
||||||
|
//------------------
|
||||||
|
// uri of current host, to be set by Init
|
||||||
|
std::string host_uri;
|
||||||
|
// uri of master
|
||||||
|
std::string master_uri;
|
||||||
|
// port of master address
|
||||||
|
int master_port;
|
||||||
|
// port of slave process
|
||||||
|
int slave_port, nport_trial;
|
||||||
|
// reduce buffer size
|
||||||
|
size_t reduce_buffer_size;
|
||||||
|
// current rank
|
||||||
|
int rank;
|
||||||
|
// world size
|
||||||
|
int world_size;
|
||||||
|
// index of parent link, can be -1, meaning this is root of the tree
|
||||||
|
int parent_index;
|
||||||
|
// sockets of all links
|
||||||
|
std::vector<LinkRecord> links;
|
||||||
|
// select helper
|
||||||
|
utils::SelectHelper selecter;
|
||||||
|
};
|
||||||
|
|
||||||
|
// singleton sync manager
|
||||||
|
SyncManager manager;
|
||||||
|
|
||||||
|
/*! \brief intiialize the synchronization module */
|
||||||
|
void Init(int argc, char *argv[]) {
|
||||||
|
for (int i = 1; i < argc; ++i) {
|
||||||
|
char name[256], val[256];
|
||||||
|
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||||
|
manager.SetParam(name, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.Init();
|
||||||
|
}
|
||||||
|
/*! \brief finalize syncrhonization module */
|
||||||
|
void Finalize(void) {
|
||||||
|
manager.Shutdown();
|
||||||
|
}
|
||||||
|
/*! \brief singleton method to get engine */
|
||||||
|
IEngine *GetEngine(void) {
|
||||||
|
return &manager;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace engine
|
||||||
214
src/io.h
Normal file
214
src/io.h
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
#ifndef ALLREDUCE_UTILS_IO_H
|
||||||
|
#define ALLREDUCE_UTILS_IO_H
|
||||||
|
#include <cstdio>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include "./utils.h"
|
||||||
|
/*!
|
||||||
|
* \file io.h
|
||||||
|
* \brief general stream interface for serialization, I/O
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
namespace utils {
|
||||||
|
/*!
|
||||||
|
* \brief interface of stream I/O, used to serialize model
|
||||||
|
*/
|
||||||
|
class IStream {
|
||||||
|
public:
|
||||||
|
/*!
|
||||||
|
* \brief read data from stream
|
||||||
|
* \param ptr pointer to memory buffer
|
||||||
|
* \param size size of block
|
||||||
|
* \return usually is the size of data readed
|
||||||
|
*/
|
||||||
|
virtual size_t Read(void *ptr, size_t size) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief write data to stream
|
||||||
|
* \param ptr pointer to memory buffer
|
||||||
|
* \param size size of block
|
||||||
|
*/
|
||||||
|
virtual void Write(const void *ptr, size_t size) = 0;
|
||||||
|
/*! \brief virtual destructor */
|
||||||
|
virtual ~IStream(void) {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// helper functions to write various of data structures
|
||||||
|
/*!
|
||||||
|
* \brief binary serialize a vector
|
||||||
|
* \param vec vector to be serialized
|
||||||
|
*/
|
||||||
|
template<typename T>
|
||||||
|
inline void Write(const std::vector<T> &vec) {
|
||||||
|
uint64_t sz = static_cast<uint64_t>(vec.size());
|
||||||
|
this->Write(&sz, sizeof(sz));
|
||||||
|
if (sz != 0) {
|
||||||
|
this->Write(&vec[0], sizeof(T) * sz);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief binary load a vector
|
||||||
|
* \param out_vec vector to be loaded
|
||||||
|
* \return whether load is successfull
|
||||||
|
*/
|
||||||
|
template<typename T>
|
||||||
|
inline bool Read(std::vector<T> *out_vec) {
|
||||||
|
uint64_t sz;
|
||||||
|
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||||
|
out_vec->resize(sz);
|
||||||
|
if (sz != 0) {
|
||||||
|
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief binary serialize a string
|
||||||
|
* \param str the string to be serialized
|
||||||
|
*/
|
||||||
|
inline void Write(const std::string &str) {
|
||||||
|
uint64_t sz = static_cast<uint64_t>(str.length());
|
||||||
|
this->Write(&sz, sizeof(sz));
|
||||||
|
if (sz != 0) {
|
||||||
|
this->Write(&str[0], sizeof(char) * sz);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief binary load a string
|
||||||
|
* \param out_str string to be loaded
|
||||||
|
* \return whether load is successful
|
||||||
|
*/
|
||||||
|
inline bool Read(std::string *out_str) {
|
||||||
|
uint64_t sz;
|
||||||
|
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||||
|
out_str->resize(sz);
|
||||||
|
if (sz != 0) {
|
||||||
|
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief interface of se*/
|
||||||
|
class ISerializable {
|
||||||
|
/*! \brief load the model from file */
|
||||||
|
virtual void Load(IStream &fi) = 0;
|
||||||
|
/*! \brief save the model to the stream*/
|
||||||
|
virtual void Save(IStream &fo) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief interface of i/o stream that support seek */
|
||||||
|
class ISeekStream: public IStream {
|
||||||
|
public:
|
||||||
|
/*! \brief seek to certain position of the file */
|
||||||
|
virtual void Seek(size_t pos) = 0;
|
||||||
|
/*! \brief tell the position of the stream */
|
||||||
|
virtual size_t Tell(void) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief fixed size memory buffer */
|
||||||
|
struct MemoryFixSizeBuffer : public ISeekStream {
|
||||||
|
public:
|
||||||
|
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
||||||
|
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
|
||||||
|
curr_ptr_ = 0;
|
||||||
|
}
|
||||||
|
virtual ~MemoryFixSizeBuffer(void) {}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||||
|
"read can not have position excceed buffer length");
|
||||||
|
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||||
|
if (nread != 0) memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||||
|
curr_ptr_ += nread;
|
||||||
|
return nread;
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
if (size == 0) return;
|
||||||
|
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||||
|
"write position exceed fixed buffer size");
|
||||||
|
memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||||
|
curr_ptr_ += size;
|
||||||
|
}
|
||||||
|
virtual void Seek(size_t pos) {
|
||||||
|
curr_ptr_ = static_cast<size_t>(pos);
|
||||||
|
}
|
||||||
|
virtual size_t Tell(void) {
|
||||||
|
return curr_ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief in memory buffer */
|
||||||
|
char *p_buffer_;
|
||||||
|
/*! \brief current pointer */
|
||||||
|
size_t buffer_size_;
|
||||||
|
/*! \brief current pointer */
|
||||||
|
size_t curr_ptr_;
|
||||||
|
}; // class MemoryFixSizeBuffer
|
||||||
|
|
||||||
|
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||||
|
struct MemoryBufferStream : public ISeekStream {
|
||||||
|
public:
|
||||||
|
MemoryBufferStream(std::string *p_buffer)
|
||||||
|
: p_buffer_(p_buffer) {
|
||||||
|
curr_ptr_ = 0;
|
||||||
|
}
|
||||||
|
virtual ~MemoryBufferStream(void) {}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
||||||
|
"read can not have position excceed buffer length");
|
||||||
|
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
||||||
|
if (nread != 0) memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
||||||
|
curr_ptr_ += nread;
|
||||||
|
return nread;
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
if (size == 0) return;
|
||||||
|
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||||
|
p_buffer_->resize(curr_ptr_+size);
|
||||||
|
}
|
||||||
|
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||||
|
curr_ptr_ += size;
|
||||||
|
}
|
||||||
|
virtual void Seek(size_t pos) {
|
||||||
|
curr_ptr_ = static_cast<size_t>(pos);
|
||||||
|
}
|
||||||
|
virtual size_t Tell(void) {
|
||||||
|
return curr_ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief in memory buffer */
|
||||||
|
std::string *p_buffer_;
|
||||||
|
/*! \brief current pointer */
|
||||||
|
size_t curr_ptr_;
|
||||||
|
}; // class MemoryBufferStream
|
||||||
|
|
||||||
|
/*! \brief implementation of file i/o stream */
|
||||||
|
class FileStream : public ISeekStream {
|
||||||
|
public:
|
||||||
|
explicit FileStream(FILE *fp) : fp(fp) {}
|
||||||
|
explicit FileStream(void) {
|
||||||
|
this->fp = NULL;
|
||||||
|
}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
return std::fread(ptr, size, 1, fp);
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
std::fwrite(ptr, size, 1, fp);
|
||||||
|
}
|
||||||
|
virtual void Seek(size_t pos) {
|
||||||
|
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
|
||||||
|
}
|
||||||
|
virtual size_t Tell(void) {
|
||||||
|
return std::ftell(fp);
|
||||||
|
}
|
||||||
|
inline void Close(void) {
|
||||||
|
if (fp != NULL){
|
||||||
|
std::fclose(fp); fp = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FILE *fp;
|
||||||
|
};
|
||||||
|
} // namespace utils
|
||||||
|
#endif
|
||||||
387
src/socket.h
Normal file
387
src/socket.h
Normal file
@ -0,0 +1,387 @@
|
|||||||
|
#ifndef ALLREDUCE_SOCKET_H
|
||||||
|
#define ALLREDUCE_SOCKET_H
|
||||||
|
/*!
|
||||||
|
* \file socket.h
|
||||||
|
* \brief this file aims to provide a wrapper of sockets
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#include <winsock2.h>
|
||||||
|
#include <ws2tcpip.h>
|
||||||
|
#else
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
#include <errno.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <netinet/in.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <sys/select.h>
|
||||||
|
#endif
|
||||||
|
#include <string>
|
||||||
|
#include <cstring>
|
||||||
|
#include "./utils.h"
|
||||||
|
|
||||||
|
namespace utils {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
typedef int ssize_t;
|
||||||
|
typedef int sock_size_t;
|
||||||
|
#else
|
||||||
|
typedef int SOCKET;
|
||||||
|
typedef size_t sock_size_t;
|
||||||
|
const int INVALID_SOCKET = -1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*! \brief data structure for network address */
|
||||||
|
struct SockAddr {
|
||||||
|
sockaddr_in addr;
|
||||||
|
// constructor
|
||||||
|
SockAddr(void) {}
|
||||||
|
SockAddr(const char *url, int port) {
|
||||||
|
this->Set(url, port);
|
||||||
|
}
|
||||||
|
inline static std::string GetHostName(void) {
|
||||||
|
std::string buf; buf.resize(256);
|
||||||
|
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||||
|
return std::string(buf.c_str());
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief set the address
|
||||||
|
* \param url the url of the address
|
||||||
|
* \param port the port of address
|
||||||
|
*/
|
||||||
|
inline void Set(const char *host, int port) {
|
||||||
|
hostent *hp = gethostbyname(host);
|
||||||
|
Check(hp != NULL, "cannot obtain address of %s", host);
|
||||||
|
memset(&addr, 0, sizeof(addr));
|
||||||
|
addr.sin_family = AF_INET;
|
||||||
|
addr.sin_port = htons(port);
|
||||||
|
memcpy(&addr.sin_addr, hp->h_addr_list[0], hp->h_length);
|
||||||
|
}
|
||||||
|
/*! \brief return port of the address*/
|
||||||
|
inline int port(void) const {
|
||||||
|
return ntohs(addr.sin_port);
|
||||||
|
}
|
||||||
|
/*! \return a string representation of the address */
|
||||||
|
inline std::string AddrStr(void) const {
|
||||||
|
std::string buf; buf.resize(256);
|
||||||
|
#ifdef _WIN32
|
||||||
|
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, &buf[0], buf.length());
|
||||||
|
#else
|
||||||
|
const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length());
|
||||||
|
#endif
|
||||||
|
Assert(s != NULL, "cannot decode address");
|
||||||
|
return std::string(s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/*!
|
||||||
|
* \brief a wrapper of TCP socket that hopefully be cross platform
|
||||||
|
*/
|
||||||
|
class TCPSocket {
|
||||||
|
public:
|
||||||
|
/*! \brief the file descriptor of socket */
|
||||||
|
SOCKET sockfd;
|
||||||
|
// constructor
|
||||||
|
TCPSocket(void) : sockfd(INVALID_SOCKET) {
|
||||||
|
}
|
||||||
|
explicit TCPSocket(SOCKET sockfd) : sockfd(sockfd) {
|
||||||
|
}
|
||||||
|
~TCPSocket(void) {
|
||||||
|
// do nothing in destructor
|
||||||
|
// user need to take care of close
|
||||||
|
}
|
||||||
|
// default conversion to int
|
||||||
|
inline operator SOCKET() const {
|
||||||
|
return sockfd;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief create the socket, call this before using socket
|
||||||
|
* \param af domain
|
||||||
|
*/
|
||||||
|
inline void Create(int af = PF_INET) {
|
||||||
|
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
||||||
|
if (sockfd == INVALID_SOCKET) {
|
||||||
|
SockError("Create");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief start up the socket module
|
||||||
|
* call this before using the sockets
|
||||||
|
*/
|
||||||
|
inline static void Startup(void) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
WSADATA wsa_data;
|
||||||
|
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) {
|
||||||
|
SockError("Startup");
|
||||||
|
}
|
||||||
|
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||||
|
WSACleanup();
|
||||||
|
utils::Error("Could not find a usable version of Winsock.dll\n");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief shutdown the socket module after use, all sockets need to be closed
|
||||||
|
*/
|
||||||
|
inline static void Finalize(void) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
WSACleanup();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief set this socket to use non-blocking mode
|
||||||
|
* \param non_block whether set it to be non-block, if it is false
|
||||||
|
* it will set it back to block mode
|
||||||
|
*/
|
||||||
|
inline void SetNonBlock(bool non_block) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
u_long mode = non_block ? 1 : 0;
|
||||||
|
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||||
|
SockError("SetNonBlock");
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
int flag = fcntl(sockfd, F_GETFL, 0);
|
||||||
|
if (flag == -1) {
|
||||||
|
SockError("SetNonBlock-1");
|
||||||
|
}
|
||||||
|
if (non_block) {
|
||||||
|
flag |= O_NONBLOCK;
|
||||||
|
} else {
|
||||||
|
flag &= ~O_NONBLOCK;
|
||||||
|
}
|
||||||
|
if (fcntl(sockfd, F_SETFL, flag) == -1) {
|
||||||
|
SockError("SetNonBlock-2");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief perform listen of the socket
|
||||||
|
* \param backlog backlog parameter
|
||||||
|
*/
|
||||||
|
inline void Listen(int backlog = 16) {
|
||||||
|
listen(sockfd, backlog);
|
||||||
|
}
|
||||||
|
/*! \brief get a new connection */
|
||||||
|
TCPSocket Accept(void) {
|
||||||
|
SOCKET newfd = accept(sockfd, NULL, NULL);
|
||||||
|
if (newfd == INVALID_SOCKET) {
|
||||||
|
SockError("Accept");
|
||||||
|
}
|
||||||
|
return TCPSocket(newfd);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief bind the socket to an address
|
||||||
|
* \param addr
|
||||||
|
*/
|
||||||
|
inline void Bind(const SockAddr &addr) {
|
||||||
|
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) {
|
||||||
|
SockError("Bind");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief try bind the socket to host, from start_port to end_port
|
||||||
|
* \param start_port starting port number to try
|
||||||
|
* \param end_port ending port number to try
|
||||||
|
* \param out_addr the binding address, if successful
|
||||||
|
* \return whether the binding is successful
|
||||||
|
*/
|
||||||
|
inline int TryBindHost(int start_port, int end_port) {
|
||||||
|
for (int port = start_port; port < end_port; ++port) {
|
||||||
|
SockAddr addr("0.0.0.0", port);
|
||||||
|
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {
|
||||||
|
return port;
|
||||||
|
}
|
||||||
|
if (errno != EADDRINUSE) {
|
||||||
|
SockError("TryBindHost");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief connect to an address
|
||||||
|
* \param addr the address to connect to
|
||||||
|
*/
|
||||||
|
inline void Connect(const SockAddr &addr) {
|
||||||
|
if (connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) {
|
||||||
|
SockError("Connect");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*! \brief close the connection */
|
||||||
|
inline void Close(void) {
|
||||||
|
if (sockfd != -1) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
closesocket(sockfd);
|
||||||
|
#else
|
||||||
|
close(sockfd);
|
||||||
|
#endif
|
||||||
|
sockfd = INVALID_SOCKET;
|
||||||
|
} else {
|
||||||
|
Error("TCPSocket::Close double close the socket or close without create");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief send data using the socket
|
||||||
|
* \param buf the pointer to the buffer
|
||||||
|
* \param len the size of the buffer
|
||||||
|
* \param flags extra flags
|
||||||
|
* \return size of data actually sent
|
||||||
|
*/
|
||||||
|
inline size_t Send(const void *buf_, size_t len, int flag = 0) {
|
||||||
|
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||||
|
if (len == 0) return 0;
|
||||||
|
ssize_t ret = send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||||
|
if (ret == -1) {
|
||||||
|
if (errno == EAGAIN || errno == EWOULDBLOCK) return 0;
|
||||||
|
SockError("Send");
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief receive data using the socket
|
||||||
|
* \param buf_ the pointer to the buffer
|
||||||
|
* \param len the size of the buffer
|
||||||
|
* \param flags extra flags
|
||||||
|
* \return size of data actually received
|
||||||
|
*/
|
||||||
|
inline size_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||||
|
char *buf = reinterpret_cast<char*>(buf_);
|
||||||
|
if (len == 0) return 0;
|
||||||
|
ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||||
|
if (ret == -1) {
|
||||||
|
if (errno == EAGAIN || errno == EWOULDBLOCK) return 0;
|
||||||
|
SockError("Recv");
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief peform block write that will attempt to send all data out
|
||||||
|
* can still return smaller than request when error occurs
|
||||||
|
* \param buf the pointer to the buffer
|
||||||
|
* \param len the size of the buffer
|
||||||
|
* \return size of data actually sent
|
||||||
|
*/
|
||||||
|
inline size_t SendAll(const void *buf_, size_t len) {
|
||||||
|
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||||
|
size_t ndone = 0;
|
||||||
|
while (ndone < len) {
|
||||||
|
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
|
||||||
|
if (ret == -1) {
|
||||||
|
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
|
||||||
|
SockError("Recv");
|
||||||
|
}
|
||||||
|
buf += ret;
|
||||||
|
ndone += ret;
|
||||||
|
}
|
||||||
|
return ndone;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief peforma block read that will attempt to read all data
|
||||||
|
* can still return smaller than request when error occurs
|
||||||
|
* \param buf_ the buffer pointer
|
||||||
|
* \param len length of data to recv
|
||||||
|
* \return size of data actually sent
|
||||||
|
*/
|
||||||
|
inline size_t RecvAll(void *buf_, size_t len) {
|
||||||
|
char *buf = reinterpret_cast<char*>(buf_);
|
||||||
|
size_t ndone = 0;
|
||||||
|
while (ndone < len) {
|
||||||
|
ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||||
|
if (ret == -1) {
|
||||||
|
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
|
||||||
|
SockError("Recv");
|
||||||
|
}
|
||||||
|
if (ret == 0) return ndone;
|
||||||
|
buf += ret;
|
||||||
|
ndone += ret;
|
||||||
|
}
|
||||||
|
return ndone;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// report an socket error
|
||||||
|
inline static void SockError(const char *msg) {
|
||||||
|
int errsv = errno;
|
||||||
|
Error("Socket %s Error:%s", msg, strerror(errsv));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/*! \brief helper data structure to perform select */
|
||||||
|
struct SelectHelper {
|
||||||
|
public:
|
||||||
|
SelectHelper(void) {
|
||||||
|
this->Clear();
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief add file descriptor to watch for read
|
||||||
|
* \param fd file descriptor to be watched
|
||||||
|
*/
|
||||||
|
inline void WatchRead(SOCKET fd) {
|
||||||
|
read_fds.push_back(fd);
|
||||||
|
if (fd > maxfd) maxfd = fd;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief add file descriptor to watch for write
|
||||||
|
* \param fd file descriptor to be watched
|
||||||
|
*/
|
||||||
|
inline void WatchWrite(SOCKET fd) {
|
||||||
|
write_fds.push_back(fd);
|
||||||
|
if (fd > maxfd) maxfd = fd;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief Check if the descriptor is ready for read
|
||||||
|
* \param fd file descriptor to check status
|
||||||
|
*/
|
||||||
|
inline bool CheckRead(SOCKET fd) const {
|
||||||
|
return FD_ISSET(fd, &read_set) != 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief Check if the descriptor is ready for write
|
||||||
|
* \param fd file descriptor to check status
|
||||||
|
*/
|
||||||
|
inline bool CheckWrite(SOCKET fd) const {
|
||||||
|
return FD_ISSET(fd, &write_set) != 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief clear all the monitored descriptors
|
||||||
|
*/
|
||||||
|
inline void Clear(void) {
|
||||||
|
read_fds.clear();
|
||||||
|
write_fds.clear();
|
||||||
|
maxfd = 0;
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief peform select on the set defined
|
||||||
|
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
|
||||||
|
* \return number of active descriptors selected
|
||||||
|
*/
|
||||||
|
inline int Select(long timeout = 0) {
|
||||||
|
FD_ZERO(&read_set);
|
||||||
|
FD_ZERO(&write_set);
|
||||||
|
for (size_t i = 0; i < read_fds.size(); ++i) {
|
||||||
|
FD_SET(read_fds[i], &read_set);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < write_fds.size(); ++i) {
|
||||||
|
FD_SET(write_fds[i], &write_set);
|
||||||
|
}
|
||||||
|
int ret;
|
||||||
|
if (timeout == 0) {
|
||||||
|
ret = select(static_cast<int>(maxfd + 1), &read_set, &write_set, NULL, NULL);
|
||||||
|
} else {
|
||||||
|
timeval tm;
|
||||||
|
tm.tv_usec = (timeout % 1000) * 1000;
|
||||||
|
tm.tv_sec = timeout / 1000;
|
||||||
|
ret = select(static_cast<int>(maxfd + 1), &read_set, &write_set, NULL, &tm);
|
||||||
|
}
|
||||||
|
if (ret == -1) {
|
||||||
|
int errsv = errno;
|
||||||
|
Error("Select Error: %s", strerror(errsv));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SOCKET maxfd;
|
||||||
|
fd_set read_set, write_set;
|
||||||
|
std::vector<SOCKET> read_fds, write_fds;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
||||||
106
src/tcp_master.py
Normal file
106
src/tcp_master.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
Master script for xgboost, tcp_master
|
||||||
|
This script can be used to start jobs of multi-node xgboost using sync_tcp
|
||||||
|
|
||||||
|
Tianqi Chen
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import subprocess
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
class ExSocket:
|
||||||
|
def __init__(self, sock):
|
||||||
|
self.sock = sock
|
||||||
|
def recvall(self, nbytes):
|
||||||
|
res = []
|
||||||
|
sock = self.sock
|
||||||
|
nread = 0
|
||||||
|
while nread < nbytes:
|
||||||
|
chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL)
|
||||||
|
nread += len(chunk)
|
||||||
|
res.append(chunk)
|
||||||
|
return ''.join(res)
|
||||||
|
def recvint(self):
|
||||||
|
return struct.unpack('@i', self.recvall(4))[0]
|
||||||
|
def sendint(self, n):
|
||||||
|
self.sock.sendall(struct.pack('@i', n))
|
||||||
|
def sendstr(self, s):
|
||||||
|
self.sendint(len(s))
|
||||||
|
self.sock.sendall(s)
|
||||||
|
|
||||||
|
# magic number used to verify existence of data
|
||||||
|
kMagic = 0xff99
|
||||||
|
|
||||||
|
class Master:
|
||||||
|
def __init__(self, port = 9000, port_end = 9999):
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
for port in range(port, port_end):
|
||||||
|
try:
|
||||||
|
sock.bind(('', port))
|
||||||
|
self.port = port
|
||||||
|
break
|
||||||
|
except socket.error:
|
||||||
|
continue
|
||||||
|
sock.listen(16)
|
||||||
|
self.sock = sock
|
||||||
|
print 'start listen on %s:%d' % (socket.gethostname(), self.port)
|
||||||
|
def __del__(self):
|
||||||
|
self.sock.close()
|
||||||
|
def slave_args(self):
|
||||||
|
return ['master_uri=%s' % socket.gethostname(),
|
||||||
|
'master_port=%s' % self.port]
|
||||||
|
def accept_slaves(self, nslave):
|
||||||
|
slave_addrs = []
|
||||||
|
for rank in range(nslave):
|
||||||
|
while True:
|
||||||
|
fd, s_addr = self.sock.accept()
|
||||||
|
slave = ExSocket(fd)
|
||||||
|
nparent = int(rank != 0)
|
||||||
|
nchild = 0
|
||||||
|
if (rank + 1) * 2 - 1 < nslave:
|
||||||
|
nchild += 1
|
||||||
|
if (rank + 1) * 2 < nslave:
|
||||||
|
nchild += 1
|
||||||
|
try:
|
||||||
|
magic = slave.recvint()
|
||||||
|
if magic != kMagic:
|
||||||
|
print 'invalid magic number=%d from %s' % (magic, s_addr[0])
|
||||||
|
slave.sock.close()
|
||||||
|
continue
|
||||||
|
except socket.error:
|
||||||
|
print 'sock error in %s' % (s_addr[0])
|
||||||
|
slave.sock.close()
|
||||||
|
continue
|
||||||
|
slave.sendint(kMagic)
|
||||||
|
slave.sendint(rank)
|
||||||
|
slave.sendint(nslave)
|
||||||
|
slave.sendint(nparent)
|
||||||
|
slave.sendint(nchild)
|
||||||
|
if nparent != 0:
|
||||||
|
parent_index = (rank + 1) / 2 - 1
|
||||||
|
ptuple = slave_addrs[parent_index]
|
||||||
|
slave.sendstr(ptuple[0])
|
||||||
|
slave.sendint(ptuple[1])
|
||||||
|
s_port = slave.recvint()
|
||||||
|
assert rank == len(slave_addrs)
|
||||||
|
slave_addrs.append((s_addr[0], s_port))
|
||||||
|
slave.sock.close()
|
||||||
|
print 'finish starting rank=%d at %s' % (rank, s_addr[0])
|
||||||
|
break
|
||||||
|
print 'all slaves setup complete'
|
||||||
|
|
||||||
|
def mpi_submit(nslave, args):
|
||||||
|
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
||||||
|
print cmd
|
||||||
|
return subprocess.check_call(cmd, shell = True)
|
||||||
|
|
||||||
|
def submit(nslave, args, fun_submit = mpi_submit):
|
||||||
|
master = Master()
|
||||||
|
submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args()))
|
||||||
|
submit_thread.start()
|
||||||
|
master.accept_slaves(nslave)
|
||||||
|
submit_thread.join()
|
||||||
176
src/utils.h
Normal file
176
src/utils.h
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
#ifndef ALLREDUCE_UTILS_H_
|
||||||
|
#define ALLREDUCE_UTILS_H_
|
||||||
|
/*!
|
||||||
|
* \file utils.h
|
||||||
|
* \brief simple utils to support the code
|
||||||
|
* \author Tianqi Chen
|
||||||
|
*/
|
||||||
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#ifndef ALLREDUCE_STRICT_CXX98_
|
||||||
|
#include <cstdarg>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if !defined(__GNUC__)
|
||||||
|
#define fopen64 std::fopen
|
||||||
|
#endif
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
// NOTE: sprintf_s is not equivalent to snprintf,
|
||||||
|
// they are equivalent when success, which is sufficient for our case
|
||||||
|
#define snprintf sprintf_s
|
||||||
|
#define vsnprintf vsprintf_s
|
||||||
|
#else
|
||||||
|
#ifdef _FILE_OFFSET_BITS
|
||||||
|
#if _FILE_OFFSET_BITS == 32
|
||||||
|
#pragma message ("Warning: FILE OFFSET BITS defined to be 32 bit")
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#define off64_t off_t
|
||||||
|
#define fopen64 std::fopen
|
||||||
|
#endif
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
#include <sys/types.h>
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
typedef unsigned char uint8_t;
|
||||||
|
typedef unsigned short int uint16_t;
|
||||||
|
typedef unsigned int uint32_t;
|
||||||
|
typedef unsigned long uint64_t;
|
||||||
|
typedef long int64_t;
|
||||||
|
#else
|
||||||
|
#include <inttypes.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*! \brief namespace for helper utils of the project */
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
/*! \brief error message buffer length */
|
||||||
|
const int kPrintBuffer = 1 << 12;
|
||||||
|
|
||||||
|
#ifndef ALLREDUCE_CUSTOMIZE_MSG_
|
||||||
|
/*!
|
||||||
|
* \brief handling of Assert error, caused by in-apropriate input
|
||||||
|
* \param msg error message
|
||||||
|
*/
|
||||||
|
inline void HandleAssertError(const char *msg) {
|
||||||
|
fprintf(stderr, "AssertError:%s\n", msg);
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief handling of Check error, caused by in-apropriate input
|
||||||
|
* \param msg error message
|
||||||
|
*/
|
||||||
|
inline void HandleCheckError(const char *msg) {
|
||||||
|
fprintf(stderr, "%s\n", msg);
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
inline void HandlePrint(const char *msg) {
|
||||||
|
printf("%s", msg);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
#ifndef ALLREDUCE_STRICT_CXX98_
|
||||||
|
// include declarations, some one must implement this
|
||||||
|
void HandleAssertError(const char *msg);
|
||||||
|
void HandleCheckError(const char *msg);
|
||||||
|
void HandlePrint(const char *msg);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#ifdef ALLREDUCE_STRICT_CXX98_
|
||||||
|
// these function pointers are to be assigned
|
||||||
|
extern "C" void (*Printf)(const char *fmt, ...);
|
||||||
|
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
|
||||||
|
extern "C" void (*Assert)(int exp, const char *fmt, ...);
|
||||||
|
extern "C" void (*Check)(int exp, const char *fmt, ...);
|
||||||
|
extern "C" void (*Error)(const char *fmt, ...);
|
||||||
|
#else
|
||||||
|
/*! \brief printf, print message to the console */
|
||||||
|
inline void Printf(const char *fmt, ...) {
|
||||||
|
std::string msg(kPrintBuffer, '\0');
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
HandlePrint(msg.c_str());
|
||||||
|
}
|
||||||
|
/*! \brief portable version of snprintf */
|
||||||
|
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
int ret = vsnprintf(buf, size, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief assert an condition is true, use this to handle debug information */
|
||||||
|
inline void Assert(bool exp, const char *fmt, ...) {
|
||||||
|
if (!exp) {
|
||||||
|
std::string msg(kPrintBuffer, '\0');
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
HandleAssertError(msg.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!\brief same as assert, but this is intended to be used as message for user*/
|
||||||
|
inline void Check(bool exp, const char *fmt, ...) {
|
||||||
|
if (!exp) {
|
||||||
|
std::string msg(kPrintBuffer, '\0');
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
HandleCheckError(msg.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief report error message, same as check */
|
||||||
|
inline void Error(const char *fmt, ...) {
|
||||||
|
{
|
||||||
|
std::string msg(kPrintBuffer, '\0');
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
HandleCheckError(msg.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*! \brief replace fopen, report error when the file open fails */
|
||||||
|
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
|
||||||
|
std::FILE *fp = fopen64(fname, flag);
|
||||||
|
Check(fp != NULL, "can not open file \"%s\"\n", fname);
|
||||||
|
return fp;
|
||||||
|
}
|
||||||
|
} // namespace utils
|
||||||
|
// easy utils that can be directly acessed in xgboost
|
||||||
|
/*! \brief get the beginning address of a vector */
|
||||||
|
template<typename T>
|
||||||
|
inline T *BeginPtr(std::vector<T> &vec) {
|
||||||
|
if (vec.size() == 0) {
|
||||||
|
return NULL;
|
||||||
|
} else {
|
||||||
|
return &vec[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*! \brief get the beginning address of a vector */
|
||||||
|
template<typename T>
|
||||||
|
inline const T *BeginPtr(const std::vector<T> &vec) {
|
||||||
|
if (vec.size() == 0) {
|
||||||
|
return NULL;
|
||||||
|
} else {
|
||||||
|
return &vec[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // ALLREDUCE_UTILS_H_
|
||||||
36
submit_job_tcp.py
Executable file
36
submit_job_tcp.py
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
"""
|
||||||
|
This is an example script to create a customized job submit
|
||||||
|
script using xgboost sync_tcp mode
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
# import the tcp_master.py
|
||||||
|
# add path to sync
|
||||||
|
sys.path.append(os.path.dirname(__file__)+'/src/')
|
||||||
|
import tcp_master as master
|
||||||
|
|
||||||
|
#
|
||||||
|
# Note: this submit script is only used for example purpose
|
||||||
|
# It does not have to be mpirun, it can be any job submission script that starts the job, qsub, hadoop streaming etc.
|
||||||
|
#
|
||||||
|
def mpi_submit(nslave, args):
|
||||||
|
"""
|
||||||
|
customized submit script, that submit nslave jobs, each must contain args as parameter
|
||||||
|
note this can be a lambda function containing additional parameters in input
|
||||||
|
Parameters
|
||||||
|
nslave number of slave process to start up
|
||||||
|
args arguments to launch each job
|
||||||
|
this usually includes the parameters of master_uri and parameters passed into submit
|
||||||
|
"""
|
||||||
|
cmd = ' '.join(['mpirun -n %d' % nslave] + args)
|
||||||
|
print cmd
|
||||||
|
subprocess.check_call(cmd, shell = True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print 'Usage: <nslave> <cmd>'
|
||||||
|
exit(0)
|
||||||
|
# call submit, with nslave, the commands to run each job and submit function
|
||||||
|
master.submit(int(sys.argv[1]), sys.argv[2:], fun_submit= mpi_submit)
|
||||||
33
test/Makefile
Normal file
33
test/Makefile
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
export CC = gcc
|
||||||
|
export CXX = g++
|
||||||
|
export MPICXX = mpicxx
|
||||||
|
export LDFLAGS= -pthread -lm
|
||||||
|
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src
|
||||||
|
|
||||||
|
ifeq ($(no_omp),1)
|
||||||
|
CFLAGS += -DDISABLE_OPENMP
|
||||||
|
else
|
||||||
|
CFLAGS += -fopenmp
|
||||||
|
endif
|
||||||
|
|
||||||
|
# specify tensor path
|
||||||
|
BIN = test_allreduce
|
||||||
|
OBJ = engine_tcp.o
|
||||||
|
.PHONY: clean all
|
||||||
|
|
||||||
|
all: $(BIN) $(MPIBIN)
|
||||||
|
|
||||||
|
engine_tcp.o: ../src/engine_tcp.cpp ../src/*.h
|
||||||
|
test_allreduce: test_allreduce.cpp ../src/*.h engine_tcp.o
|
||||||
|
|
||||||
|
$(BIN) :
|
||||||
|
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||||
|
|
||||||
|
$(OBJ) :
|
||||||
|
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) )
|
||||||
|
|
||||||
|
$(MPIBIN) :
|
||||||
|
$(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||||
|
|
||||||
|
clean:
|
||||||
|
$(RM) $(BIN) $(MPIBIN) *~ ../src/*~
|
||||||
7
test/test.sh
Executable file
7
test/test.sh
Executable file
@ -0,0 +1,7 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
if [ "$#" -ne 2 ];
|
||||||
|
then
|
||||||
|
echo "Usage <nslave> <ndata>"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
../submit_job_tcp.py $1 test_allreduce $2
|
||||||
80
test/test_allreduce.cpp
Normal file
80
test/test_allreduce.cpp
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#include <allreduce.h>
|
||||||
|
#include <utils.h>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
using namespace sync;
|
||||||
|
|
||||||
|
inline void TestMax(size_t n) {
|
||||||
|
int rank = sync::GetRank();
|
||||||
|
int nproc = sync::GetWorldSize();
|
||||||
|
|
||||||
|
std::vector<float> ndata(n);
|
||||||
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
|
ndata[i] = (i * (rank+1)) % 111;
|
||||||
|
}
|
||||||
|
sync::AllReduce<op::Max>(&ndata[0], ndata.size());
|
||||||
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
|
float rmax = (i * 1) % 111;
|
||||||
|
for (int r = 0; r < nproc; ++r) {
|
||||||
|
rmax = std::max(rmax, (float)((i * (r+1)) % 111));
|
||||||
|
}
|
||||||
|
utils::Check(rmax == ndata[i], "[%d] TestMax check failure", rank);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TestSum(size_t n) {
|
||||||
|
int rank = sync::GetRank();
|
||||||
|
int nproc = sync::GetWorldSize();
|
||||||
|
const int z = 131;
|
||||||
|
|
||||||
|
std::vector<float> ndata(n);
|
||||||
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
|
ndata[i] = (i * (rank+1)) % z;
|
||||||
|
}
|
||||||
|
sync::AllReduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
|
float rsum = 0.0f;
|
||||||
|
for (int r = 0; r < nproc; ++r) {
|
||||||
|
rsum += (float)((i * (r+1)) % z);
|
||||||
|
}
|
||||||
|
utils::Check(fabsf(rsum - ndata[i]) < 1e-5 ,
|
||||||
|
"[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TestBcast(size_t n, int root) {
|
||||||
|
int rank = sync::GetRank();
|
||||||
|
std::string s; s.resize(n);
|
||||||
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
s[i] = char(i % 126 + 1);
|
||||||
|
}
|
||||||
|
std::string res;
|
||||||
|
if (root == rank) {
|
||||||
|
res = s;
|
||||||
|
sync::Bcast(&res, root);
|
||||||
|
} else {
|
||||||
|
sync::Bcast(&res, root);
|
||||||
|
}
|
||||||
|
utils::Check(res == s, "[%d] TestBcast fail", rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char *argv[]) {
|
||||||
|
if (argc < 2) {
|
||||||
|
printf("Usage: <ndata>\n");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int n = atoi(argv[1]);
|
||||||
|
sync::Init(argc, argv);
|
||||||
|
int rank = sync::GetRank();
|
||||||
|
std::string name = sync::GetProcessorName();
|
||||||
|
printf("[%d] start at %s\n", rank, name.c_str());
|
||||||
|
TestMax(n);
|
||||||
|
printf("[%d] TestMax pass\n", rank);
|
||||||
|
TestSum(n);
|
||||||
|
printf("[%d] TestSum pass\n", rank);
|
||||||
|
sync::Finalize();
|
||||||
|
printf("[%d] all check pass\n", rank);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user