diff --git a/.gitignore b/.gitignore index b8bd0267b..2922a01e6 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,6 @@ *.exe *.out *.app +*~ +*.pyc +test \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000..d4fa97339 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +AllReduce Abstraction +==== +* Tianqi, Nacho, Tianyi + +Go! + diff --git a/src/allreduce.h b/src/allreduce.h new file mode 100644 index 000000000..c9bd0e579 --- /dev/null +++ b/src/allreduce.h @@ -0,0 +1,107 @@ +/*! + * \file allreduce.h + * \brief This file defines a template wrapper of engine to ensure + * \author Tianqi Chen, Nacho, Tianyi + */ +#include "./engine.h" + +/*! \brief namespace of all reduce */ +namespace sync { +/*! \brief namespace of operator */ +namespace op { +struct Max { + template + inline static void Reduce(DType &dst, const DType &src) { + if (dst < src) dst = src; + } +}; +struct Sum { + template + inline static void Reduce(DType &dst, const DType &src) { + dst += src; + } +}; +struct BitOR { + template + inline static void Reduce(DType &dst, const DType &src) { + dst |= src; + } +}; +template +inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { + const DType *src = (const DType*)src_; + DType *dst = (DType*)dst_; + for (int i = 0; i < len; ++i) { + OP::Reduce(dst[i], src[i]); + } +} +} // namespace op + +void Init(int argc, char *argv[]) { + engine::Init(argc, argv); +} +void Finalize(void) { + engine::Finalize(); +} +/*! \brief get rank of current process */ +inline int GetRank(void) { + return engine::GetEngine()->GetRank(); +} +/*! \brief get total number of process */ +int GetWorldSize(void) { + return engine::GetEngine()->GetWorldSize(); +} +/*! \brief get name of processor */ +std::string GetProcessorName(void) { + return engine::GetEngine()->GetHost(); +} +/*! + * \brief broadcast an std::string to all others from root + * \param sendrecv_data the pointer to send or recive buffer, + * receive buffer does not need to be pre-allocated + * and string will be resized to correct length + * \param root the root of process + */ +inline void Bcast(std::string *sendrecv_data, int root) { + engine::IEngine *e = engine::GetEngine(); + unsigned len = static_cast(sendrecv_data->length()); + e->Broadcast(&len, sizeof(len), root); + sendrecv_data->resize(len); + if (len != 0) { + e->Broadcast(&(*sendrecv_data)[0], len, root); + } +} +/*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * Example Usage: the following code gives sum of the result + * vector data(10); + * ... + * AllReduce(&data[0], data.size()); + * ... + * \param sendrecvbuf buffer for both sending and recving data + * \param count number of elements to be reduced + * \tparam OP see namespace op, reduce operator + * \tparam DType type of data + */ +template +inline void AllReduce(DType *sendrecvbuf, size_t count) { + engine::GetEngine()->AllReduce(sendrecvbuf, sizeof(DType), count, op::Reducer); +} +/*! + * \brief load latest check point + * \param p_model pointer to the model + * \return true if there was stored checkpoint and load was successful + * false if there was no stored checkpoint, means we are start over gain + */ +inline bool LoadCheckPoint(utils::ISerializable *p_model) { + return engine::GetEngine()->LoadCheckPoint(p_model); +} +/*! + * \brief checkpoint the model, meaning we finished a stage of execution + * \param p_model pointer to the model + */ +inline void CheckPoint(const utils::ISerializable &model) { + engine::GetEngine()->CheckPoint(model); +} +} // namespace allreduce diff --git a/src/engine.h b/src/engine.h new file mode 100644 index 000000000..ca928b22a --- /dev/null +++ b/src/engine.h @@ -0,0 +1,80 @@ +#ifndef ALLREDUCE_ENGINE_H +#define ALLREDUCE_ENGINE_H +/*! + * \file engine.h + * \brief This file defines the interface of allreduce library + * \author Tianqi Chen, Nacho, Tianyi + */ +#include "./io.h" + +namespace MPI { +/*! \brief MPI data type just to be compatible with MPI reduce function*/ +class Datatype; +} + +/*! \brief namespace of allreduce functionality */ +namespace engine { +/*! \brief interface of core AllReduce engine */ +class IEngine { + public: + /*! + * \brief reduce function, the same form of MPI reduce function is used, + * to be compatible with MPI interface + * In all the functions, the memory is ensured to aligned to 64-bit + * which means it is OK to cast src,dst to double* int* etc + * \param src pointer to source space + * \param dst pointer to destination reduction + * \param count total number of elements to be reduced(note this is total number of elements instead of bytes) + * the definition of reduce function should be type aware + * \param dtype the data type object, to be compatible with MPI reduce + */ + typedef void (ReduceFunction) (const void *src, + void *dst, int count, + const MPI::Datatype &dtype); + /*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_n4bytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + */ + virtual void AllReduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) = 0; + /*! + * \brief broadcast data from root to all nodes + * \param sendrecvbuf_ buffer for both sending and recving data + * \param size the size of the data to be broadcasted + * \param root the root worker id to broadcast the data + */ + virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0; + /*! + * \brief load latest check point + * \param p_model pointer to the model + * \return true if there was stored checkpoint and load was successful + * false if there was no stored checkpoint, means we are start over gain + */ + virtual bool LoadCheckPoint(utils::ISerializable *p_model) = 0; + /*! + * \brief checkpoint the model, meaning we finished a stage of execution + * \param p_model pointer to the model + */ + virtual void CheckPoint(const utils::ISerializable &model) = 0; + /*! \brief get rank of current node */ + virtual int GetRank(void) const = 0; + /*! \brief get total number of */ + virtual int GetWorldSize(void) const = 0; + /*! \brief get the host name of current node */ + virtual std::string GetHost(void) const = 0; +}; + +/*! \brief intiialize the engine module */ +void Init(int argc, char *argv[]); +/*! \brief finalize engine module */ +void Finalize(void); +/*! \brief singleton method to get engine */ +IEngine *GetEngine(void); +} // namespace engine +#endif // ALLREDUCE_ENGINE_H diff --git a/src/engine_tcp.cpp b/src/engine_tcp.cpp new file mode 100644 index 000000000..a0506129d --- /dev/null +++ b/src/engine_tcp.cpp @@ -0,0 +1,448 @@ +/*! + * \file engine_tcp.cpp + * \brief implementation of sync AllReduce using TCP sockets + * with use non-block socket and tree-shape reduction + * \author Tianqi Chen + */ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE +#define NOMINMAX +#include +#include +#include +#include "./engine.h" +#include "./socket.h" + +namespace MPI { +class Datatype { + public: + size_t type_size; + Datatype(size_t type_size) : type_size(type_size) {} +}; +} +namespace engine { +/*! \brief implementation of sync goes to here */ +class SyncManager : public IEngine { + public: + const static int kMagic = 0xff99; + SyncManager(void) { + master_uri = "NULL"; + master_port = 9000; + host_uri = ""; + slave_port = 9010; + nport_trial = 1000; + rank = 0; + world_size = 1; + this->SetParam("reduce_buffer", "256MB"); + } + ~SyncManager(void) { + } + inline void Shutdown(void) { + for (size_t i = 0; i < links.size(); ++i) { + links[i].sock.Close(); + } + links.clear(); + utils::TCPSocket::Finalize(); + } + /*! \brief set parameters to the sync manager */ + inline void SetParam(const char *name, const char *val) { + if (!strcmp(name, "master_uri")) master_uri = val; + if (!strcmp(name, "master_port")) master_port = atoi(val); + if (!strcmp(name, "reduce_buffer")) { + char unit; + unsigned long amount; + if (sscanf(val, "%lu%c", &amount, &unit) == 2) { + switch (unit) { + case 'B': reduce_buffer_size = (amount + 7)/ 8; break; + case 'K': reduce_buffer_size = amount << 7UL; break; + case 'M': reduce_buffer_size = amount << 17UL; break; + case 'G': reduce_buffer_size = amount << 27UL; break; + default: utils::Error("invalid format for reduce buffer"); + } + } else { + utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); + } + } + } + // initialize the manager + inline void Init(void) { + utils::TCPSocket::Startup(); + // single node mode + if (master_uri == "NULL") return; + utils::Assert(links.size() == 0, "can only call Init once"); + int magic = kMagic; + int nchild = 0, nparent = 0; + this->host_uri = utils::SockAddr::GetHostName(); + // get information from master + utils::TCPSocket master; + master.Create(); + master.Connect(utils::SockAddr(master_uri.c_str(), master_port)); + utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 1"); + utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 2"); + utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); + utils::Assert(master.RecvAll(&rank, sizeof(rank)) == sizeof(rank), "sync::Init failure 3"); + utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "sync::Init failure 4"); + utils::Assert(master.RecvAll(&nparent, sizeof(nparent)) == sizeof(nparent), "sync::Init failure 5"); + utils::Assert(master.RecvAll(&nchild, sizeof(nchild)) == sizeof(nchild), "sync::Init failure 6"); + utils::Assert(nchild >= 0, "in correct number of childs"); + utils::Assert(nparent == 1 || nparent == 0, "in correct number of parent"); + + // create listen + utils::TCPSocket sock_listen; + sock_listen.Create(); + int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); + utils::Check(port != -1, "sync::Init fail to bind the ports specified"); + sock_listen.Listen(); + + if (nparent != 0) { + parent_index = 0; + links.push_back(LinkRecord()); + int len, hport; + std::string hname; + utils::Assert(master.RecvAll(&len, sizeof(len)) == sizeof(len), "sync::Init failure 9"); + hname.resize(len); + utils::Assert(len != 0, "string must not be empty"); + utils::Assert(master.RecvAll(&hname[0], len) == static_cast(len), "sync::Init failure 10"); + utils::Assert(master.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "sync::Init failure 11"); + links[0].sock.Create(); + links[0].sock.Connect(utils::SockAddr(hname.c_str(), hport)); + utils::Assert(links[0].sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 12"); + utils::Assert(links[0].sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 13"); + utils::Check(magic == kMagic, "sync::Init failure, parent magic number mismatch"); + parent_index = 0; + } else { + parent_index = -1; + } + // send back socket listening port to master + utils::Assert(master.SendAll(&port, sizeof(port)) == sizeof(port), "sync::Init failure 14"); + // close connection to master + master.Close(); + // accept links from childs + for (int i = 0; i < nchild; ++i) { + LinkRecord r; + while (true) { + r.sock = sock_listen.Accept(); + if (r.sock.RecvAll(&magic, sizeof(magic)) == sizeof(magic) && magic == kMagic) { + utils::Assert(r.sock.SendAll(&magic, sizeof(magic)) == sizeof(magic), "sync::Init failure 15"); + break; + } else { + // not a valid child + r.sock.Close(); + } + } + links.push_back(r); + } + // close listening sockets + sock_listen.Close(); + // setup selecter + selecter.Clear(); + for (size_t i = 0; i < links.size(); ++i) { + // set the socket to non-blocking mode + links[i].sock.SetNonBlock(true); + selecter.WatchRead(links[i].sock); + selecter.WatchWrite(links[i].sock); + } + // done + } + /*! \brief get rank */ + virtual int GetRank(void) const { + return rank; + } + /*! \brief get rank */ + virtual int GetWorldSize(void) const { + return world_size; + } + /*! \brief get rank */ + virtual std::string GetHost(void) const { + return host_uri; + } + /*! + * \brief perform in-place allreduce, on sendrecvbuf + * this function is NOT thread-safe + * \param sendrecvbuf_ buffer for both sending and recving data + * \param type_n4bytes the unit number of bytes the type have + * \param count number of elements to be reduced + * \param reducer reduce function + */ + virtual void AllReduce(void *sendrecvbuf_, + size_t type_nbytes, + size_t count, + ReduceFunction reducer) { + if (links.size() == 0) return; + // total size of message + const size_t total_size = type_nbytes * count; + // number of links + const int nlink = static_cast(links.size()); + // send recv buffer + char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); + // size of space that we already performs reduce in up pass + size_t size_up_reduce = 0; + // size of space that we have already passed to parent + size_t size_up_out = 0; + // size of message we received, and send in the down pass + size_t size_down_in = 0; + + // initialize the link ring-buffer and pointer + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + links[i].InitBuffer(type_nbytes, count, reduce_buffer_size); + } + links[i].ResetSize(); + } + // if no childs, no need to reduce + if (nlink == static_cast(parent_index != -1)) { + size_up_reduce = total_size; + } + + // while we have not passed the messages out + while(true) { + selecter.Select(); + // read data from childs + for (int i = 0; i < nlink; ++i) { + if (i != parent_index && selecter.CheckRead(links[i].sock)) { + links[i].ReadToRingBuffer(size_up_out); + } + } + // this node have childs, peform reduce + if (nlink > static_cast(parent_index != -1)) { + size_t buffer_size = 0; + // do upstream reduce + size_t max_reduce = total_size; + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + max_reduce= std::min(max_reduce, links[i].size_read); + utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size, + "buffer size inconsistent"); + buffer_size = links[i].buffer_size; + } + } + utils::Assert(buffer_size != 0, "must assign buffer_size"); + // round to type_n4bytes + max_reduce = (max_reduce / type_nbytes * type_nbytes); + // peform reduce, can be at most two rounds + while (size_up_reduce < max_reduce) { + // start position + size_t start = size_up_reduce % buffer_size; + // peform read till end of buffer + size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce); + utils::Assert(nread % type_nbytes == 0, "AllReduce: size check"); + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + reducer(links[i].buffer_head + start, + sendrecvbuf + size_up_reduce, + static_cast(nread / type_nbytes), + MPI::Datatype(type_nbytes)); + } + } + size_up_reduce += nread; + } + } + if (parent_index != -1) { + // pass message up to parent, can pass data that are already been reduced + if (selecter.CheckWrite(links[parent_index].sock)) { + size_up_out += links[parent_index].sock. + Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); + } + // read data from parent + if (selecter.CheckRead(links[parent_index].sock)) { + size_down_in += links[parent_index].sock. + Recv(sendrecvbuf + size_down_in, total_size - size_down_in); + utils::Assert(size_down_in <= size_up_out, "AllReduce: boundary error"); + } + } else { + // this is root, can use reduce as most recent point + size_down_in = size_up_out = size_up_reduce; + } + // check if we finished the job of message passing + size_t nfinished = size_down_in; + // can pass message down to childs + for (int i = 0; i < nlink; ++i) { + if (i != parent_index) { + if (selecter.CheckWrite(links[i].sock)) { + links[i].WriteFromArray(sendrecvbuf, size_down_in); + } + nfinished = std::min(links[i].size_write, nfinished); + } + } + // check boundary condition + if (nfinished >= total_size) break; + } + } + /*! + * \brief broadcast data from root to all nodes + * \param sendrecvbuf_ buffer for both sending and recving data + * \param size the size of the data to be broadcasted + * \param root the root worker id to broadcast the data + */ + virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { + if (links.size() == 0) return; + // number of links + const int nlink = static_cast(links.size()); + // size of space already read from data + size_t size_in = 0; + // input link, -2 means unknown yet, -1 means this is root + int in_link = -2; + + // initialize the link statistics + for (int i = 0; i < nlink; ++i) { + links[i].ResetSize(); + } + // root have all the data + if (this->rank == root) { + size_in = total_size; + in_link = -1; + } + + // while we have not passed the messages out + while(true) { + selecter.Select(); + if (in_link == -2) { + // probe in-link + for (int i = 0; i < nlink; ++i) { + if (selecter.CheckRead(links[i].sock)) { + links[i].ReadToArray(sendrecvbuf_, total_size); + size_in = links[i].size_read; + if (size_in != 0) { + in_link = i; break; + } + } + } + } else { + // read from in link + if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { + links[in_link].ReadToArray(sendrecvbuf_, total_size); + size_in = links[in_link].size_read; + } + } + size_t nfinished = total_size; + // send data to all out-link + for (int i = 0; i < nlink; ++i) { + if (i != in_link) { + if (selecter.CheckWrite(links[i].sock)) { + links[i].WriteFromArray(sendrecvbuf_, size_in); + } + nfinished = std::min(nfinished, links[i].size_write); + } + } + // check boundary condition + if (nfinished >= total_size) break; + } + } + virtual bool LoadCheckPoint(utils::ISerializable *p_model) { + return false; + } + virtual void CheckPoint(const utils::ISerializable &model) { + } + + private: + // an independent child record + struct LinkRecord { + public: + // socket to get data from/to link + utils::TCPSocket sock; + // size of data readed from link + size_t size_read; + // size of data sent to the link + size_t size_write; + // pointer to buffer head + char *buffer_head; + // buffer size, in bytes + size_t buffer_size; + // initialize buffer + inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { + size_t n = (type_nbytes * count + 7)/ 8; + buffer_.resize(std::min(reduce_buffer_size, n)); + // make sure align to type_nbytes + buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; + utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size); + // set buffer head + buffer_head = reinterpret_cast(BeginPtr(buffer_)); + } + // reset the recv and sent size + inline void ResetSize(void) { + size_write = size_read = 0; + } + /*! + * \brief read data into ring-buffer, with care not to existing useful override data + * position after protect_start + * \param protect_start all data start from protect_start is still needed in buffer + * read shall not override this + */ + inline void ReadToRingBuffer(size_t protect_start) { + size_t ngap = size_read - protect_start; + utils::Assert(ngap <= buffer_size, "AllReduce: boundary check"); + size_t offset = size_read % buffer_size; + size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); + size_read += sock.Recv(buffer_head + offset, nmax); + } + /*! + * \brief read data into array, + * this function can not be used together with ReadToRingBuffer + * a link can either read into the ring buffer, or existing array + * \param max_size maximum size of array + */ + inline void ReadToArray(void *recvbuf_, size_t max_size) { + char *p = static_cast(recvbuf_); + size_read += sock.Recv(p + size_read, max_size - size_read); + } + /*! + * \brief write data in array to sock + * \param sendbuf_ head of array + * \param max_size maximum size of array + */ + inline void WriteFromArray(const void *sendbuf_, size_t max_size) { + const char *p = static_cast(sendbuf_); + size_write += sock.Send(p + size_write, max_size - size_write); + } + + private: + // recv buffer to get data from child + // aligned with 64 bits, will be able to perform 64 bits operations freely + std::vector buffer_; + }; + //------------------ + // uri of current host, to be set by Init + std::string host_uri; + // uri of master + std::string master_uri; + // port of master address + int master_port; + // port of slave process + int slave_port, nport_trial; + // reduce buffer size + size_t reduce_buffer_size; + // current rank + int rank; + // world size + int world_size; + // index of parent link, can be -1, meaning this is root of the tree + int parent_index; + // sockets of all links + std::vector links; + // select helper + utils::SelectHelper selecter; +}; + +// singleton sync manager +SyncManager manager; + +/*! \brief intiialize the synchronization module */ +void Init(int argc, char *argv[]) { + for (int i = 1; i < argc; ++i) { + char name[256], val[256]; + if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { + manager.SetParam(name, val); + } + } + manager.Init(); +} +/*! \brief finalize syncrhonization module */ +void Finalize(void) { + manager.Shutdown(); +} +/*! \brief singleton method to get engine */ +IEngine *GetEngine(void) { + return &manager; +} + +} // namespace engine diff --git a/src/io.h b/src/io.h new file mode 100644 index 000000000..97a33f163 --- /dev/null +++ b/src/io.h @@ -0,0 +1,214 @@ +#ifndef ALLREDUCE_UTILS_IO_H +#define ALLREDUCE_UTILS_IO_H +#include +#include +#include +#include +#include "./utils.h" +/*! + * \file io.h + * \brief general stream interface for serialization, I/O + * \author Tianqi Chen + */ +namespace utils { +/*! + * \brief interface of stream I/O, used to serialize model + */ +class IStream { + public: + /*! + * \brief read data from stream + * \param ptr pointer to memory buffer + * \param size size of block + * \return usually is the size of data readed + */ + virtual size_t Read(void *ptr, size_t size) = 0; + /*! + * \brief write data to stream + * \param ptr pointer to memory buffer + * \param size size of block + */ + virtual void Write(const void *ptr, size_t size) = 0; + /*! \brief virtual destructor */ + virtual ~IStream(void) {} + + public: + // helper functions to write various of data structures + /*! + * \brief binary serialize a vector + * \param vec vector to be serialized + */ + template + inline void Write(const std::vector &vec) { + uint64_t sz = static_cast(vec.size()); + this->Write(&sz, sizeof(sz)); + if (sz != 0) { + this->Write(&vec[0], sizeof(T) * sz); + } + } + /*! + * \brief binary load a vector + * \param out_vec vector to be loaded + * \return whether load is successfull + */ + template + inline bool Read(std::vector *out_vec) { + uint64_t sz; + if (this->Read(&sz, sizeof(sz)) == 0) return false; + out_vec->resize(sz); + if (sz != 0) { + if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false; + } + return true; + } + /*! + * \brief binary serialize a string + * \param str the string to be serialized + */ + inline void Write(const std::string &str) { + uint64_t sz = static_cast(str.length()); + this->Write(&sz, sizeof(sz)); + if (sz != 0) { + this->Write(&str[0], sizeof(char) * sz); + } + } + /*! + * \brief binary load a string + * \param out_str string to be loaded + * \return whether load is successful + */ + inline bool Read(std::string *out_str) { + uint64_t sz; + if (this->Read(&sz, sizeof(sz)) == 0) return false; + out_str->resize(sz); + if (sz != 0) { + if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false; + } + return true; + } +}; + +/*! \brief interface of se*/ +class ISerializable { + /*! \brief load the model from file */ + virtual void Load(IStream &fi) = 0; + /*! \brief save the model to the stream*/ + virtual void Save(IStream &fo) const = 0; +}; + +/*! \brief interface of i/o stream that support seek */ +class ISeekStream: public IStream { + public: + /*! \brief seek to certain position of the file */ + virtual void Seek(size_t pos) = 0; + /*! \brief tell the position of the stream */ + virtual size_t Tell(void) = 0; +}; + +/*! \brief fixed size memory buffer */ +struct MemoryFixSizeBuffer : public ISeekStream { + public: + MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) + : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) { + curr_ptr_ = 0; + } + virtual ~MemoryFixSizeBuffer(void) {} + virtual size_t Read(void *ptr, size_t size) { + utils::Assert(curr_ptr_ + size <= buffer_size_, + "read can not have position excceed buffer length"); + size_t nread = std::min(buffer_size_ - curr_ptr_, size); + if (nread != 0) memcpy(ptr, p_buffer_ + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + virtual void Write(const void *ptr, size_t size) { + if (size == 0) return; + utils::Assert(curr_ptr_ + size <= buffer_size_, + "write position exceed fixed buffer size"); + memcpy(p_buffer_ + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + virtual void Seek(size_t pos) { + curr_ptr_ = static_cast(pos); + } + virtual size_t Tell(void) { + return curr_ptr_; + } + + private: + /*! \brief in memory buffer */ + char *p_buffer_; + /*! \brief current pointer */ + size_t buffer_size_; + /*! \brief current pointer */ + size_t curr_ptr_; +}; // class MemoryFixSizeBuffer + +/*! \brief a in memory buffer that can be read and write as stream interface */ +struct MemoryBufferStream : public ISeekStream { + public: + MemoryBufferStream(std::string *p_buffer) + : p_buffer_(p_buffer) { + curr_ptr_ = 0; + } + virtual ~MemoryBufferStream(void) {} + virtual size_t Read(void *ptr, size_t size) { + utils::Assert(curr_ptr_ <= p_buffer_->length(), + "read can not have position excceed buffer length"); + size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); + if (nread != 0) memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + virtual void Write(const void *ptr, size_t size) { + if (size == 0) return; + if (curr_ptr_ + size > p_buffer_->length()) { + p_buffer_->resize(curr_ptr_+size); + } + memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + virtual void Seek(size_t pos) { + curr_ptr_ = static_cast(pos); + } + virtual size_t Tell(void) { + return curr_ptr_; + } + + private: + /*! \brief in memory buffer */ + std::string *p_buffer_; + /*! \brief current pointer */ + size_t curr_ptr_; +}; // class MemoryBufferStream + +/*! \brief implementation of file i/o stream */ +class FileStream : public ISeekStream { + public: + explicit FileStream(FILE *fp) : fp(fp) {} + explicit FileStream(void) { + this->fp = NULL; + } + virtual size_t Read(void *ptr, size_t size) { + return std::fread(ptr, size, 1, fp); + } + virtual void Write(const void *ptr, size_t size) { + std::fwrite(ptr, size, 1, fp); + } + virtual void Seek(size_t pos) { + std::fseek(fp, static_cast(pos), SEEK_SET); + } + virtual size_t Tell(void) { + return std::ftell(fp); + } + inline void Close(void) { + if (fp != NULL){ + std::fclose(fp); fp = NULL; + } + } + + private: + FILE *fp; +}; +} // namespace utils +#endif diff --git a/src/socket.h b/src/socket.h new file mode 100644 index 000000000..a18d9a576 --- /dev/null +++ b/src/socket.h @@ -0,0 +1,387 @@ +#ifndef ALLREDUCE_SOCKET_H +#define ALLREDUCE_SOCKET_H +/*! + * \file socket.h + * \brief this file aims to provide a wrapper of sockets + * \author Tianqi Chen + */ +#if defined(_WIN32) +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif +#include +#include +#include "./utils.h" + +namespace utils { +#if defined(_WIN32) +typedef int ssize_t; +typedef int sock_size_t; +#else +typedef int SOCKET; +typedef size_t sock_size_t; +const int INVALID_SOCKET = -1; +#endif + +/*! \brief data structure for network address */ +struct SockAddr { + sockaddr_in addr; + // constructor + SockAddr(void) {} + SockAddr(const char *url, int port) { + this->Set(url, port); + } + inline static std::string GetHostName(void) { + std::string buf; buf.resize(256); + utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); + return std::string(buf.c_str()); + } + /*! + * \brief set the address + * \param url the url of the address + * \param port the port of address + */ + inline void Set(const char *host, int port) { + hostent *hp = gethostbyname(host); + Check(hp != NULL, "cannot obtain address of %s", host); + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + memcpy(&addr.sin_addr, hp->h_addr_list[0], hp->h_length); + } + /*! \brief return port of the address*/ + inline int port(void) const { + return ntohs(addr.sin_port); + } + /*! \return a string representation of the address */ + inline std::string AddrStr(void) const { + std::string buf; buf.resize(256); +#ifdef _WIN32 + const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, &buf[0], buf.length()); +#else + const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length()); +#endif + Assert(s != NULL, "cannot decode address"); + return std::string(s); + } +}; +/*! + * \brief a wrapper of TCP socket that hopefully be cross platform + */ +class TCPSocket { + public: + /*! \brief the file descriptor of socket */ + SOCKET sockfd; + // constructor + TCPSocket(void) : sockfd(INVALID_SOCKET) { + } + explicit TCPSocket(SOCKET sockfd) : sockfd(sockfd) { + } + ~TCPSocket(void) { + // do nothing in destructor + // user need to take care of close + } + // default conversion to int + inline operator SOCKET() const { + return sockfd; + } + /*! + * \brief create the socket, call this before using socket + * \param af domain + */ + inline void Create(int af = PF_INET) { + sockfd = socket(PF_INET, SOCK_STREAM, 0); + if (sockfd == INVALID_SOCKET) { + SockError("Create"); + } + } + /*! + * \brief start up the socket module + * call this before using the sockets + */ + inline static void Startup(void) { +#ifdef _WIN32 + WSADATA wsa_data; + if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) { + SockError("Startup"); + } + if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { + WSACleanup(); + utils::Error("Could not find a usable version of Winsock.dll\n"); + } +#endif + } + /*! + * \brief shutdown the socket module after use, all sockets need to be closed + */ + inline static void Finalize(void) { +#ifdef _WIN32 + WSACleanup(); +#endif + } + /*! + * \brief set this socket to use non-blocking mode + * \param non_block whether set it to be non-block, if it is false + * it will set it back to block mode + */ + inline void SetNonBlock(bool non_block) { +#ifdef _WIN32 + u_long mode = non_block ? 1 : 0; + if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { + SockError("SetNonBlock"); + } +#else + int flag = fcntl(sockfd, F_GETFL, 0); + if (flag == -1) { + SockError("SetNonBlock-1"); + } + if (non_block) { + flag |= O_NONBLOCK; + } else { + flag &= ~O_NONBLOCK; + } + if (fcntl(sockfd, F_SETFL, flag) == -1) { + SockError("SetNonBlock-2"); + } +#endif + } + /*! + * \brief perform listen of the socket + * \param backlog backlog parameter + */ + inline void Listen(int backlog = 16) { + listen(sockfd, backlog); + } + /*! \brief get a new connection */ + TCPSocket Accept(void) { + SOCKET newfd = accept(sockfd, NULL, NULL); + if (newfd == INVALID_SOCKET) { + SockError("Accept"); + } + return TCPSocket(newfd); + } + /*! + * \brief bind the socket to an address + * \param addr + */ + inline void Bind(const SockAddr &addr) { + if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { + SockError("Bind"); + } + } + /*! + * \brief try bind the socket to host, from start_port to end_port + * \param start_port starting port number to try + * \param end_port ending port number to try + * \param out_addr the binding address, if successful + * \return whether the binding is successful + */ + inline int TryBindHost(int start_port, int end_port) { + for (int port = start_port; port < end_port; ++port) { + SockAddr addr("0.0.0.0", port); + if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) { + return port; + } + if (errno != EADDRINUSE) { + SockError("TryBindHost"); + } + } + return -1; + } + /*! + * \brief connect to an address + * \param addr the address to connect to + */ + inline void Connect(const SockAddr &addr) { + if (connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { + SockError("Connect"); + } + } + /*! \brief close the connection */ + inline void Close(void) { + if (sockfd != -1) { +#ifdef _WIN32 + closesocket(sockfd); +#else + close(sockfd); +#endif + sockfd = INVALID_SOCKET; + } else { + Error("TCPSocket::Close double close the socket or close without create"); + } + } + /*! + * \brief send data using the socket + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually sent + */ + inline size_t Send(const void *buf_, size_t len, int flag = 0) { + const char *buf = reinterpret_cast(buf_); + if (len == 0) return 0; + ssize_t ret = send(sockfd, buf, static_cast(len), flag); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return 0; + SockError("Send"); + } + return ret; + } + /*! + * \brief receive data using the socket + * \param buf_ the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually received + */ + inline size_t Recv(void *buf_, size_t len, int flags = 0) { + char *buf = reinterpret_cast(buf_); + if (len == 0) return 0; + ssize_t ret = recv(sockfd, buf, static_cast(len), flags); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return 0; + SockError("Recv"); + } + return ret; + } + /*! + * \brief peform block write that will attempt to send all data out + * can still return smaller than request when error occurs + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \return size of data actually sent + */ + inline size_t SendAll(const void *buf_, size_t len) { + const char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; + SockError("Recv"); + } + buf += ret; + ndone += ret; + } + return ndone; + } + /*! + * \brief peforma block read that will attempt to read all data + * can still return smaller than request when error occurs + * \param buf_ the buffer pointer + * \param len length of data to recv + * \return size of data actually sent + */ + inline size_t RecvAll(void *buf_, size_t len) { + char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; + SockError("Recv"); + } + if (ret == 0) return ndone; + buf += ret; + ndone += ret; + } + return ndone; + } + + private: + // report an socket error + inline static void SockError(const char *msg) { + int errsv = errno; + Error("Socket %s Error:%s", msg, strerror(errsv)); + } +}; +/*! \brief helper data structure to perform select */ +struct SelectHelper { + public: + SelectHelper(void) { + this->Clear(); + } + /*! + * \brief add file descriptor to watch for read + * \param fd file descriptor to be watched + */ + inline void WatchRead(SOCKET fd) { + read_fds.push_back(fd); + if (fd > maxfd) maxfd = fd; + } + /*! + * \brief add file descriptor to watch for write + * \param fd file descriptor to be watched + */ + inline void WatchWrite(SOCKET fd) { + write_fds.push_back(fd); + if (fd > maxfd) maxfd = fd; + } + /*! + * \brief Check if the descriptor is ready for read + * \param fd file descriptor to check status + */ + inline bool CheckRead(SOCKET fd) const { + return FD_ISSET(fd, &read_set) != 0; + } + /*! + * \brief Check if the descriptor is ready for write + * \param fd file descriptor to check status + */ + inline bool CheckWrite(SOCKET fd) const { + return FD_ISSET(fd, &write_set) != 0; + } + /*! + * \brief clear all the monitored descriptors + */ + inline void Clear(void) { + read_fds.clear(); + write_fds.clear(); + maxfd = 0; + } + /*! + * \brief peform select on the set defined + * \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block + * \return number of active descriptors selected + */ + inline int Select(long timeout = 0) { + FD_ZERO(&read_set); + FD_ZERO(&write_set); + for (size_t i = 0; i < read_fds.size(); ++i) { + FD_SET(read_fds[i], &read_set); + } + for (size_t i = 0; i < write_fds.size(); ++i) { + FD_SET(write_fds[i], &write_set); + } + int ret; + if (timeout == 0) { + ret = select(static_cast(maxfd + 1), &read_set, &write_set, NULL, NULL); + } else { + timeval tm; + tm.tv_usec = (timeout % 1000) * 1000; + tm.tv_sec = timeout / 1000; + ret = select(static_cast(maxfd + 1), &read_set, &write_set, NULL, &tm); + } + if (ret == -1) { + int errsv = errno; + Error("Select Error: %s", strerror(errsv)); + } + return ret; + } + + private: + SOCKET maxfd; + fd_set read_set, write_set; + std::vector read_fds, write_fds; +}; +} +#endif diff --git a/src/tcp_master.py b/src/tcp_master.py new file mode 100644 index 000000000..c0820f14b --- /dev/null +++ b/src/tcp_master.py @@ -0,0 +1,106 @@ +""" +Master script for xgboost, tcp_master +This script can be used to start jobs of multi-node xgboost using sync_tcp + +Tianqi Chen +""" + +import sys +import os +import socket +import struct +import subprocess +from threading import Thread + +class ExSocket: + def __init__(self, sock): + self.sock = sock + def recvall(self, nbytes): + res = [] + sock = self.sock + nread = 0 + while nread < nbytes: + chunk = self.sock.recv(min(nbytes - nread, 1024), socket.MSG_WAITALL) + nread += len(chunk) + res.append(chunk) + return ''.join(res) + def recvint(self): + return struct.unpack('@i', self.recvall(4))[0] + def sendint(self, n): + self.sock.sendall(struct.pack('@i', n)) + def sendstr(self, s): + self.sendint(len(s)) + self.sock.sendall(s) + +# magic number used to verify existence of data +kMagic = 0xff99 + +class Master: + def __init__(self, port = 9000, port_end = 9999): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + for port in range(port, port_end): + try: + sock.bind(('', port)) + self.port = port + break + except socket.error: + continue + sock.listen(16) + self.sock = sock + print 'start listen on %s:%d' % (socket.gethostname(), self.port) + def __del__(self): + self.sock.close() + def slave_args(self): + return ['master_uri=%s' % socket.gethostname(), + 'master_port=%s' % self.port] + def accept_slaves(self, nslave): + slave_addrs = [] + for rank in range(nslave): + while True: + fd, s_addr = self.sock.accept() + slave = ExSocket(fd) + nparent = int(rank != 0) + nchild = 0 + if (rank + 1) * 2 - 1 < nslave: + nchild += 1 + if (rank + 1) * 2 < nslave: + nchild += 1 + try: + magic = slave.recvint() + if magic != kMagic: + print 'invalid magic number=%d from %s' % (magic, s_addr[0]) + slave.sock.close() + continue + except socket.error: + print 'sock error in %s' % (s_addr[0]) + slave.sock.close() + continue + slave.sendint(kMagic) + slave.sendint(rank) + slave.sendint(nslave) + slave.sendint(nparent) + slave.sendint(nchild) + if nparent != 0: + parent_index = (rank + 1) / 2 - 1 + ptuple = slave_addrs[parent_index] + slave.sendstr(ptuple[0]) + slave.sendint(ptuple[1]) + s_port = slave.recvint() + assert rank == len(slave_addrs) + slave_addrs.append((s_addr[0], s_port)) + slave.sock.close() + print 'finish starting rank=%d at %s' % (rank, s_addr[0]) + break + print 'all slaves setup complete' + +def mpi_submit(nslave, args): + cmd = ' '.join(['mpirun -n %d' % nslave] + args) + print cmd + return subprocess.check_call(cmd, shell = True) + +def submit(nslave, args, fun_submit = mpi_submit): + master = Master() + submit_thread = Thread(target = fun_submit, args = (nslave, args + master.slave_args())) + submit_thread.start() + master.accept_slaves(nslave) + submit_thread.join() diff --git a/src/utils.h b/src/utils.h new file mode 100644 index 000000000..2c529c449 --- /dev/null +++ b/src/utils.h @@ -0,0 +1,176 @@ +#ifndef ALLREDUCE_UTILS_H_ +#define ALLREDUCE_UTILS_H_ +/*! + * \file utils.h + * \brief simple utils to support the code + * \author Tianqi Chen + */ +#define _CRT_SECURE_NO_WARNINGS +#include +#include +#include +#include + +#ifndef ALLREDUCE_STRICT_CXX98_ +#include +#endif + +#if !defined(__GNUC__) +#define fopen64 std::fopen +#endif +#ifdef _MSC_VER +// NOTE: sprintf_s is not equivalent to snprintf, +// they are equivalent when success, which is sufficient for our case +#define snprintf sprintf_s +#define vsnprintf vsprintf_s +#else +#ifdef _FILE_OFFSET_BITS +#if _FILE_OFFSET_BITS == 32 +#pragma message ("Warning: FILE OFFSET BITS defined to be 32 bit") +#endif +#endif + +#ifdef __APPLE__ +#define off64_t off_t +#define fopen64 std::fopen +#endif + +extern "C" { +#include +} +#endif + +#ifdef _MSC_VER +typedef unsigned char uint8_t; +typedef unsigned short int uint16_t; +typedef unsigned int uint32_t; +typedef unsigned long uint64_t; +typedef long int64_t; +#else +#include +#endif + +/*! \brief namespace for helper utils of the project */ +namespace utils { + +/*! \brief error message buffer length */ +const int kPrintBuffer = 1 << 12; + +#ifndef ALLREDUCE_CUSTOMIZE_MSG_ +/*! + * \brief handling of Assert error, caused by in-apropriate input + * \param msg error message + */ +inline void HandleAssertError(const char *msg) { + fprintf(stderr, "AssertError:%s\n", msg); + exit(-1); +} +/*! + * \brief handling of Check error, caused by in-apropriate input + * \param msg error message + */ +inline void HandleCheckError(const char *msg) { + fprintf(stderr, "%s\n", msg); + exit(-1); +} +inline void HandlePrint(const char *msg) { + printf("%s", msg); +} +#else +#ifndef ALLREDUCE_STRICT_CXX98_ +// include declarations, some one must implement this +void HandleAssertError(const char *msg); +void HandleCheckError(const char *msg); +void HandlePrint(const char *msg); +#endif +#endif +#ifdef ALLREDUCE_STRICT_CXX98_ +// these function pointers are to be assigned +extern "C" void (*Printf)(const char *fmt, ...); +extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...); +extern "C" void (*Assert)(int exp, const char *fmt, ...); +extern "C" void (*Check)(int exp, const char *fmt, ...); +extern "C" void (*Error)(const char *fmt, ...); +#else +/*! \brief printf, print message to the console */ +inline void Printf(const char *fmt, ...) { + std::string msg(kPrintBuffer, '\0'); + va_list args; + va_start(args, fmt); + vsnprintf(&msg[0], kPrintBuffer, fmt, args); + va_end(args); + HandlePrint(msg.c_str()); +} +/*! \brief portable version of snprintf */ +inline int SPrintf(char *buf, size_t size, const char *fmt, ...) { + va_list args; + va_start(args, fmt); + int ret = vsnprintf(buf, size, fmt, args); + va_end(args); + return ret; +} + +/*! \brief assert an condition is true, use this to handle debug information */ +inline void Assert(bool exp, const char *fmt, ...) { + if (!exp) { + std::string msg(kPrintBuffer, '\0'); + va_list args; + va_start(args, fmt); + vsnprintf(&msg[0], kPrintBuffer, fmt, args); + va_end(args); + HandleAssertError(msg.c_str()); + } +} + +/*!\brief same as assert, but this is intended to be used as message for user*/ +inline void Check(bool exp, const char *fmt, ...) { + if (!exp) { + std::string msg(kPrintBuffer, '\0'); + va_list args; + va_start(args, fmt); + vsnprintf(&msg[0], kPrintBuffer, fmt, args); + va_end(args); + HandleCheckError(msg.c_str()); + } +} + +/*! \brief report error message, same as check */ +inline void Error(const char *fmt, ...) { + { + std::string msg(kPrintBuffer, '\0'); + va_list args; + va_start(args, fmt); + vsnprintf(&msg[0], kPrintBuffer, fmt, args); + va_end(args); + HandleCheckError(msg.c_str()); + } +} +#endif + +/*! \brief replace fopen, report error when the file open fails */ +inline std::FILE *FopenCheck(const char *fname, const char *flag) { + std::FILE *fp = fopen64(fname, flag); + Check(fp != NULL, "can not open file \"%s\"\n", fname); + return fp; +} +} // namespace utils +// easy utils that can be directly acessed in xgboost +/*! \brief get the beginning address of a vector */ +template +inline T *BeginPtr(std::vector &vec) { + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +/*! \brief get the beginning address of a vector */ +template +inline const T *BeginPtr(const std::vector &vec) { + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +#endif // ALLREDUCE_UTILS_H_ diff --git a/submit_job_tcp.py b/submit_job_tcp.py new file mode 100755 index 000000000..d79ef53bf --- /dev/null +++ b/submit_job_tcp.py @@ -0,0 +1,36 @@ +#!/usr/bin/python +""" +This is an example script to create a customized job submit +script using xgboost sync_tcp mode +""" +import sys +import os +import subprocess +# import the tcp_master.py +# add path to sync +sys.path.append(os.path.dirname(__file__)+'/src/') +import tcp_master as master + +# +# Note: this submit script is only used for example purpose +# It does not have to be mpirun, it can be any job submission script that starts the job, qsub, hadoop streaming etc. +# +def mpi_submit(nslave, args): + """ + customized submit script, that submit nslave jobs, each must contain args as parameter + note this can be a lambda function containing additional parameters in input + Parameters + nslave number of slave process to start up + args arguments to launch each job + this usually includes the parameters of master_uri and parameters passed into submit + """ + cmd = ' '.join(['mpirun -n %d' % nslave] + args) + print cmd + subprocess.check_call(cmd, shell = True) + +if __name__ == '__main__': + if len(sys.argv) < 2: + print 'Usage: ' + exit(0) + # call submit, with nslave, the commands to run each job and submit function + master.submit(int(sys.argv[1]), sys.argv[2:], fun_submit= mpi_submit) diff --git a/test/Makefile b/test/Makefile new file mode 100644 index 000000000..18f8c7481 --- /dev/null +++ b/test/Makefile @@ -0,0 +1,33 @@ +export CC = gcc +export CXX = g++ +export MPICXX = mpicxx +export LDFLAGS= -pthread -lm +export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fPIC -I../src + +ifeq ($(no_omp),1) + CFLAGS += -DDISABLE_OPENMP +else + CFLAGS += -fopenmp +endif + +# specify tensor path +BIN = test_allreduce +OBJ = engine_tcp.o +.PHONY: clean all + +all: $(BIN) $(MPIBIN) + +engine_tcp.o: ../src/engine_tcp.cpp ../src/*.h +test_allreduce: test_allreduce.cpp ../src/*.h engine_tcp.o + +$(BIN) : + $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) + +$(OBJ) : + $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) ) + +$(MPIBIN) : + $(MPICXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) + +clean: + $(RM) $(BIN) $(MPIBIN) *~ ../src/*~ diff --git a/test/test.sh b/test/test.sh new file mode 100755 index 000000000..5e5ef546d --- /dev/null +++ b/test/test.sh @@ -0,0 +1,7 @@ +#!/bin/bash +if [ "$#" -ne 2 ]; +then + echo "Usage " + exit -1 +fi +../submit_job_tcp.py $1 test_allreduce $2 diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp new file mode 100644 index 000000000..407abf139 --- /dev/null +++ b/test/test_allreduce.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include + +using namespace sync; + +inline void TestMax(size_t n) { + int rank = sync::GetRank(); + int nproc = sync::GetWorldSize(); + + std::vector ndata(n); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % 111; + } + sync::AllReduce(&ndata[0], ndata.size()); + for (size_t i = 0; i < ndata.size(); ++i) { + float rmax = (i * 1) % 111; + for (int r = 0; r < nproc; ++r) { + rmax = std::max(rmax, (float)((i * (r+1)) % 111)); + } + utils::Check(rmax == ndata[i], "[%d] TestMax check failure", rank); + } +} + +inline void TestSum(size_t n) { + int rank = sync::GetRank(); + int nproc = sync::GetWorldSize(); + const int z = 131; + + std::vector ndata(n); + for (size_t i = 0; i < ndata.size(); ++i) { + ndata[i] = (i * (rank+1)) % z; + } + sync::AllReduce(&ndata[0], ndata.size()); + for (size_t i = 0; i < ndata.size(); ++i) { + float rsum = 0.0f; + for (int r = 0; r < nproc; ++r) { + rsum += (float)((i * (r+1)) % z); + } + utils::Check(fabsf(rsum - ndata[i]) < 1e-5 , + "[%d] TestSum check failure, local=%g, allreduce=%g", rank, rsum, ndata[i]); + } +} + +inline void TestBcast(size_t n, int root) { + int rank = sync::GetRank(); + std::string s; s.resize(n); + for (size_t i = 0; i < n; ++i) { + s[i] = char(i % 126 + 1); + } + std::string res; + if (root == rank) { + res = s; + sync::Bcast(&res, root); + } else { + sync::Bcast(&res, root); + } + utils::Check(res == s, "[%d] TestBcast fail", rank); +} + +int main(int argc, char *argv[]) { + if (argc < 2) { + printf("Usage: \n"); + return 0; + } + int n = atoi(argv[1]); + sync::Init(argc, argv); + int rank = sync::GetRank(); + std::string name = sync::GetProcessorName(); + printf("[%d] start at %s\n", rank, name.c_str()); + TestMax(n); + printf("[%d] TestMax pass\n", rank); + TestSum(n); + printf("[%d] TestSum pass\n", rank); + sync::Finalize(); + printf("[%d] all check pass\n", rank); + return 0; +}