cpplint pass

This commit is contained in:
tqchen 2014-12-28 05:12:07 -08:00
parent 15836eb98e
commit 27d6977a3e
16 changed files with 406 additions and 328 deletions

View File

@ -36,4 +36,4 @@ $(ALIB):
ar cr $@ $+ ar cr $@ $+
clean: clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ $(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~

View File

@ -1,6 +1,5 @@
#ifndef RABIT_RABIT_H
#define RABIT_RABIT_H
/*! /*!
* Copyright (c) 2014 by Contributors
* \file rabit.h * \file rabit.h
* \brief This file defines unified Allreduce/Broadcast interface of rabit * \brief This file defines unified Allreduce/Broadcast interface of rabit
* The actual implementation is redirected to rabit engine * The actual implementation is redirected to rabit engine
@ -9,12 +8,14 @@
* rabit.h and serializable.h is all the user need to use rabit interface * rabit.h and serializable.h is all the user need to use rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#ifndef RABIT_RABIT_H_
#define RABIT_RABIT_H_
#include <string> #include <string>
#include <vector> #include <vector>
// optionally support of lambda function in C++11, if available // optionally support of lambda function in C++11, if available
#if __cplusplus >= 201103L #if __cplusplus >= 201103L
#include <functional> #include <functional>
#endif // C++11 #endif // C++11
// contains definition of ISerializable // contains definition of ISerializable
#include "./rabit_serializable.h" #include "./rabit_serializable.h"
// engine definition of rabit, defines internal implementation // engine definition of rabit, defines internal implementation
@ -116,7 +117,7 @@ inline void Broadcast(std::string *sendrecv_data, int root);
*/ */
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL, void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL); void *prepare_arg = NULL);
// C++11 support for lambda prepare function // C++11 support for lambda prepare function
@ -142,9 +143,9 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \tparam DType type of data * \tparam DType type of data
*/ */
template<typename OP, typename DType> template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun); inline void Allreduce(DType *sendrecvbuf, size_t count,
#endif // C++11 std::function<void()> prepare_fun);
#endif // C++11
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
@ -228,6 +229,7 @@ class Reducer {
inline void Allreduce(DType *sendrecvbuf, size_t count, inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun); std::function<void()> prepare_fun);
#endif #endif
private: private:
/*! \brief function handle to do reduce */ /*! \brief function handle to do reduce */
engine::ReduceHandle handle_; engine::ReduceHandle handle_;
@ -274,6 +276,7 @@ class SerializeReducer {
size_t max_nbyte, size_t count, size_t max_nbyte, size_t count,
std::function<void()> prepare_fun); std::function<void()> prepare_fun);
#endif #endif
private: private:
/*! \brief function handle to do reduce */ /*! \brief function handle to do reduce */
engine::ReduceHandle handle_; engine::ReduceHandle handle_;
@ -283,4 +286,4 @@ class SerializeReducer {
} // namespace rabit } // namespace rabit
// implementation of template functions // implementation of template functions
#include "./rabit/rabit-inl.h" #include "./rabit/rabit-inl.h"
#endif // RABIT_ALLREDUCE_H #endif // RABIT_RABIT_H_

View File

@ -1,10 +1,12 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file engine.h * \file engine.h
* \brief This file defines the core interface of allreduce library * \brief This file defines the core interface of allreduce library
* \author Tianqi Chen, Nacho, Tianyi * \author Tianqi Chen, Nacho, Tianyi
*/ */
#ifndef RABIT_ENGINE_H #ifndef RABIT_ENGINE_H_
#define RABIT_ENGINE_H #define RABIT_ENGINE_H_
#include <string>
#include "../rabit_serializable.h" #include "../rabit_serializable.h"
namespace MPI { namespace MPI {
@ -122,7 +124,7 @@ class IEngine {
virtual int GetRank(void) const = 0; virtual int GetRank(void) const = 0;
/*! \brief get total number of */ /*! \brief get total number of */
virtual int GetWorldSize(void) const = 0; virtual int GetWorldSize(void) const = 0;
/*! \brief get the host name of current node */ /*! \brief get the host name of current node */
virtual std::string GetHost(void) const = 0; virtual std::string GetHost(void) const = 0;
/*! /*!
* \brief print the msg in the tracker, * \brief print the msg in the tracker,
@ -211,7 +213,7 @@ class ReduceHandle {
/*! \return the number of bytes occupied by the type */ /*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype); static int TypeSize(const MPI::Datatype &dtype);
private: protected:
// handle data field // handle data field
void *handle_; void *handle_;
// handle to the type field // handle to the type field
@ -221,5 +223,4 @@ class ReduceHandle {
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit
#endif // RABIT_ENGINE_H #endif // RABIT_ENGINE_H_

View File

@ -1,16 +1,19 @@
#ifndef RABIT_UTILS_IO_H
#define RABIT_UTILS_IO_H
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include "./utils.h"
#include "../rabit_serializable.h"
/*! /*!
* Copyright (c) 2014 by Contributors
* \file io.h * \file io.h
* \brief utilities that implements different serializable interface * \brief utilities that implements different serializable interface
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_UTILS_IO_H_
#define RABIT_UTILS_IO_H_
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include <algorithm>
#include "./utils.h"
#include "../rabit_serializable.h"
namespace rabit { namespace rabit {
namespace utils { namespace utils {
/*! \brief interface of i/o stream that support seek */ /*! \brief interface of i/o stream that support seek */
@ -25,8 +28,9 @@ class ISeekStream: public IStream {
/*! \brief fixed size memory buffer */ /*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public ISeekStream { struct MemoryFixSizeBuffer : public ISeekStream {
public: public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) { : p_buffer_(reinterpret_cast<char*>(p_buffer)),
buffer_size_(buffer_size) {
curr_ptr_ = 0; curr_ptr_ = 0;
} }
virtual ~MemoryFixSizeBuffer(void) {} virtual ~MemoryFixSizeBuffer(void) {}
@ -40,7 +44,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
} }
virtual void Write(const void *ptr, size_t size) { virtual void Write(const void *ptr, size_t size) {
if (size == 0) return; if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_, utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size"); "write position exceed fixed buffer size");
memcpy(p_buffer_ + curr_ptr_, ptr, size); memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size; curr_ptr_ += size;
@ -59,12 +63,12 @@ struct MemoryFixSizeBuffer : public ISeekStream {
size_t buffer_size_; size_t buffer_size_;
/*! \brief current pointer */ /*! \brief current pointer */
size_t curr_ptr_; size_t curr_ptr_;
}; // class MemoryFixSizeBuffer }; // class MemoryFixSizeBuffer
/*! \brief a in memory buffer that can be read and write as stream interface */ /*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream { struct MemoryBufferStream : public ISeekStream {
public: public:
MemoryBufferStream(std::string *p_buffer) explicit MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) { : p_buffer_(p_buffer) {
curr_ptr_ = 0; curr_ptr_ = 0;
} }
@ -82,7 +86,7 @@ struct MemoryBufferStream : public ISeekStream {
if (curr_ptr_ + size > p_buffer_->length()) { if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size); p_buffer_->resize(curr_ptr_+size);
} }
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size; curr_ptr_ += size;
} }
virtual void Seek(size_t pos) { virtual void Seek(size_t pos) {
@ -97,36 +101,7 @@ struct MemoryBufferStream : public ISeekStream {
std::string *p_buffer_; std::string *p_buffer_;
/*! \brief current pointer */ /*! \brief current pointer */
size_t curr_ptr_; size_t curr_ptr_;
}; // class MemoryBufferStream }; // class MemoryBufferStream
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
} // namespace utils } // namespace utils
} // namespace rabit } // namespace rabit
#endif #endif // RABIT_UTILS_IO_H_

View File

@ -1,10 +1,11 @@
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
/*! /*!
* Copyright (c) 2014 by Contributors
* \file utils.h * \file utils.h
* \brief simple utils to support the code * \brief simple utils to support the code
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#include <cstdio> #include <cstdio>
#include <string> #include <string>
@ -19,7 +20,7 @@
#define fopen64 std::fopen #define fopen64 std::fopen
#endif #endif
#ifdef _MSC_VER #ifdef _MSC_VER
// NOTE: sprintf_s is not equivalent to snprintf, // NOTE: sprintf_s is not equivalent to snprintf,
// they are equivalent when success, which is sufficient for our case // they are equivalent when success, which is sufficient for our case
#define snprintf sprintf_s #define snprintf sprintf_s
#define vsnprintf vsprintf_s #define vsnprintf vsprintf_s
@ -30,7 +31,7 @@
#endif #endif
#endif #endif
#ifdef __APPLE__ #ifdef __APPLE__
#define off64_t off_t #define off64_t off_t
#define fopen64 std::fopen #define fopen64 std::fopen
#endif #endif
@ -186,5 +187,5 @@ inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL; if (str.length() == 0) return NULL;
return &str[0]; return &str[0];
} }
} // namespace rabit } // namespace rabit
#endif // RABIT_UTILS_H_ #endif // RABIT_UTILS_H_

View File

@ -1,13 +1,14 @@
#ifndef RABIT_RABIT_SERIALIZABLE_H
#define RABIT_RABIT_SERIALIZABLE_H
#include <vector>
#include <string>
#include "./rabit/utils.h"
/*! /*!
* Copyright (c) 2014 by Contributors
* \file serializable.h * \file serializable.h
* \brief defines serializable interface of rabit * \brief defines serializable interface of rabit
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_RABIT_SERIALIZABLE_H_
#define RABIT_RABIT_SERIALIZABLE_H_
#include <vector>
#include <string>
#include "./rabit/utils.h"
namespace rabit { namespace rabit {
/*! /*!
* \brief interface of stream I/O, used by ISerializable * \brief interface of stream I/O, used by ISerializable
@ -96,4 +97,4 @@ class ISerializable {
virtual void Save(IStream &fo) const = 0; virtual void Save(IStream &fo) const = 0;
}; };
} // namespace rabit } // namespace rabit
#endif #endif // RABIT_RABIT_SERIALIZABLE_H_

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.cc * \file allreduce_base.cc
* \brief Basic implementation of AllReduce * \brief Basic implementation of AllReduce
* *
@ -32,13 +33,15 @@ AllreduceBase::AllreduceBase(void) {
// initialization function // initialization function
void AllreduceBase::Init(void) { void AllreduceBase::Init(void) {
// setup from enviroment variables // setup from enviroment variables
{// handling for hadoop {
// handling for hadoop
const char *task_id = getenv("mapred_tip_id"); const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) { if (task_id == NULL) {
task_id = getenv("mapreduce_task_id"); task_id = getenv("mapreduce_task_id");
} }
if (hadoop_mode != 0) { if (hadoop_mode != 0) {
utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id"); utils::Check(task_id != NULL,
"hadoop_mode is set but cannot find mapred_task_id");
} }
if (task_id != NULL) { if (task_id != NULL) {
this->SetParam("rabit_task_id", task_id); this->SetParam("rabit_task_id", task_id);
@ -48,7 +51,7 @@ void AllreduceBase::Init(void) {
if (attempt_id != 0) { if (attempt_id != 0) {
const char *att = strrchr(attempt_id, '_'); const char *att = strrchr(attempt_id, '_');
int num_trial; int num_trial;
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) { if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1); this->SetParam("rabit_num_trial", att + 1);
} }
} }
@ -58,7 +61,8 @@ void AllreduceBase::Init(void) {
num_task = getenv("mapreduce_job_maps"); num_task = getenv("mapreduce_job_maps");
} }
if (hadoop_mode != 0) { if (hadoop_mode != 0) {
utils::Check(num_task != NULL, "hadoop_mode is set but cannot find mapred_map_tasks"); utils::Check(num_task != NULL,
"hadoop_mode is set but cannot find mapred_map_tasks");
} }
if (num_task != NULL) { if (num_task != NULL) {
this->SetParam("rabit_world_size", num_task); this->SetParam("rabit_world_size", num_task);
@ -81,11 +85,11 @@ void AllreduceBase::Shutdown(void) {
} }
all_links.clear(); all_links.clear();
tree_links.plinks.clear(); tree_links.plinks.clear();
if (tracker_uri == "NULL") return; if (tracker_uri == "NULL") return;
// notify tracker rank i have shutdown // notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker(); utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown")); tracker.SendStr(std::string("shutdown"));
tracker.Close(); tracker.Close();
utils::TCPSocket::Finalize(); utils::TCPSocket::Finalize();
} }
@ -107,11 +111,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val; if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val); if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
if (!strcmp(name, "rabit_task_id")) task_id = val; if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val); if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_reduce_buffer")) { if (!strcmp(name, "rabit_reduce_buffer")) {
char unit; char unit;
unsigned long amount; uint64_t amount;
if (sscanf(val, "%lu%c", &amount, &unit) == 2) { if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
switch (unit) { switch (unit) {
case 'B': reduce_buffer_size = (amount + 7)/ 8; break; case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
@ -121,7 +125,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
default: utils::Error("invalid format for reduce buffer"); default: utils::Error("invalid format for reduce buffer");
} }
} else { } else {
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}"); utils::Error("invalid format for reduce_buffer,"\
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
} }
} }
} }
@ -137,11 +142,16 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) { if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect"); utils::Socket::Error("Connect");
} }
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1"); using utils::Assert;
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2"); Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure"); utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3"); "ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id); tracker.SendStr(task_id);
return tracker; return tracker;
} }
@ -161,29 +171,30 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
int prev_rank, next_rank; int prev_rank, next_rank;
// the rank of neighbors // the rank of neighbors
std::map<int, int> tree_neighbors; std::map<int, int> tree_neighbors;
{// get new ranks using utils::Assert;
int newrank, num_neighbors; // get new ranks
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), int newrank, num_neighbors;
"ReConnectLink failure 4"); Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank), "ReConnectLink failure 4");
"ReConnectLink failure 4"); Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), sizeof(parent_rank), "ReConnectLink failure 4");
"ReConnectLink failure 4"); Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one"); "ReConnectLink failure 4");
rank = newrank; Assert(rank == -1 || newrank == rank,
utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors), "must keep rank to same if the node already have one");
"ReConnectLink failure 4"); rank = newrank;
for (int i = 0; i < num_neighbors; ++i) { Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
int nrank; sizeof(num_neighbors), "ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), for (int i = 0; i < num_neighbors; ++i) {
"ReConnectLink failure 4"); int nrank;
tree_neighbors[nrank] = 1; Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
} "ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), tree_neighbors[nrank] = 1;
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
} }
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
// create listening socket // create listening socket
utils::TCPSocket sock_listen; utils::TCPSocket sock_listen;
sock_listen.Create(); sock_listen.Create();
@ -204,56 +215,67 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
} }
} }
int ngood = static_cast<int>(good_link.size()); int ngood = static_cast<int>(good_link.size());
utils::Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5"); "ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) { for (size_t i = 0; i < good_link.size(); ++i) {
utils::Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]), Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
"ReConnectLink failure 6"); sizeof(good_link[i]), "ReConnectLink failure 6");
} }
utils::Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7"); "ReConnectLink failure 7");
utils::Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
"ReConnectLink failure 8"); sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0; num_error = 0;
for (int i = 0; i < num_conn; ++i) { for (int i = 0; i < num_conn; ++i) {
LinkRecord r; LinkRecord r;
int hport, hrank; int hport, hrank;
std::string hname; std::string hname;
tracker.RecvStr(&hname); tracker.RecvStr(&hname);
utils::Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
utils::Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
r.sock.Create(); r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) { if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1; r.sock.Close(); continue; num_error += 1; r.sock.Close(); continue;
} }
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12"); Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13"); "ReConnectLink failure 12");
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent"); Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false; bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) { for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) { if (all_links[i].rank == hrank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active"); Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break; all_links[i].sock = r.sock; match = true; break;
} }
} }
if (!match) all_links.push_back(r); if (!match) all_links.push_back(r);
} }
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14"); Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0); } while (num_error != 0);
// send back socket listening port to tracker // send back socket listening port to tracker
utils::Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
// close connection to tracker // close connection to tracker
tracker.Close(); tracker.Close();
// listen to incoming links // listen to incoming links
for (int i = 0; i < num_accept; ++i) { for (int i = 0; i < num_accept; ++i) {
LinkRecord r; LinkRecord r;
r.sock = sock_listen.Accept(); r.sock = sock_listen.Accept();
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15"); Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15"); "ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false; bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) { for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) { if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active"); utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break; all_links[i].sock = r.sock; match = true; break;
} }
} }
@ -278,9 +300,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i]; if (all_links[i].rank == next_rank) ring_next = &all_links[i];
} }
utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link"); Assert(parent_rank == -1 || parent_index != -1,
utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link"); "cannot find parent in the link");
utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link"); Assert(prev_rank == -1 || ring_prev != NULL,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
} }
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
@ -326,7 +351,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// if no childs, no need to reduce // if no childs, no need to reduce
if (nlink == static_cast<int>(parent_index != -1)) { if (nlink == static_cast<int>(parent_index != -1)) {
size_up_reduce = total_size; size_up_reduce = total_size;
} }
// while we have not passed the messages out // while we have not passed the messages out
while (true) { while (true) {
// select helper // select helper
@ -347,7 +372,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
if (links[i].size_read != total_size) { if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); selecter.WatchRead(links[i].sock);
} }
// size_write <= size_read // size_write <= size_read
if (links[i].size_write != total_size) { if (links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock); selecter.WatchWrite(links[i].sock);
// only watch for exception in live channels // only watch for exception in live channels
@ -358,11 +383,11 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
} }
// finish runing allreduce // finish runing allreduce
if (finished) break; if (finished) break;
// select must return // select must return
selecter.Select(); selecter.Select();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link // recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) return kGetExcept; if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
} }
// read data from childs // read data from childs
@ -392,7 +417,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// start position // start position
size_t start = size_up_reduce % buffer_size; size_t start = size_up_reduce % buffer_size;
// peform read till end of buffer // peform read till end of buffer
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce); size_t nread = std::min(buffer_size - start,
max_reduce - size_up_reduce);
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check"); utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index) { if (i != parent_index) {
@ -407,7 +433,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
} }
if (parent_index != -1) { if (parent_index != -1) {
// pass message up to parent, can pass data that are already been reduced // pass message up to parent, can pass data that are already been reduced
if (selecter.CheckWrite(links[parent_index].sock)) { if (selecter.CheckWrite(links[parent_index].sock)) {
ssize_t len = links[parent_index].sock. ssize_t len = links[parent_index].sock.
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
if (len != -1) { if (len != -1) {
@ -417,7 +443,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
} }
} }
// read data from parent // read data from parent
if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) { if (selecter.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock. ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in); Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) { if (len == 0) {
@ -425,7 +452,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
} }
if (len != -1) { if (len != -1) {
size_down_in += static_cast<size_t>(len); size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error"); utils::Assert(size_down_in <= size_up_out,
"Allreduce: boundary error");
} else { } else {
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError; if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
} }
@ -437,11 +465,13 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// can pass message down to childs // can pass message down to childs
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckWrite(links[i].sock)) { if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError; if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) {
return kSockError;
}
} }
} }
} }
return kSuccess; return kSuccess;
} }
/*! /*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
@ -455,14 +485,15 @@ AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links; RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess; if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size, "Broadcast: root should be smaller than world size"); utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links // number of links
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.size());
// size of space already read from data // size of space already read from data
size_t size_in = 0; size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root // input link, -2 means unknown yet, -1 means this is root
int in_link = -2; int in_link = -2;
// initialize the link statistics // initialize the link statistics
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
links[i].ResetSize(); links[i].ResetSize();
@ -471,9 +502,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
if (this->rank == root) { if (this->rank == root) {
size_in = total_size; size_in = total_size;
in_link = -1; in_link = -1;
} }
// while we have not passed the messages out // while we have not passed the messages out
while(true) { while (true) {
bool finished = true; bool finished = true;
// select helper // select helper
utils::SelectHelper selecter; utils::SelectHelper selecter;
@ -487,7 +518,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
if (in_link != -2 && i != in_link && links[i].size_write != total_size) { if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock); finished = false; selecter.WatchWrite(links[i].sock); finished = false;
} }
selecter.WatchException(links[i].sock); selecter.WatchException(links[i].sock);
} }
// finish running // finish running
if (finished) break; if (finished) break;
@ -495,14 +526,16 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
selecter.Select(); selecter.Select();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link // recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) return kGetExcept; if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
} }
if (in_link == -2) { if (in_link == -2) {
// probe in-link // probe in-link
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) { if (selecter.CheckRead(links[i].sock)) {
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) return kSockError; if (!links[i].ReadToArray(sendrecvbuf_, total_size)) {
return kSockError;
}
size_in = links[i].size_read; size_in = links[i].size_read;
if (size_in != 0) { if (size_in != 0) {
in_link = i; break; in_link = i; break;
@ -512,7 +545,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
} else { } else {
// read from in link // read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) return kSockError; if (!links[in_link].ReadToArray(sendrecvbuf_, total_size)) {
return kSockError;
}
size_in = links[in_link].size_read; size_in = links[in_link].size_read;
} }
} }

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.h * \file allreduce_base.h
* \brief Basic implementation of AllReduce * \brief Basic implementation of AllReduce
* using TCP non-block socket and tree-shape reduction. * using TCP non-block socket and tree-shape reduction.
@ -8,13 +9,14 @@
* *
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#ifndef RABIT_ALLREDUCE_BASE_H #ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H #define RABIT_ALLREDUCE_BASE_H_
#include <vector> #include <vector>
#include <string> #include <string>
#include <rabit/utils.h> #include <algorithm>
#include <rabit/engine.h> #include "rabit/utils.h"
#include "rabit/engine.h"
#include "./socket.h" #include "./socket.h"
namespace MPI { namespace MPI {
@ -22,7 +24,7 @@ namespace MPI {
class Datatype { class Datatype {
public: public:
size_t type_size; size_t type_size;
Datatype(size_t type_size) : type_size(type_size) {} explicit Datatype(size_t type_size) : type_size(type_size) {}
}; };
} }
namespace rabit { namespace rabit {
@ -31,7 +33,7 @@ namespace engine {
class AllreduceBase : public IEngine { class AllreduceBase : public IEngine {
public: public:
// magic number to verify server // magic number to verify server
const static int kMagic = 0xff99; static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening // constant one byte out of band message to indicate error happening
AllreduceBase(void); AllreduceBase(void);
virtual ~AllreduceBase(void) {} virtual ~AllreduceBase(void) {}
@ -79,12 +81,13 @@ class AllreduceBase : public IEngine {
*/ */
virtual void Allreduce(void *sendrecvbuf_, virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
ReduceFunction reducer, ReduceFunction reducer,
PreprocFunction prepare_fun = NULL, PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) { void *prepare_arg = NULL) {
if (prepare_fun != NULL) prepare_fun(prepare_arg); if (prepare_fun != NULL) prepare_fun(prepare_arg);
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess, utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
"Allreduce failed"); "Allreduce failed");
} }
/*! /*!
@ -201,12 +204,16 @@ class AllreduceBase : public IEngine {
// constructor // constructor
LinkRecord(void) {} LinkRecord(void) {}
// initialize buffer // initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) { inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8; size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n)); buffer_.resize(std::min(reduce_buffer_size, n));
// make sure align to type_nbytes // make sure align to type_nbytes
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; buffer_size =
utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size); buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size,
"too large type_nbytes=%lu, buffer_size=%lu",
type_nbytes, buffer_size);
// set buffer head // set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_)); buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
} }
@ -225,7 +232,7 @@ class AllreduceBase : public IEngine {
size_t ngap = size_read - protect_start; size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size; size_t offset = size_read % buffer_size;
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
if (nmax == 0) return true; if (nmax == 0) return true;
ssize_t len = sock.Recv(buffer_head + offset, nmax); ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected // length equals 0, remote disconnected
@ -235,7 +242,7 @@ class AllreduceBase : public IEngine {
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK; if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
size_read += static_cast<size_t>(len); size_read += static_cast<size_t>(len);
return true; return true;
} }
/*! /*!
* \brief read data into array, * \brief read data into array,
* this function can not be used together with ReadToRingBuffer * this function can not be used together with ReadToRingBuffer

View File

@ -1,11 +1,13 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust-inl.h * \file allreduce_robust-inl.h
* \brief implementation of inline template function in AllreduceRobust * \brief implementation of inline template function in AllreduceRobust
* *
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_ENGINE_ROBUST_INL_H #ifndef RABIT_ENGINE_ROBUST_INL_H_
#define RABIT_ENGINE_ROBUST_INL_H #define RABIT_ENGINE_ROBUST_INL_H_
#include <vector>
namespace rabit { namespace rabit {
namespace engine { namespace engine {
@ -33,14 +35,14 @@ inline AllreduceRobust::ReturnType
AllreduceRobust::MsgPassing(const NodeType &node_value, AllreduceRobust::MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in, std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out, std::vector<EdgeType> *p_edge_out,
EdgeType (*func) (const NodeType &node_value, EdgeType (*func)
const std::vector<EdgeType> &edge_in, (const NodeType &node_value,
size_t out_index) const std::vector<EdgeType> &edge_in,
) { size_t out_index)) {
RefLinkVector &links = tree_links; RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess; if (links.size() == 0) return kSuccess;
// number of links // number of links
const int nlink = static_cast<int>(links.size()); const int nlink = static_cast<int>(links.size());
// initialize the pointers // initialize the pointers
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
links[i].ResetSize(); links[i].ResetSize();
@ -58,7 +60,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// if no childs, no need to, directly start passing message // if no childs, no need to, directly start passing message
if (nlink == static_cast<int>(parent_index != -1)) { if (nlink == static_cast<int>(parent_index != -1)) {
utils::Assert(parent_index == 0, "parent must be 0"); utils::Assert(parent_index == 0, "parent must be 0");
edge_out[parent_index] = func(node_value, edge_in, parent_index); edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1; stage = 1;
} }
// while we have not passed the messages out // while we have not passed the messages out
@ -94,7 +96,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
selecter.Select(); selecter.Select();
// exception handling // exception handling
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link // recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) return kGetExcept; if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
} }
if (stage == 0) { if (stage == 0) {
@ -103,7 +105,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index) { if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) { if (selecter.CheckRead(links[i].sock)) {
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) return kSockError; if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) {
return kSockError;
}
} }
if (links[i].size_read != sizeof(EdgeType)) finished = false; if (links[i].size_read != sizeof(EdgeType)) finished = false;
} }
@ -124,13 +128,17 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (stage == 1) { if (stage == 1) {
const int pid = this->parent_index; const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage"); utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) return kSockError; if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) {
return kSockError;
}
if (links[pid].size_write == sizeof(EdgeType)) stage = 2; if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
} }
if (stage == 2) { if (stage == 2) {
const int pid = this->parent_index; const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage"); utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) return kSockError; if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) {
return kSockError;
}
if (links[pid].size_read == sizeof(EdgeType)) { if (links[pid].size_read == sizeof(EdgeType)) {
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != pid) edge_out[i] = func(node_value, edge_in, i); if (i != pid) edge_out[i] = func(node_value, edge_in, i);
@ -141,7 +149,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (stage == 3) { if (stage == 3) {
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError; if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) {
return kSockError;
}
} }
} }
} }
@ -150,4 +160,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
} }
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit
#endif // RABIT_ENGINE_ROBUST_INL_H #endif // RABIT_ENGINE_ROBUST_INL_H_

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.cc * \file allreduce_robust.cc
* \brief Robust implementation of Allreduce * \brief Robust implementation of Allreduce
* *
@ -9,10 +10,10 @@
#define NOMINMAX #define NOMINMAX
#include <limits> #include <limits>
#include <utility> #include <utility>
#include <rabit/io.h> #include "rabit/io.h"
#include <rabit/utils.h> #include "rabit/utils.h"
#include <rabit/engine.h> #include "rabit/engine.h"
#include <rabit/rabit-inl.h> #include "rabit/rabit-inl.h"
#include "./allreduce_robust.h" #include "./allreduce_robust.h"
namespace rabit { namespace rabit {
@ -30,10 +31,10 @@ void AllreduceRobust::Shutdown(void) {
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"Shutdown: check point must return true"); "Shutdown: check point must return true");
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here // execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"Shutdown: check ack must return true"); "Shutdown: check ack must return true");
AllreduceBase::Shutdown(); AllreduceBase::Shutdown();
} }
/*! /*!
@ -89,7 +90,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
} else { } else {
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
} }
} }
} }
resbuf.PushTemp(seq_counter, type_nbytes, count); resbuf.PushTemp(seq_counter, type_nbytes, count);
seq_counter += 1; seq_counter += 1;
@ -102,7 +103,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
*/ */
void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) { void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
// skip action in single node // skip action in single node
if (world_size == 1) return; if (world_size == 1) return;
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
// now we are free to remove the last result, if any // now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 && if (resbuf.LastSeqNo() != -1 &&
@ -119,7 +120,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
} else { } else {
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter); recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
} }
} }
} }
resbuf.PushTemp(seq_counter, 1, total_size); resbuf.PushTemp(seq_counter, 1, total_size);
seq_counter += 1; seq_counter += 1;
@ -151,7 +152,8 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
// skip action in single node // skip action in single node
if (world_size == 1) return 0; if (world_size == 1) return 0;
if (num_local_replica == 0) { if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); utils::Check(local_model == NULL,
"need to set num_local_replica larger than 1 to checkpoint local_model");
} }
// check if we succesful // check if we succesful
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) { if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
@ -171,9 +173,10 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
// load from buffer // load from buffer
utils::MemoryBufferStream fs(&global_checkpoint); utils::MemoryBufferStream fs(&global_checkpoint);
if (global_checkpoint.length() == 0) { if (global_checkpoint.length() == 0) {
version_number = 0; version_number = 0;
} else { } else {
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number"); utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
"read in version number");
global_model->Load(fs); global_model->Load(fs);
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1, utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
"local model inconsistent, nlocal=%d", nlocal); "local model inconsistent, nlocal=%d", nlocal);
@ -212,9 +215,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
version_number += 1; return; version_number += 1; return;
} }
if (num_local_replica == 0) { if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model"); utils::Check(local_model == NULL,
} "need to set num_local_replica larger than 1 to checkpoint local_model");
if (num_local_replica != 0) { }
if (num_local_replica != 0) {
while (true) { while (true) {
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break; if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
// save model model to new version place // save model model to new version place
@ -247,10 +251,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
fs.Write(&version_number, sizeof(version_number)); fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs); global_model->Save(fs);
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here // execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"check ack must return true"); "check ack must return true");
} }
/*! /*!
* \brief reset the all the existing links by sending Out-of-Band message marker * \brief reset the all the existing links by sending Out-of-Band message marker
@ -383,7 +387,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
*/ */
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
if (err_type == kSuccess) return true; if (err_type == kSuccess) return true;
{// simple way, shutdown all links {
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) { for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
} }
@ -392,8 +397,8 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
} }
// this was old way // this was old way
// TryResetLinks still causes possible errors, so not use this one // TryResetLinks still causes possible errors, so not use this one
while(err_type != kSuccess) { while (err_type != kSuccess) {
switch(err_type) { switch (err_type) {
case kGetExcept: err_type = TryResetLinks(); break; case kGetExcept: err_type = TryResetLinks(); break;
case kSockError: { case kSockError: {
TryResetLinks(); TryResetLinks();
@ -416,7 +421,7 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
* \param out_index the edge index of output link * \param out_index the edge index of output link
* \return the shorest distance result of out edge specified by out_index * \return the shorest distance result of out edge specified by out_index
*/ */
inline std::pair<int,size_t> inline std::pair<int, size_t>
ShortestDist(const std::pair<bool, size_t> &node_value, ShortestDist(const std::pair<bool, size_t> &node_value,
const std::vector< std::pair<int, size_t> > &dist_in, const std::vector< std::pair<int, size_t> > &dist_in,
size_t out_index) { size_t out_index) {
@ -484,8 +489,9 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
int *p_recvlink, int *p_recvlink,
std::vector<bool> *p_req_in) { std::vector<bool> *p_req_in) {
int best_link = -2; int best_link = -2;
{// get the shortest distance to the request point {
std::vector< std::pair<int,size_t> > dist_in, dist_out; // get the shortest distance to the request point
std::vector<std::pair<int, size_t> > dist_in, dist_out;
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size), ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
&dist_in, &dist_out, ShortestDist); &dist_in, &dist_out, ShortestDist);
if (succ != kSuccess) return succ; if (succ != kSuccess) return succ;
@ -512,7 +518,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
&req_in, &req_out, DataRequest); &req_in, &req_out, DataRequest);
if (succ != kSuccess) return succ; if (succ != kSuccess) return succ;
// set p_req_in // set p_req_in
p_req_in->resize(req_in.size()); p_req_in->resize(req_in.size());
for (size_t i = 0; i < req_in.size(); ++i) { for (size_t i = 0; i < req_in.size(); ++i) {
// set p_req_in // set p_req_in
(*p_req_in)[i] = (req_in[i] != 0); (*p_req_in)[i] = (req_in[i] != 0);
@ -591,19 +597,23 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kRequestData) { if (role == kRequestData) {
const int pid = recv_link; const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) { if (selecter.CheckRead(links[pid].sock)) {
if(!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError; if (!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError;
} }
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (req_in[i] && links[i].size_write != links[pid].size_read && if (req_in[i] && links[i].size_write != links[pid].size_read &&
selecter.CheckWrite(links[i].sock)) { selecter.CheckWrite(links[i].sock)) {
if(!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) return kSockError; if (!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) {
return kSockError;
}
} }
} }
} }
if (role == kHaveData) { if (role == kHaveData) {
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock)) { if (req_in[i] && selecter.CheckWrite(links[i].sock)) {
if(!links[i].WriteFromArray(sendrecvbuf_, size)) return kSockError; if (!links[i].WriteFromArray(sendrecvbuf_, size)) {
return kSockError;
}
} }
} }
} }
@ -616,13 +626,14 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (req_in[i]) min_write = std::min(links[i].size_write, min_write); if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
} }
utils::Assert(min_write <= links[pid].size_read, "boundary check"); utils::Assert(min_write <= links[pid].size_read, "boundary check");
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError; if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
} }
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) { if (req_in[i] && selecter.CheckWrite(links[i].sock) &&
links[pid].size_read != links[i].size_write) {
size_t start = links[i].size_write % buffer_size; size_t start = links[i].size_write % buffer_size;
// send out data from ring buffer // send out data from ring buffer
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite); ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite);
if (len != -1) { if (len != -1) {
links[i].size_write += len; links[i].size_write += len;
@ -648,15 +659,15 @@ AllreduceRobust::TryRecoverData(RecoverType role,
*/ */
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
// check in local data // check in local data
RecoverType role = requester ? kRequestData : kHaveData; RecoverType role = requester ? kRequestData : kHaveData;
ReturnType succ; ReturnType succ;
if (num_local_replica != 0) { if (num_local_replica != 0) {
if (requester) { if (requester) {
// clear existing history, if any, before load // clear existing history, if any, before load
local_rptr[local_chkpt_version].clear(); local_rptr[local_chkpt_version].clear();
local_chkpt[local_chkpt_version].clear(); local_chkpt[local_chkpt_version].clear();
} }
// recover local checkpoint // recover local checkpoint
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version], succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
&local_chkpt[local_chkpt_version]); &local_chkpt[local_chkpt_version]);
if (succ != kSuccess) return succ; if (succ != kSuccess) return succ;
@ -716,7 +727,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
// if we goes to this place, use must have already setup the state once // if we goes to this place, use must have already setup the state once
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1, utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
"TryGetResult::Checkpoint"); "TryGetResult::Checkpoint");
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]); return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
} }
// handles normal data recovery // handles normal data recovery
RecoverType role; RecoverType role;
@ -735,8 +746,9 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
utils::Check(data_size != 0, "zero size check point is not allowed"); utils::Check(data_size != 0, "zero size check point is not allowed");
if (role == kRequestData || role == kHaveData) { if (role == kRequestData || role == kHaveData) {
utils::Check(data_size == size, utils::Check(data_size == size,
"Allreduce Recovered data size do not match the specification of function call\n"\ "Allreduce Recovered data size do not match the specification of function call.\n"\
"Please check if calling sequence of recovered program is the same the original one in current VersionNumber"); "Please check if calling sequence of recovered program is the " \
"same the original one in current VersionNumber");
} }
return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in); return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
} }
@ -766,7 +778,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
while (true) { while (true) {
this->ReportStatus(); this->ReportStatus();
// action // action
ActionSummary act = req; ActionSummary act = req;
// get the reduced action // get the reduced action
if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue; if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
if (act.check_ack()) { if (act.check_ack()) {
@ -816,7 +828,8 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue; if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
if (requester) return true; if (requester) return true;
} else { } else {
// all the request is same, this is most recent command that is yet to be executed // all the request is same,
// this is most recent command that is yet to be executed
return false; return false;
} }
} }
@ -855,7 +868,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent"); utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
} }
const int n = num_local_replica; const int n = num_local_replica;
{// backward passing, passing state in backward direction of the ring {
// backward passing, passing state in backward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1); const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica"); utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_back(n + 1); std::vector<int> msg_back(n + 1);
@ -897,10 +911,10 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
// update rptr // update rptr
rptr.resize(nread_end + 1); rptr.resize(nread_end + 1);
for (int i = nlocal; i < nread_end; ++i) { for (int i = nlocal; i < nread_end; ++i) {
rptr[i + 1] = rptr[i] + sizes[i]; rptr[i + 1] = rptr[i] + sizes[i];
} }
chkpt.resize(rptr.back()); chkpt.resize(rptr.back());
// pass data through the link // pass data through the link
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nread_end], rptr[nwrite_start], rptr[nread_end],
ring_next, ring_prev); ring_next, ring_prev);
@ -908,7 +922,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
} }
} }
{// forward passing, passing state in forward direction of the ring {
// forward passing, passing state in forward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1); const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica"); utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_forward(n + 1); std::vector<int> msg_forward(n + 1);
@ -926,7 +941,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
1 * sizeof(int), 2 * sizeof(int), 1 * sizeof(int), 2 * sizeof(int),
0 * sizeof(int), 1 * sizeof(int), 0 * sizeof(int), 1 * sizeof(int),
ring_next, ring_prev); ring_next, ring_prev);
if (succ != kSuccess) return succ; if (succ != kSuccess) return succ;
// calculate the number of things we can read from next link // calculate the number of things we can read from next link
int nread_end = nlocal, nwrite_end = 1; int nread_end = nlocal, nwrite_end = 1;
// have to have itself in order to get other data from prev link // have to have itself in order to get other data from prev link
@ -936,7 +951,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
nread_end = std::max(nread_end, i + 1); nread_end = std::max(nread_end, i + 1);
nwrite_end = i + 1; nwrite_end = i + 1;
} }
if (nwrite_end > n) nwrite_end = n; if (nwrite_end > n) nwrite_end = n;
} else { } else {
nread_end = 0; nwrite_end = 0; nread_end = 0; nwrite_end = 0;
} }
@ -963,7 +978,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
rptr[i + 1] = rptr[i] + sizes[i]; rptr[i + 1] = rptr[i] + sizes[i];
} }
chkpt.resize(rptr.back()); chkpt.resize(rptr.back());
// pass data through the link // pass data through the link
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nwrite_end], rptr[nwrite_start], rptr[nwrite_end],
ring_prev, ring_next); ring_prev, ring_next);
@ -995,7 +1010,8 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
if (num_local_replica == 0) return kSuccess; if (num_local_replica == 0) return kSuccess;
std::vector<size_t> &rptr = *p_local_rptr; std::vector<size_t> &rptr = *p_local_rptr;
std::string &chkpt = *p_local_chkpt; std::string &chkpt = *p_local_chkpt;
utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state"); utils::Assert(rptr.size() == 2,
"TryCheckinLocalState must have exactly 1 state");
const int n = num_local_replica; const int n = num_local_replica;
std::vector<size_t> sizes(n + 1); std::vector<size_t> sizes(n + 1);
sizes[0] = rptr[1] - rptr[0]; sizes[0] = rptr[1] - rptr[0];
@ -1012,9 +1028,9 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
rptr.resize(n + 2); rptr.resize(n + 2);
for (int i = 1; i <= n; ++i) { for (int i = 1; i <= n; ++i) {
rptr[i + 1] = rptr[i] + sizes[i]; rptr[i + 1] = rptr[i] + sizes[i];
} }
chkpt.resize(rptr.back()); chkpt.resize(rptr.back());
// pass data through the link // pass data through the link
succ = RingPassing(BeginPtr(chkpt), succ = RingPassing(BeginPtr(chkpt),
rptr[1], rptr[n + 1], rptr[1], rptr[n + 1],
rptr[0], rptr[n], rptr[0], rptr[n],
@ -1050,13 +1066,14 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
LinkRecord *read_link, LinkRecord *read_link,
LinkRecord *write_link) { LinkRecord *write_link) {
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess; if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
utils::Assert(write_end <= read_end, "RingPassing: boundary check1, write_end=%lu, read_end=%lu", write_end, read_end); utils::Assert(write_end <= read_end,
"RingPassing: boundary check1");
utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2"); utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2");
utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3"); utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3");
// take reference // take reference
LinkRecord &prev = *read_link, &next = *write_link; LinkRecord &prev = *read_link, &next = *write_link;
// send recv buffer // send recv buffer
char *buf = reinterpret_cast<char*>(sendrecvbuf_); char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) { while (true) {
bool finished = true; bool finished = true;
utils::SelectHelper selecter; utils::SelectHelper selecter;
@ -1066,7 +1083,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
} }
if (write_ptr < read_ptr && write_ptr != write_end) { if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock); selecter.WatchWrite(next.sock);
finished = false; finished = false;
} }
selecter.WatchException(prev.sock); selecter.WatchException(prev.sock);
selecter.WatchException(next.sock); selecter.WatchException(next.sock);
@ -1078,7 +1095,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr); ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) { if (len == 0) {
prev.sock.Close(); return kSockError; prev.sock.Close(); return kSockError;
} }
if (len != -1) { if (len != -1) {
read_ptr += static_cast<size_t>(len); read_ptr += static_cast<size_t>(len);
} else { } else {

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.h * \file allreduce_robust.h
* \brief Robust implementation of Allreduce * \brief Robust implementation of Allreduce
* using TCP non-block socket and tree-shape reduction. * using TCP non-block socket and tree-shape reduction.
@ -7,10 +8,12 @@
* *
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/ */
#ifndef RABIT_ALLREDUCE_ROBUST_H #ifndef RABIT_ALLREDUCE_ROBUST_H_
#define RABIT_ALLREDUCE_ROBUST_H #define RABIT_ALLREDUCE_ROBUST_H_
#include <vector> #include <vector>
#include <rabit/engine.h> #include <string>
#include <algorithm>
#include "rabit/engine.h"
#include "./allreduce_base.h" #include "./allreduce_base.h"
namespace rabit { namespace rabit {
@ -111,11 +114,11 @@ class AllreduceRobust : public AllreduceBase {
private: private:
// constant one byte out of band message to indicate error happening // constant one byte out of band message to indicate error happening
// and mark for channel cleanup // and mark for channel cleanup
const static char kOOBReset = 95; static const char kOOBReset = 95;
// and mark for channel cleanup, after OOB signal // and mark for channel cleanup, after OOB signal
const static char kResetMark = 97; static const char kResetMark = 97;
// and mark for channel cleanup // and mark for channel cleanup
const static char kResetAck = 97; static const char kResetAck = 97;
/*! \brief type of roles each node can play during recovery */ /*! \brief type of roles each node can play during recovery */
enum RecoverType { enum RecoverType {
/*! \brief current node have data */ /*! \brief current node have data */
@ -132,29 +135,29 @@ class AllreduceRobust : public AllreduceBase {
*/ */
struct ActionSummary { struct ActionSummary {
// maximumly allowed sequence id // maximumly allowed sequence id
const static int kSpecialOp = (1 << 26); static const int kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint // special sequence number for local state checkpoint
const static int kLocalCheckPoint = (1 << 26) - 2; static const int kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal // special sequnce number for local state checkpoint ack signal
const static int kLocalCheckAck = (1 << 26) - 1; static const int kLocalCheckAck = (1 << 26) - 1;
//--------------------------------------------- //---------------------------------------------
// The following are bit mask of flag used in // The following are bit mask of flag used in
//---------------------------------------------- //----------------------------------------------
// some node want to load check point // some node want to load check point
const static int kLoadCheck = 1; static const int kLoadCheck = 1;
// some node want to do check point // some node want to do check point
const static int kCheckPoint = 2; static const int kCheckPoint = 2;
// check point Ack, we use a two phase message in check point, // check point Ack, we use a two phase message in check point,
// this is the second phase of check pointing // this is the second phase of check pointing
const static int kCheckAck = 4; static const int kCheckAck = 4;
// there are difference sequence number the nodes proposed // there are difference sequence number the nodes proposed
// this means we want to do recover execution of the lower sequence // this means we want to do recover execution of the lower sequence
// action instead of normal execution // action instead of normal execution
const static int kDiffSeq = 8; static const int kDiffSeq = 8;
// constructor // constructor
ActionSummary(void) {} ActionSummary(void) {}
// constructor of action // constructor of action
ActionSummary(int flag, int minseqno = kSpecialOp) { explicit ActionSummary(int flag, int minseqno = kSpecialOp) {
seqcode = (minseqno << 4) | flag; seqcode = (minseqno << 4) | flag;
} }
// minimum number of all operations // minimum number of all operations
@ -181,10 +184,11 @@ class AllreduceRobust : public AllreduceBase {
inline int flag(void) const { inline int flag(void) const {
return seqcode & 15; return seqcode & 15;
} }
// reducer for Allreduce, used to get the result ActionSummary from all nodes // reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) { inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_; const ActionSummary *src = (const ActionSummary*)src_;
ActionSummary *dst = (ActionSummary*)dst_; ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
int src_seqno = src[i].min_seqno(); int src_seqno = src[i].min_seqno();
int dst_seqno = dst[i].min_seqno(); int dst_seqno = dst[i].min_seqno();
@ -192,7 +196,8 @@ class AllreduceRobust : public AllreduceBase {
if (src_seqno == dst_seqno) { if (src_seqno == dst_seqno) {
dst[i] = ActionSummary(flag, src_seqno); dst[i] = ActionSummary(flag, src_seqno);
} else { } else {
dst[i] = ActionSummary(flag | kDiffSeq, std::min(src_seqno, dst_seqno)); dst[i] = ActionSummary(flag | kDiffSeq,
std::min(src_seqno, dst_seqno));
} }
} }
} }
@ -222,7 +227,7 @@ class AllreduceRobust : public AllreduceBase {
data_.resize(rptr_.back() + nhop); data_.resize(rptr_.back() + nhop);
return BeginPtr(data_) + rptr_.back(); return BeginPtr(data_) + rptr_.back();
} }
// push the result in temp to the // push the result in temp to the
inline void PushTemp(int seqid, size_t type_nbytes, size_t count) { inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
size_t size = type_nbytes * count; size_t size = type_nbytes * count;
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
@ -234,13 +239,14 @@ class AllreduceRobust : public AllreduceBase {
size_.push_back(size); size_.push_back(size);
utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent"); utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
} }
// return the stored result of seqid, if any // return the stored result of seqid, if any
inline void* Query(int seqid, size_t *p_size) { inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin(); size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL; if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
*p_size = size_[idx]; *p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx]; return BeginPtr(data_) + rptr_[idx];
} }
// drop last stored result // drop last stored result
inline void DropLast(void) { inline void DropLast(void) {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped"); utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
@ -254,15 +260,16 @@ class AllreduceRobust : public AllreduceBase {
if (seqno_.size() == 0) return -1; if (seqno_.size() == 0) return -1;
return seqno_.back(); return seqno_.back();
} }
private: private:
// sequence number of each // sequence number of each
std::vector<int> seqno_; std::vector<int> seqno_;
// pointer to the positions // pointer to the positions
std::vector<size_t> rptr_; std::vector<size_t> rptr_;
// actual size of each buffer // actual size of each buffer
std::vector<size_t> size_; std::vector<size_t> size_;
// content of the buffer // content of the buffer
std::vector<uint64_t> data_; std::vector<uint64_t> data_;
}; };
/*! /*!
* \brief reset the all the existing links by sending Out-of-Band message marker * \brief reset the all the existing links by sending Out-of-Band message marker
@ -291,14 +298,16 @@ class AllreduceRobust : public AllreduceBase {
* \param buf the buffer to store the result * \param buf the buffer to store the result
* \param size the total size of the buffer * \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary * \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set, seqno needs to be set to ActionSummary::kSpecialOp * \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kSpecialOp
* *
* \return if this function can return true or false * \return if this function can return true or false
* - true means buf already set to the * - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed * result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action * - false means this is the lastest action that has not yet been executed, need to execute the action
*/ */
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kSpecialOp); bool RecoverExec(void *buf, size_t size, int flag,
int seqno = ActionSummary::kSpecialOp);
/*! /*!
* \brief try to load check point * \brief try to load check point
* *
@ -363,7 +372,7 @@ class AllreduceRobust : public AllreduceBase {
void *sendrecvbuf_, void *sendrecvbuf_,
size_t size, size_t size,
int recv_link, int recv_link,
const std::vector<bool> &req_in); const std::vector<bool> &req_in);
/*! /*!
* \brief try to recover the local state, making each local state to be the result of itself * \brief try to recover the local state, making each local state to be the result of itself
* plus replication of states in previous num_local_replica hops in the ring * plus replication of states in previous num_local_replica hops in the ring
@ -446,17 +455,17 @@ o * the input state must exactly one saved state(local state of current node)
inline ReturnType MsgPassing(const NodeType &node_value, inline ReturnType MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in, std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out, std::vector<EdgeType> *p_edge_out,
EdgeType (*func) (const NodeType &node_value, EdgeType (*func)
const std::vector<EdgeType> &edge_in, (const NodeType &node_value,
size_t out_index) const std::vector<EdgeType> &edge_in,
); size_t out_index));
//---- recovery data structure ---- //---- recovery data structure ----
// the round of result buffer, used to mode the result // the round of result buffer, used to mode the result
int result_buffer_round; int result_buffer_round;
// result buffer of all reduce // result buffer of all reduce
ResultBuffer resbuf; ResultBuffer resbuf;
// last check point global model // last check point global model
std::string global_checkpoint; std::string global_checkpoint;
// number of replica for local state/model // number of replica for local state/model
int num_local_replica; int num_local_replica;
// --- recovery data structure for local checkpoint // --- recovery data structure for local checkpoint
@ -465,16 +474,15 @@ o * the input state must exactly one saved state(local state of current node)
// pointer to memory position in the local model // pointer to memory position in the local model
// local model is stored in CSR format(like a sparse matrices) // local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node // local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2]; std::vector<size_t> local_rptr[2];
// storage for local model replicas // storage for local model replicas
std::string local_chkpt[2]; std::string local_chkpt[2];
// version of local checkpoint can be 1 or 0 // version of local checkpoint can be 1 or 0
int local_chkpt_version; int local_chkpt_version;
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit
// implementation of inline template function // implementation of inline template function
#include "./allreduce_robust-inl.h" #include "./allreduce_robust-inl.h"
#endif // RABIT_ALLREDUCE_ROBUST_H_
#endif // RABIT_ALLREDUCE_ROBUST_H

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file engine.cc * \file engine.cc
* \brief this file governs which implementation of engine we are actually using * \brief this file governs which implementation of engine we are actually using
* provides an singleton of engine interface * provides an singleton of engine interface
@ -41,16 +42,17 @@ void Finalize(void) {
IEngine *GetEngine(void) { IEngine *GetEngine(void) {
return &manager; return &manager;
} }
// perform in-place allreduce, on sendrecvbuf // perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, void Allreduce_(void *sendrecvbuf,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
IEngine::ReduceFunction red, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::DataType dtype,
mpi::OpType op, mpi::OpType op,
IEngine::PreprocFunction prepare_fun, IEngine::PreprocFunction prepare_fun,
void *prepare_arg) { void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg); GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
red, prepare_fun, prepare_arg);
} }
// code for reduce handle // code for reduce handle

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file engine_empty.cc * \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing * \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies * this file provides a way to fall back to single node program without causing too many dependencies
@ -25,9 +26,10 @@ class EmptyEngine : public IEngine {
ReduceFunction reducer, ReduceFunction reducer,
PreprocFunction prepare_fun, PreprocFunction prepare_fun,
void *prepare_arg) { void *prepare_arg) {
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead"); utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
} }
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
} }
virtual void InitAfterException(void) { virtual void InitAfterException(void) {
utils::Error("EmptyEngine is not fault tolerant"); utils::Error("EmptyEngine is not fault tolerant");
@ -51,7 +53,7 @@ class EmptyEngine : public IEngine {
virtual int GetWorldSize(void) const { virtual int GetWorldSize(void) const {
return 1; return 1;
} }
/*! \brief get the host name of current node */ /*! \brief get the host name of current node */
virtual std::string GetHost(void) const { virtual std::string GetHost(void) const {
return std::string(""); return std::string("");
} }
@ -59,6 +61,7 @@ class EmptyEngine : public IEngine {
// simply print information into the tracker // simply print information into the tracker
utils::Printf("%s", msg.c_str()); utils::Printf("%s", msg.c_str());
} }
private: private:
int version_number; int version_number;
}; };
@ -77,11 +80,11 @@ void Finalize(void) {
IEngine *GetEngine(void) { IEngine *GetEngine(void) {
return &manager; return &manager;
} }
// perform in-place allreduce, on sendrecvbuf // perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, void Allreduce_(void *sendrecvbuf,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
IEngine::ReduceFunction red, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::DataType dtype,
mpi::OpType op, mpi::OpType op,
IEngine::PreprocFunction prepare_fun, IEngine::PreprocFunction prepare_fun,

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc * \file engine_mock.cc
* \brief this is an engine implementation that will * \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure * insert failures in certain call point, to test if the engine is robust to failure

View File

@ -1,4 +1,5 @@
/*! /*!
* Copyright (c) 2014 by Contributors
* \file engine_mpi.cc * \file engine_mpi.cc
* \brief this file gives an implementation of engine interface using MPI, * \brief this file gives an implementation of engine interface using MPI,
* this will allow rabit program to run with MPI, but do not comes with fault tolerant * this will allow rabit program to run with MPI, but do not comes with fault tolerant
@ -8,10 +9,10 @@
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX #define NOMINMAX
#include <cstdio>
#include <rabit/engine.h>
#include <rabit/utils.h>
#include <mpi.h> #include <mpi.h>
#include <cstdio>
#include "rabit/engine.h"
#include "rabit/utils.h"
namespace rabit { namespace rabit {
namespace engine { namespace engine {
@ -27,9 +28,10 @@ class MPIEngine : public IEngine {
ReduceFunction reducer, ReduceFunction reducer,
PreprocFunction prepare_fun, PreprocFunction prepare_fun,
void *prepare_arg) { void *prepare_arg) {
utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead"); utils::Error("MPIEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
} }
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) { virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root); MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
} }
virtual void InitAfterException(void) { virtual void InitAfterException(void) {
@ -48,13 +50,13 @@ class MPIEngine : public IEngine {
} }
/*! \brief get rank of current node */ /*! \brief get rank of current node */
virtual int GetRank(void) const { virtual int GetRank(void) const {
return MPI::COMM_WORLD.Get_rank(); return MPI::COMM_WORLD.Get_rank();
} }
/*! \brief get total number of */ /*! \brief get total number of */
virtual int GetWorldSize(void) const { virtual int GetWorldSize(void) const {
return MPI::COMM_WORLD.Get_size(); return MPI::COMM_WORLD.Get_size();
} }
/*! \brief get the host name of current node */ /*! \brief get the host name of current node */
virtual std::string GetHost(void) const { virtual std::string GetHost(void) const {
int len; int len;
char name[MPI_MAX_PROCESSOR_NAME]; char name[MPI_MAX_PROCESSOR_NAME];
@ -68,6 +70,7 @@ class MPIEngine : public IEngine {
utils::Printf("%s", msg.c_str()); utils::Printf("%s", msg.c_str());
} }
} }
private: private:
int version_number; int version_number;
}; };
@ -91,7 +94,7 @@ IEngine *GetEngine(void) {
// transform enum to MPI data type // transform enum to MPI data type
inline MPI::Datatype GetType(mpi::DataType dtype) { inline MPI::Datatype GetType(mpi::DataType dtype) {
using namespace mpi; using namespace mpi;
switch(dtype) { switch (dtype) {
case kInt: return MPI::INT; case kInt: return MPI::INT;
case kUInt: return MPI::UNSIGNED; case kUInt: return MPI::UNSIGNED;
case kFloat: return MPI::FLOAT; case kFloat: return MPI::FLOAT;
@ -103,7 +106,7 @@ inline MPI::Datatype GetType(mpi::DataType dtype) {
// transform enum to MPI OP // transform enum to MPI OP
inline MPI::Op GetOp(mpi::OpType otype) { inline MPI::Op GetOp(mpi::OpType otype) {
using namespace mpi; using namespace mpi;
switch(otype) { switch (otype) {
case kMax: return MPI::MAX; case kMax: return MPI::MAX;
case kMin: return MPI::MIN; case kMin: return MPI::MIN;
case kSum: return MPI::SUM; case kSum: return MPI::SUM;
@ -112,17 +115,18 @@ inline MPI::Op GetOp(mpi::OpType otype) {
utils::Error("unknown mpi::OpType"); utils::Error("unknown mpi::OpType");
return MPI::MAX; return MPI::MAX;
} }
// perform in-place allreduce, on sendrecvbuf // perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf, void Allreduce_(void *sendrecvbuf,
size_t type_nbytes, size_t type_nbytes,
size_t count, size_t count,
IEngine::ReduceFunction red, IEngine::ReduceFunction red,
mpi::DataType dtype, mpi::DataType dtype,
mpi::OpType op, mpi::OpType op,
IEngine::PreprocFunction prepare_fun, IEngine::PreprocFunction prepare_fun,
void *prepare_arg) { void *prepare_arg) {
if (prepare_fun != NULL) prepare_fun(prepare_arg); if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op)); MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
count, GetType(dtype), GetOp(op));
} }
// code for reduce handle // code for reduce handle
@ -152,7 +156,7 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
created_type_nbytes_ = type_nbytes; created_type_nbytes_ = type_nbytes;
htype_ = dtype; htype_ = dtype;
} }
MPI::Op *op = new MPI::Op(); MPI::Op *op = new MPI::Op();
MPI::User_function *pf = redfunc; MPI::User_function *pf = redfunc;
op->Init(pf, true); op->Init(pf, true);
@ -175,7 +179,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
dtype->Commit(); dtype->Commit();
created_type_nbytes_ = type_nbytes; created_type_nbytes_ = type_nbytes;
} }
if (prepare_fun != NULL) prepare_fun(prepare_arg); if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op); MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
} }
} // namespace engine } // namespace engine

View File

@ -1,10 +1,11 @@
#ifndef RABIT_SOCKET_H
#define RABIT_SOCKET_H
/*! /*!
* Copyright (c) 2014 by Contributors
* \file socket.h * \file socket.h
* \brief this file aims to provide a wrapper of sockets * \brief this file aims to provide a wrapper of sockets
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_SOCKET_H_
#define RABIT_SOCKET_H_
#if defined(_WIN32) #if defined(_WIN32)
#include <winsock2.h> #include <winsock2.h>
#include <ws2tcpip.h> #include <ws2tcpip.h>
@ -21,7 +22,7 @@
#endif #endif
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <rabit/utils.h> #include "rabit/utils.h"
#if defined(_WIN32) #if defined(_WIN32)
typedef int ssize_t; typedef int ssize_t;
@ -68,9 +69,11 @@ struct SockAddr {
inline std::string AddrStr(void) const { inline std::string AddrStr(void) const {
std::string buf; buf.resize(256); std::string buf; buf.resize(256);
#ifdef _WIN32 #ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, &buf[0], buf.length()); const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length());
#else #else
const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length()); const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif #endif
Assert(s != NULL, "cannot decode address"); Assert(s != NULL, "cannot decode address");
return std::string(s); return std::string(s);
@ -94,12 +97,12 @@ class Socket {
*/ */
inline static void Startup(void) { inline static void Startup(void) {
#ifdef _WIN32 #ifdef _WIN32
WSADATA wsa_data; WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) { if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) {
Socket::Error("Startup"); Socket::Error("Startup");
} }
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup(); WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n"); utils::Error("Could not find a usable version of Winsock.dll\n");
} }
#endif #endif
@ -118,11 +121,11 @@ class Socket {
* it will set it back to block mode * it will set it back to block mode
*/ */
inline void SetNonBlock(bool non_block) { inline void SetNonBlock(bool non_block) {
#ifdef _WIN32 #ifdef _WIN32
u_long mode = non_block ? 1 : 0; u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock"); Socket::Error("SetNonBlock");
} }
#else #else
int flag = fcntl(sockfd, F_GETFL, 0); int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) { if (flag == -1) {
@ -143,7 +146,8 @@ class Socket {
* \param addr * \param addr
*/ */
inline void Bind(const SockAddr &addr) { inline void Bind(const SockAddr &addr) {
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) { if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) {
Socket::Error("Bind"); Socket::Error("Bind");
} }
} }
@ -154,10 +158,11 @@ class Socket {
* \return the port successfully bind to, return -1 if failed to bind any port * \return the port successfully bind to, return -1 if failed to bind any port
*/ */
inline int TryBindHost(int start_port, int end_port) { inline int TryBindHost(int start_port, int end_port) {
// TODO, add prefix check // TODO(tqchen) add prefix check
for (int port = start_port; port < end_port; ++port) { for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port); SockAddr addr("0.0.0.0", port);
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) { if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0) {
return port; return port;
} }
if (errno != EADDRINUSE) { if (errno != EADDRINUSE) {
@ -179,22 +184,22 @@ class Socket {
inline bool BadSocket(void) const { inline bool BadSocket(void) const {
if (IsClosed()) return true; if (IsClosed()) return true;
int err = GetSockError(); int err = GetSockError();
if (err == EBADF || err == EINTR) return true; if (err == EBADF || err == EINTR) return true;
return false; return false;
} }
/*! \brief check if socket is already closed */ /*! \brief check if socket is already closed */
inline bool IsClosed(void) const { inline bool IsClosed(void) const {
return sockfd == INVALID_SOCKET; return sockfd == INVALID_SOCKET;
} }
/*! \brief close the socket */ /*! \brief close the socket */
inline void Close(void) { inline void Close(void) {
if (sockfd != INVALID_SOCKET) { if (sockfd != INVALID_SOCKET) {
#ifdef _WIN32 #ifdef _WIN32
closesocket(sockfd); closesocket(sockfd);
#else #else
close(sockfd); close(sockfd);
#endif #endif
sockfd = INVALID_SOCKET; sockfd = INVALID_SOCKET;
} else { } else {
Error("Socket::Close double close the socket or close without create"); Error("Socket::Close double close the socket or close without create");
} }
@ -204,6 +209,7 @@ class Socket {
int errsv = errno; int errsv = errno;
utils::Error("Socket %s Error:%s", msg, strerror(errsv)); utils::Error("Socket %s Error:%s", msg, strerror(errsv));
} }
protected: protected:
explicit Socket(SOCKET sockfd) : sockfd(sockfd) { explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
} }
@ -227,7 +233,7 @@ class TCPSocket : public Socket{
int opt = static_cast<int>(keepalive); int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) { if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive"); Socket::Error("SetKeepAlive");
} }
} }
/*! /*!
* \brief create the socket, call this before using socket * \brief create the socket, call this before using socket
@ -273,7 +279,8 @@ class TCPSocket : public Socket{
* \return whether connect is successful * \return whether connect is successful
*/ */
inline bool Connect(const SockAddr &addr) { inline bool Connect(const SockAddr &addr) {
return connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0; return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0;
} }
/*! /*!
* \brief send data using the socket * \brief send data using the socket
@ -284,7 +291,7 @@ class TCPSocket : public Socket{
* return -1 if error occurs * return -1 if error occurs
*/ */
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_); const char *buf = reinterpret_cast<const char*>(buf_);
return send(sockfd, buf, static_cast<sock_size_t>(len), flag); return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
} }
/*! /*!
@ -296,7 +303,7 @@ class TCPSocket : public Socket{
* return -1 if error occurs * return -1 if error occurs
*/ */
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_); char *buf = reinterpret_cast<char*>(buf_);
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
} }
/*! /*!
@ -331,7 +338,8 @@ class TCPSocket : public Socket{
char *buf = reinterpret_cast<char*>(buf_); char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0; size_t ndone = 0;
while (ndone < len) { while (ndone < len) {
ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) { if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone; if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
Socket::Error("RecvAll"); Socket::Error("RecvAll");
@ -385,7 +393,7 @@ struct SelectHelper {
* \param fd file descriptor to be watched * \param fd file descriptor to be watched
*/ */
inline void WatchRead(SOCKET fd) { inline void WatchRead(SOCKET fd) {
FD_SET(fd, &read_set); FD_SET(fd, &read_set);
if (fd > maxfd) maxfd = fd; if (fd > maxfd) maxfd = fd;
} }
/*! /*!
@ -403,7 +411,7 @@ struct SelectHelper {
inline void WatchException(SOCKET fd) { inline void WatchException(SOCKET fd) {
FD_SET(fd, &except_set); FD_SET(fd, &except_set);
if (fd > maxfd) maxfd = fd; if (fd > maxfd) maxfd = fd;
} }
/*! /*!
* \brief Check if the descriptor is ready for read * \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status * \param fd file descriptor to check status
@ -435,8 +443,9 @@ struct SelectHelper {
fd_set wait_set; fd_set wait_set;
FD_ZERO(&wait_set); FD_ZERO(&wait_set);
FD_SET(fd, &wait_set); FD_SET(fd, &wait_set);
return Select_(static_cast<int>(fd + 1), NULL, NULL, &wait_set, timeout); return Select_(static_cast<int>(fd + 1),
} NULL, NULL, &wait_set, timeout);
}
/*! /*!
* \brief peform select on the set defined * \brief peform select on the set defined
* \param select_read whether to watch for read event * \param select_read whether to watch for read event
@ -454,9 +463,10 @@ struct SelectHelper {
} }
return ret; return ret;
} }
private: private:
inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) { inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) {
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE"); utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
if (timeout == 0) { if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL); return select(maxfd, rfds, wfds, efds, NULL);
@ -465,12 +475,12 @@ struct SelectHelper {
tm.tv_usec = (timeout % 1000) * 1000; tm.tv_usec = (timeout % 1000) * 1000;
tm.tv_sec = timeout / 1000; tm.tv_sec = timeout / 1000;
return select(maxfd, rfds, wfds, efds, &tm); return select(maxfd, rfds, wfds, efds, &tm);
} }
} }
SOCKET maxfd; SOCKET maxfd;
fd_set read_set, write_set, except_set; fd_set read_set, write_set, except_set;
}; };
} // namespace utils } // namespace utils
} // namespace rabit } // namespace rabit
#endif #endif // RABIT_SOCKET_H_