cpplint pass
This commit is contained in:
parent
15836eb98e
commit
27d6977a3e
2
Makefile
2
Makefile
@ -36,4 +36,4 @@ $(ALIB):
|
||||
ar cr $@ $+
|
||||
|
||||
clean:
|
||||
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~
|
||||
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#ifndef RABIT_RABIT_H
|
||||
#define RABIT_RABIT_H
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file rabit.h
|
||||
* \brief This file defines unified Allreduce/Broadcast interface of rabit
|
||||
* The actual implementation is redirected to rabit engine
|
||||
@ -9,12 +8,14 @@
|
||||
* rabit.h and serializable.h is all the user need to use rabit interface
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_RABIT_H_
|
||||
#define RABIT_RABIT_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
// optionally support of lambda function in C++11, if available
|
||||
#if __cplusplus >= 201103L
|
||||
#include <functional>
|
||||
#endif // C++11
|
||||
#endif // C++11
|
||||
// contains definition of ISerializable
|
||||
#include "./rabit_serializable.h"
|
||||
// engine definition of rabit, defines internal implementation
|
||||
@ -116,7 +117,7 @@ inline void Broadcast(std::string *sendrecv_data, int root);
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
void (*prepare_fun)(void *arg) = NULL,
|
||||
void (*prepare_fun)(void *arg) = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
|
||||
// C++11 support for lambda prepare function
|
||||
@ -142,9 +143,9 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
* \tparam DType type of data
|
||||
*/
|
||||
template<typename OP, typename DType>
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun);
|
||||
#endif // C++11
|
||||
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun);
|
||||
#endif // C++11
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
@ -228,6 +229,7 @@ class Reducer {
|
||||
inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
std::function<void()> prepare_fun);
|
||||
#endif
|
||||
|
||||
private:
|
||||
/*! \brief function handle to do reduce */
|
||||
engine::ReduceHandle handle_;
|
||||
@ -274,6 +276,7 @@ class SerializeReducer {
|
||||
size_t max_nbyte, size_t count,
|
||||
std::function<void()> prepare_fun);
|
||||
#endif
|
||||
|
||||
private:
|
||||
/*! \brief function handle to do reduce */
|
||||
engine::ReduceHandle handle_;
|
||||
@ -283,4 +286,4 @@ class SerializeReducer {
|
||||
} // namespace rabit
|
||||
// implementation of template functions
|
||||
#include "./rabit/rabit-inl.h"
|
||||
#endif // RABIT_ALLREDUCE_H
|
||||
#endif // RABIT_RABIT_H_
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine.h
|
||||
* \brief This file defines the core interface of allreduce library
|
||||
* \author Tianqi Chen, Nacho, Tianyi
|
||||
*/
|
||||
#ifndef RABIT_ENGINE_H
|
||||
#define RABIT_ENGINE_H
|
||||
#ifndef RABIT_ENGINE_H_
|
||||
#define RABIT_ENGINE_H_
|
||||
#include <string>
|
||||
#include "../rabit_serializable.h"
|
||||
|
||||
namespace MPI {
|
||||
@ -122,7 +124,7 @@ class IEngine {
|
||||
virtual int GetRank(void) const = 0;
|
||||
/*! \brief get total number of */
|
||||
virtual int GetWorldSize(void) const = 0;
|
||||
/*! \brief get the host name of current node */
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const = 0;
|
||||
/*!
|
||||
* \brief print the msg in the tracker,
|
||||
@ -211,7 +213,7 @@ class ReduceHandle {
|
||||
/*! \return the number of bytes occupied by the type */
|
||||
static int TypeSize(const MPI::Datatype &dtype);
|
||||
|
||||
private:
|
||||
protected:
|
||||
// handle data field
|
||||
void *handle_;
|
||||
// handle to the type field
|
||||
@ -221,5 +223,4 @@ class ReduceHandle {
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ENGINE_H
|
||||
|
||||
#endif // RABIT_ENGINE_H_
|
||||
|
||||
@ -1,16 +1,19 @@
|
||||
#ifndef RABIT_UTILS_IO_H
|
||||
#define RABIT_UTILS_IO_H
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include "./utils.h"
|
||||
#include "../rabit_serializable.h"
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file io.h
|
||||
* \brief utilities that implements different serializable interface
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_UTILS_IO_H_
|
||||
#define RABIT_UTILS_IO_H_
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "./utils.h"
|
||||
#include "../rabit_serializable.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*! \brief interface of i/o stream that support seek */
|
||||
@ -25,8 +28,9 @@ class ISeekStream: public IStream {
|
||||
/*! \brief fixed size memory buffer */
|
||||
struct MemoryFixSizeBuffer : public ISeekStream {
|
||||
public:
|
||||
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
||||
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
|
||||
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
||||
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
|
||||
buffer_size_(buffer_size) {
|
||||
curr_ptr_ = 0;
|
||||
}
|
||||
virtual ~MemoryFixSizeBuffer(void) {}
|
||||
@ -40,7 +44,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
if (size == 0) return;
|
||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
||||
"write position exceed fixed buffer size");
|
||||
memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
@ -59,12 +63,12 @@ struct MemoryFixSizeBuffer : public ISeekStream {
|
||||
size_t buffer_size_;
|
||||
/*! \brief current pointer */
|
||||
size_t curr_ptr_;
|
||||
}; // class MemoryFixSizeBuffer
|
||||
}; // class MemoryFixSizeBuffer
|
||||
|
||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||
struct MemoryBufferStream : public ISeekStream {
|
||||
public:
|
||||
MemoryBufferStream(std::string *p_buffer)
|
||||
explicit MemoryBufferStream(std::string *p_buffer)
|
||||
: p_buffer_(p_buffer) {
|
||||
curr_ptr_ = 0;
|
||||
}
|
||||
@ -82,7 +86,7 @@ struct MemoryBufferStream : public ISeekStream {
|
||||
if (curr_ptr_ + size > p_buffer_->length()) {
|
||||
p_buffer_->resize(curr_ptr_+size);
|
||||
}
|
||||
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
||||
curr_ptr_ += size;
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
@ -97,36 +101,7 @@ struct MemoryBufferStream : public ISeekStream {
|
||||
std::string *p_buffer_;
|
||||
/*! \brief current pointer */
|
||||
size_t curr_ptr_;
|
||||
}; // class MemoryBufferStream
|
||||
|
||||
/*! \brief implementation of file i/o stream */
|
||||
class FileStream : public ISeekStream {
|
||||
public:
|
||||
explicit FileStream(FILE *fp) : fp(fp) {}
|
||||
explicit FileStream(void) {
|
||||
this->fp = NULL;
|
||||
}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
return std::fread(ptr, size, 1, fp);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
std::fwrite(ptr, size, 1, fp);
|
||||
}
|
||||
virtual void Seek(size_t pos) {
|
||||
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
|
||||
}
|
||||
virtual size_t Tell(void) {
|
||||
return std::ftell(fp);
|
||||
}
|
||||
inline void Close(void) {
|
||||
if (fp != NULL){
|
||||
std::fclose(fp); fp = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
FILE *fp;
|
||||
};
|
||||
}; // class MemoryBufferStream
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif
|
||||
#endif // RABIT_UTILS_IO_H_
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
#ifndef RABIT_UTILS_H_
|
||||
#define RABIT_UTILS_H_
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file utils.h
|
||||
* \brief simple utils to support the code
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_UTILS_H_
|
||||
#define RABIT_UTILS_H_
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
@ -19,7 +20,7 @@
|
||||
#define fopen64 std::fopen
|
||||
#endif
|
||||
#ifdef _MSC_VER
|
||||
// NOTE: sprintf_s is not equivalent to snprintf,
|
||||
// NOTE: sprintf_s is not equivalent to snprintf,
|
||||
// they are equivalent when success, which is sufficient for our case
|
||||
#define snprintf sprintf_s
|
||||
#define vsnprintf vsprintf_s
|
||||
@ -30,7 +31,7 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
#ifdef __APPLE__
|
||||
#define off64_t off_t
|
||||
#define fopen64 std::fopen
|
||||
#endif
|
||||
@ -186,5 +187,5 @@ inline const char* BeginPtr(const std::string &str) {
|
||||
if (str.length() == 0) return NULL;
|
||||
return &str[0];
|
||||
}
|
||||
} // namespace rabit
|
||||
} // namespace rabit
|
||||
#endif // RABIT_UTILS_H_
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
#ifndef RABIT_RABIT_SERIALIZABLE_H
|
||||
#define RABIT_RABIT_SERIALIZABLE_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "./rabit/utils.h"
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file serializable.h
|
||||
* \brief defines serializable interface of rabit
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_RABIT_SERIALIZABLE_H_
|
||||
#define RABIT_RABIT_SERIALIZABLE_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "./rabit/utils.h"
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief interface of stream I/O, used by ISerializable
|
||||
@ -96,4 +97,4 @@ class ISerializable {
|
||||
virtual void Save(IStream &fo) const = 0;
|
||||
};
|
||||
} // namespace rabit
|
||||
#endif
|
||||
#endif // RABIT_RABIT_SERIALIZABLE_H_
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_base.cc
|
||||
* \brief Basic implementation of AllReduce
|
||||
*
|
||||
@ -32,13 +33,15 @@ AllreduceBase::AllreduceBase(void) {
|
||||
// initialization function
|
||||
void AllreduceBase::Init(void) {
|
||||
// setup from enviroment variables
|
||||
{// handling for hadoop
|
||||
{
|
||||
// handling for hadoop
|
||||
const char *task_id = getenv("mapred_tip_id");
|
||||
if (task_id == NULL) {
|
||||
task_id = getenv("mapreduce_task_id");
|
||||
}
|
||||
if (hadoop_mode != 0) {
|
||||
utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id");
|
||||
utils::Check(task_id != NULL,
|
||||
"hadoop_mode is set but cannot find mapred_task_id");
|
||||
}
|
||||
if (task_id != NULL) {
|
||||
this->SetParam("rabit_task_id", task_id);
|
||||
@ -48,7 +51,7 @@ void AllreduceBase::Init(void) {
|
||||
if (attempt_id != 0) {
|
||||
const char *att = strrchr(attempt_id, '_');
|
||||
int num_trial;
|
||||
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
|
||||
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
|
||||
this->SetParam("rabit_num_trial", att + 1);
|
||||
}
|
||||
}
|
||||
@ -58,7 +61,8 @@ void AllreduceBase::Init(void) {
|
||||
num_task = getenv("mapreduce_job_maps");
|
||||
}
|
||||
if (hadoop_mode != 0) {
|
||||
utils::Check(num_task != NULL, "hadoop_mode is set but cannot find mapred_map_tasks");
|
||||
utils::Check(num_task != NULL,
|
||||
"hadoop_mode is set but cannot find mapred_map_tasks");
|
||||
}
|
||||
if (num_task != NULL) {
|
||||
this->SetParam("rabit_world_size", num_task);
|
||||
@ -81,11 +85,11 @@ void AllreduceBase::Shutdown(void) {
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
|
||||
|
||||
if (tracker_uri == "NULL") return;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
tracker.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
}
|
||||
@ -107,11 +111,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
|
||||
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
|
||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||
char unit;
|
||||
unsigned long amount;
|
||||
uint64_t amount;
|
||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
||||
switch (unit) {
|
||||
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
|
||||
@ -121,7 +125,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
default: utils::Error("invalid format for reduce buffer");
|
||||
}
|
||||
} else {
|
||||
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
||||
utils::Error("invalid format for reduce_buffer,"\
|
||||
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -137,11 +142,16 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||
utils::Socket::Error("Connect");
|
||||
}
|
||||
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
|
||||
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
|
||||
using utils::Assert;
|
||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
"ReConnectLink failure 1");
|
||||
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
"ReConnectLink failure 2");
|
||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
||||
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
|
||||
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
|
||||
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 3");
|
||||
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
return tracker;
|
||||
}
|
||||
@ -161,29 +171,30 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
std::map<int, int> tree_neighbors;
|
||||
{// get new ranks
|
||||
int newrank, num_neighbors;
|
||||
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||
"ReConnectLink failure 4");
|
||||
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
|
||||
"ReConnectLink failure 4");
|
||||
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 4");
|
||||
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors),
|
||||
"ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
int nrank;
|
||||
utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||
"ReConnectLink failure 4");
|
||||
tree_neighbors[nrank] = 1;
|
||||
}
|
||||
utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||
"ReConnectLink failure 4");
|
||||
utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
using utils::Assert;
|
||||
// get new ranks
|
||||
int newrank, num_neighbors;
|
||||
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
|
||||
sizeof(parent_rank), "ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(rank == -1 || newrank == rank,
|
||||
"must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
int nrank;
|
||||
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||
"ReConnectLink failure 4");
|
||||
tree_neighbors[nrank] = 1;
|
||||
}
|
||||
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
// create listening socket
|
||||
utils::TCPSocket sock_listen;
|
||||
sock_listen.Create();
|
||||
@ -204,56 +215,67 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
}
|
||||
}
|
||||
int ngood = static_cast<int>(good_link.size());
|
||||
utils::Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||
"ReConnectLink failure 5");
|
||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||
"ReConnectLink failure 5");
|
||||
for (size_t i = 0; i < good_link.size(); ++i) {
|
||||
utils::Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]),
|
||||
"ReConnectLink failure 6");
|
||||
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
||||
sizeof(good_link[i]), "ReConnectLink failure 6");
|
||||
}
|
||||
utils::Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
utils::Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
|
||||
"ReConnectLink failure 8");
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
||||
sizeof(num_accept), "ReConnectLink failure 8");
|
||||
num_error = 0;
|
||||
for (int i = 0; i < num_conn; ++i) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
tracker.RecvStr(&hname);
|
||||
utils::Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||
utils::Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||
tracker.RecvStr(&hname);
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
||||
"ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
||||
"ReConnectLink failure 10");
|
||||
r.sock.Create();
|
||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||
num_error += 1; r.sock.Close(); continue;
|
||||
}
|
||||
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12");
|
||||
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
|
||||
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 12");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 13");
|
||||
utils::Check(hrank == r.rank,
|
||||
"ReConnectLink failure, link rank inconsistent");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == hrank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
|
||||
Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
|
||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||
"ReConnectLink failure 14");
|
||||
} while (num_error != 0);
|
||||
// send back socket listening port to tracker
|
||||
utils::Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
||||
"ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
tracker.Close();
|
||||
tracker.Close();
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
LinkRecord r;
|
||||
r.sock = sock_listen.Accept();
|
||||
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
|
||||
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 15");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 15");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == r.rank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
|
||||
utils::Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
}
|
||||
}
|
||||
@ -278,9 +300,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||
}
|
||||
utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link");
|
||||
utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link");
|
||||
utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link");
|
||||
Assert(parent_rank == -1 || parent_index != -1,
|
||||
"cannot find parent in the link");
|
||||
Assert(prev_rank == -1 || ring_prev != NULL,
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != NULL,
|
||||
"cannot find next ring in the link");
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
@ -326,7 +351,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
// if no childs, no need to reduce
|
||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||
size_up_reduce = total_size;
|
||||
}
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
// select helper
|
||||
@ -347,7 +372,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
if (links[i].size_read != total_size) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
}
|
||||
// size_write <= size_read
|
||||
// size_write <= size_read
|
||||
if (links[i].size_write != total_size) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
// only watch for exception in live channels
|
||||
@ -358,11 +383,11 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
}
|
||||
// finish runing allreduce
|
||||
if (finished) break;
|
||||
// select must return
|
||||
// select must return
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
// recive OOB message from some link
|
||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
||||
}
|
||||
// read data from childs
|
||||
@ -392,7 +417,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
// start position
|
||||
size_t start = size_up_reduce % buffer_size;
|
||||
// peform read till end of buffer
|
||||
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce);
|
||||
size_t nread = std::min(buffer_size - start,
|
||||
max_reduce - size_up_reduce);
|
||||
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
@ -407,7 +433,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
}
|
||||
if (parent_index != -1) {
|
||||
// pass message up to parent, can pass data that are already been reduced
|
||||
if (selecter.CheckWrite(links[parent_index].sock)) {
|
||||
if (selecter.CheckWrite(links[parent_index].sock)) {
|
||||
ssize_t len = links[parent_index].sock.
|
||||
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
||||
if (len != -1) {
|
||||
@ -417,7 +443,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
}
|
||||
}
|
||||
// read data from parent
|
||||
if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) {
|
||||
if (selecter.CheckRead(links[parent_index].sock) &&
|
||||
total_size > size_down_in) {
|
||||
ssize_t len = links[parent_index].sock.
|
||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||
if (len == 0) {
|
||||
@ -425,7 +452,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
}
|
||||
if (len != -1) {
|
||||
size_down_in += static_cast<size_t>(len);
|
||||
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
|
||||
utils::Assert(size_down_in <= size_up_out,
|
||||
"Allreduce: boundary error");
|
||||
} else {
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||
}
|
||||
@ -437,11 +465,13 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
// can pass message down to childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
|
||||
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
|
||||
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) {
|
||||
return kSockError;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
||||
@ -455,14 +485,15 @@ AllreduceBase::ReturnType
|
||||
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
|
||||
utils::Check(root < world_size,
|
||||
"Broadcast: root should be smaller than world size");
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// size of space already read from data
|
||||
size_t size_in = 0;
|
||||
// input link, -2 means unknown yet, -1 means this is root
|
||||
int in_link = -2;
|
||||
|
||||
|
||||
// initialize the link statistics
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
@ -471,9 +502,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
if (this->rank == root) {
|
||||
size_in = total_size;
|
||||
in_link = -1;
|
||||
}
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while(true) {
|
||||
while (true) {
|
||||
bool finished = true;
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
@ -487,7 +518,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
||||
selecter.WatchWrite(links[i].sock); finished = false;
|
||||
}
|
||||
selecter.WatchException(links[i].sock);
|
||||
selecter.WatchException(links[i].sock);
|
||||
}
|
||||
// finish running
|
||||
if (finished) break;
|
||||
@ -495,14 +526,16 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
// recive OOB message from some link
|
||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
||||
}
|
||||
if (in_link == -2) {
|
||||
// probe in-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (selecter.CheckRead(links[i].sock)) {
|
||||
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) return kSockError;
|
||||
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) {
|
||||
return kSockError;
|
||||
}
|
||||
size_in = links[i].size_read;
|
||||
if (size_in != 0) {
|
||||
in_link = i; break;
|
||||
@ -512,7 +545,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
} else {
|
||||
// read from in link
|
||||
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
|
||||
if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) return kSockError;
|
||||
if (!links[in_link].ReadToArray(sendrecvbuf_, total_size)) {
|
||||
return kSockError;
|
||||
}
|
||||
size_in = links[in_link].size_read;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_base.h
|
||||
* \brief Basic implementation of AllReduce
|
||||
* using TCP non-block socket and tree-shape reduction.
|
||||
@ -8,13 +9,14 @@
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_BASE_H
|
||||
#define RABIT_ALLREDUCE_BASE_H
|
||||
#ifndef RABIT_ALLREDUCE_BASE_H_
|
||||
#define RABIT_ALLREDUCE_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <rabit/utils.h>
|
||||
#include <rabit/engine.h>
|
||||
#include <algorithm>
|
||||
#include "rabit/utils.h"
|
||||
#include "rabit/engine.h"
|
||||
#include "./socket.h"
|
||||
|
||||
namespace MPI {
|
||||
@ -22,7 +24,7 @@ namespace MPI {
|
||||
class Datatype {
|
||||
public:
|
||||
size_t type_size;
|
||||
Datatype(size_t type_size) : type_size(type_size) {}
|
||||
explicit Datatype(size_t type_size) : type_size(type_size) {}
|
||||
};
|
||||
}
|
||||
namespace rabit {
|
||||
@ -31,7 +33,7 @@ namespace engine {
|
||||
class AllreduceBase : public IEngine {
|
||||
public:
|
||||
// magic number to verify server
|
||||
const static int kMagic = 0xff99;
|
||||
static const int kMagic = 0xff99;
|
||||
// constant one byte out of band message to indicate error happening
|
||||
AllreduceBase(void);
|
||||
virtual ~AllreduceBase(void) {}
|
||||
@ -79,12 +81,13 @@ class AllreduceBase : public IEngine {
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess,
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_,
|
||||
type_nbytes, count, reducer) == kSuccess,
|
||||
"Allreduce failed");
|
||||
}
|
||||
/*!
|
||||
@ -201,12 +204,16 @@ class AllreduceBase : public IEngine {
|
||||
// constructor
|
||||
LinkRecord(void) {}
|
||||
// initialize buffer
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
buffer_.resize(std::min(reduce_buffer_size, n));
|
||||
// make sure align to type_nbytes
|
||||
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size);
|
||||
buffer_size =
|
||||
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
utils::Assert(type_nbytes <= buffer_size,
|
||||
"too large type_nbytes=%lu, buffer_size=%lu",
|
||||
type_nbytes, buffer_size);
|
||||
// set buffer head
|
||||
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
||||
}
|
||||
@ -225,7 +232,7 @@ class AllreduceBase : public IEngine {
|
||||
size_t ngap = size_read - protect_start;
|
||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||
if (nmax == 0) return true;
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
@ -235,7 +242,7 @@ class AllreduceBase : public IEngine {
|
||||
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
|
||||
size_read += static_cast<size_t>(len);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief read data into array,
|
||||
* this function can not be used together with ReadToRingBuffer
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_robust-inl.h
|
||||
* \brief implementation of inline template function in AllreduceRobust
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_ENGINE_ROBUST_INL_H
|
||||
#define RABIT_ENGINE_ROBUST_INL_H
|
||||
#ifndef RABIT_ENGINE_ROBUST_INL_H_
|
||||
#define RABIT_ENGINE_ROBUST_INL_H_
|
||||
#include <vector>
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
@ -33,14 +35,14 @@ inline AllreduceRobust::ReturnType
|
||||
AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
std::vector<EdgeType> *p_edge_in,
|
||||
std::vector<EdgeType> *p_edge_out,
|
||||
EdgeType (*func) (const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)
|
||||
) {
|
||||
EdgeType (*func)
|
||||
(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0) return kSuccess;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// initialize the pointers
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
@ -58,7 +60,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
// if no childs, no need to, directly start passing message
|
||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||
utils::Assert(parent_index == 0, "parent must be 0");
|
||||
edge_out[parent_index] = func(node_value, edge_in, parent_index);
|
||||
edge_out[parent_index] = func(node_value, edge_in, parent_index);
|
||||
stage = 1;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
@ -94,7 +96,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
// recive OOB message from some link
|
||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
||||
}
|
||||
if (stage == 0) {
|
||||
@ -103,7 +105,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
if (selecter.CheckRead(links[i].sock)) {
|
||||
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) return kSockError;
|
||||
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) {
|
||||
return kSockError;
|
||||
}
|
||||
}
|
||||
if (links[i].size_read != sizeof(EdgeType)) finished = false;
|
||||
}
|
||||
@ -124,13 +128,17 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
if (stage == 1) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) return kSockError;
|
||||
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) {
|
||||
return kSockError;
|
||||
}
|
||||
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
|
||||
}
|
||||
if (stage == 2) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) return kSockError;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) {
|
||||
return kSockError;
|
||||
}
|
||||
if (links[pid].size_read == sizeof(EdgeType)) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
|
||||
@ -141,7 +149,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
if (stage == 3) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
|
||||
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) {
|
||||
return kSockError;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -150,4 +160,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ENGINE_ROBUST_INL_H
|
||||
#endif // RABIT_ENGINE_ROBUST_INL_H_
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_robust.cc
|
||||
* \brief Robust implementation of Allreduce
|
||||
*
|
||||
@ -9,10 +10,10 @@
|
||||
#define NOMINMAX
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <rabit/io.h>
|
||||
#include <rabit/utils.h>
|
||||
#include <rabit/engine.h>
|
||||
#include <rabit/rabit-inl.h>
|
||||
#include "rabit/io.h"
|
||||
#include "rabit/utils.h"
|
||||
#include "rabit/engine.h"
|
||||
#include "rabit/rabit-inl.h"
|
||||
#include "./allreduce_robust.h"
|
||||
|
||||
namespace rabit {
|
||||
@ -30,10 +31,10 @@ void AllreduceRobust::Shutdown(void) {
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check point must return true");
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check ack must return true");
|
||||
"Shutdown: check ack must return true");
|
||||
AllreduceBase::Shutdown();
|
||||
}
|
||||
/*!
|
||||
@ -89,7 +90,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
||||
} else {
|
||||
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resbuf.PushTemp(seq_counter, type_nbytes, count);
|
||||
seq_counter += 1;
|
||||
@ -102,7 +103,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
||||
*/
|
||||
void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
// skip action in single node
|
||||
if (world_size == 1) return;
|
||||
if (world_size == 1) return;
|
||||
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||
// now we are free to remove the last result, if any
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
@ -119,7 +120,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
||||
} else {
|
||||
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resbuf.PushTemp(seq_counter, 1, total_size);
|
||||
seq_counter += 1;
|
||||
@ -151,7 +152,8 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||
// skip action in single node
|
||||
if (world_size == 1) return 0;
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
utils::Check(local_model == NULL,
|
||||
"need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
// check if we succesful
|
||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
|
||||
@ -171,9 +173,10 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||
// load from buffer
|
||||
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||
if (global_checkpoint.length() == 0) {
|
||||
version_number = 0;
|
||||
version_number = 0;
|
||||
} else {
|
||||
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number");
|
||||
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
|
||||
"read in version number");
|
||||
global_model->Load(fs);
|
||||
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
|
||||
"local model inconsistent, nlocal=%d", nlocal);
|
||||
@ -212,9 +215,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
||||
version_number += 1; return;
|
||||
}
|
||||
if (num_local_replica == 0) {
|
||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
if (num_local_replica != 0) {
|
||||
utils::Check(local_model == NULL,
|
||||
"need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||
}
|
||||
if (num_local_replica != 0) {
|
||||
while (true) {
|
||||
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
|
||||
// save model model to new version place
|
||||
@ -247,10 +251,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
||||
fs.Write(&version_number, sizeof(version_number));
|
||||
global_model->Save(fs);
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"check ack must return true");
|
||||
"check ack must return true");
|
||||
}
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
@ -383,7 +387,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
||||
*/
|
||||
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||
if (err_type == kSuccess) return true;
|
||||
{// simple way, shutdown all links
|
||||
{
|
||||
// simple way, shutdown all links
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
||||
}
|
||||
@ -392,8 +397,8 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||
}
|
||||
// this was old way
|
||||
// TryResetLinks still causes possible errors, so not use this one
|
||||
while(err_type != kSuccess) {
|
||||
switch(err_type) {
|
||||
while (err_type != kSuccess) {
|
||||
switch (err_type) {
|
||||
case kGetExcept: err_type = TryResetLinks(); break;
|
||||
case kSockError: {
|
||||
TryResetLinks();
|
||||
@ -416,7 +421,7 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||
* \param out_index the edge index of output link
|
||||
* \return the shorest distance result of out edge specified by out_index
|
||||
*/
|
||||
inline std::pair<int,size_t>
|
||||
inline std::pair<int, size_t>
|
||||
ShortestDist(const std::pair<bool, size_t> &node_value,
|
||||
const std::vector< std::pair<int, size_t> > &dist_in,
|
||||
size_t out_index) {
|
||||
@ -484,8 +489,9 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
||||
int *p_recvlink,
|
||||
std::vector<bool> *p_req_in) {
|
||||
int best_link = -2;
|
||||
{// get the shortest distance to the request point
|
||||
std::vector< std::pair<int,size_t> > dist_in, dist_out;
|
||||
{
|
||||
// get the shortest distance to the request point
|
||||
std::vector<std::pair<int, size_t> > dist_in, dist_out;
|
||||
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
|
||||
&dist_in, &dist_out, ShortestDist);
|
||||
if (succ != kSuccess) return succ;
|
||||
@ -512,7 +518,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
||||
&req_in, &req_out, DataRequest);
|
||||
if (succ != kSuccess) return succ;
|
||||
// set p_req_in
|
||||
p_req_in->resize(req_in.size());
|
||||
p_req_in->resize(req_in.size());
|
||||
for (size_t i = 0; i < req_in.size(); ++i) {
|
||||
// set p_req_in
|
||||
(*p_req_in)[i] = (req_in[i] != 0);
|
||||
@ -591,19 +597,23 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
if (role == kRequestData) {
|
||||
const int pid = recv_link;
|
||||
if (selecter.CheckRead(links[pid].sock)) {
|
||||
if(!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError;
|
||||
if (!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError;
|
||||
}
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (req_in[i] && links[i].size_write != links[pid].size_read &&
|
||||
selecter.CheckWrite(links[i].sock)) {
|
||||
if(!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) return kSockError;
|
||||
if (!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) {
|
||||
return kSockError;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (role == kHaveData) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (req_in[i] && selecter.CheckWrite(links[i].sock)) {
|
||||
if(!links[i].WriteFromArray(sendrecvbuf_, size)) return kSockError;
|
||||
if (!links[i].WriteFromArray(sendrecvbuf_, size)) {
|
||||
return kSockError;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -616,13 +626,14 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
||||
}
|
||||
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
||||
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
|
||||
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
|
||||
}
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (req_in[i] && selecter.CheckWrite(links[i].sock) &&
|
||||
links[pid].size_read != links[i].size_write) {
|
||||
size_t start = links[i].size_write % buffer_size;
|
||||
// send out data from ring buffer
|
||||
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
|
||||
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
|
||||
ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite);
|
||||
if (len != -1) {
|
||||
links[i].size_write += len;
|
||||
@ -648,15 +659,15 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
*/
|
||||
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||
// check in local data
|
||||
RecoverType role = requester ? kRequestData : kHaveData;
|
||||
ReturnType succ;
|
||||
RecoverType role = requester ? kRequestData : kHaveData;
|
||||
ReturnType succ;
|
||||
if (num_local_replica != 0) {
|
||||
if (requester) {
|
||||
// clear existing history, if any, before load
|
||||
local_rptr[local_chkpt_version].clear();
|
||||
local_chkpt[local_chkpt_version].clear();
|
||||
local_chkpt[local_chkpt_version].clear();
|
||||
}
|
||||
// recover local checkpoint
|
||||
// recover local checkpoint
|
||||
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
|
||||
&local_chkpt[local_chkpt_version]);
|
||||
if (succ != kSuccess) return succ;
|
||||
@ -716,7 +727,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
||||
// if we goes to this place, use must have already setup the state once
|
||||
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
|
||||
"TryGetResult::Checkpoint");
|
||||
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
|
||||
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
|
||||
}
|
||||
// handles normal data recovery
|
||||
RecoverType role;
|
||||
@ -735,8 +746,9 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
||||
utils::Check(data_size != 0, "zero size check point is not allowed");
|
||||
if (role == kRequestData || role == kHaveData) {
|
||||
utils::Check(data_size == size,
|
||||
"Allreduce Recovered data size do not match the specification of function call\n"\
|
||||
"Please check if calling sequence of recovered program is the same the original one in current VersionNumber");
|
||||
"Allreduce Recovered data size do not match the specification of function call.\n"\
|
||||
"Please check if calling sequence of recovered program is the " \
|
||||
"same the original one in current VersionNumber");
|
||||
}
|
||||
return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
|
||||
}
|
||||
@ -766,7 +778,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||
while (true) {
|
||||
this->ReportStatus();
|
||||
// action
|
||||
ActionSummary act = req;
|
||||
ActionSummary act = req;
|
||||
// get the reduced action
|
||||
if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
|
||||
if (act.check_ack()) {
|
||||
@ -816,7 +828,8 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
|
||||
if (requester) return true;
|
||||
} else {
|
||||
// all the request is same, this is most recent command that is yet to be executed
|
||||
// all the request is same,
|
||||
// this is most recent command that is yet to be executed
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -855,7 +868,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
|
||||
}
|
||||
const int n = num_local_replica;
|
||||
{// backward passing, passing state in backward direction of the ring
|
||||
{
|
||||
// backward passing, passing state in backward direction of the ring
|
||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
||||
std::vector<int> msg_back(n + 1);
|
||||
@ -897,10 +911,10 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
// update rptr
|
||||
rptr.resize(nread_end + 1);
|
||||
for (int i = nlocal; i < nread_end; ++i) {
|
||||
rptr[i + 1] = rptr[i] + sizes[i];
|
||||
rptr[i + 1] = rptr[i] + sizes[i];
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
// pass data through the link
|
||||
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
|
||||
rptr[nwrite_start], rptr[nread_end],
|
||||
ring_next, ring_prev);
|
||||
@ -908,7 +922,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
|
||||
}
|
||||
}
|
||||
{// forward passing, passing state in forward direction of the ring
|
||||
{
|
||||
// forward passing, passing state in forward direction of the ring
|
||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
||||
std::vector<int> msg_forward(n + 1);
|
||||
@ -926,7 +941,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
1 * sizeof(int), 2 * sizeof(int),
|
||||
0 * sizeof(int), 1 * sizeof(int),
|
||||
ring_next, ring_prev);
|
||||
if (succ != kSuccess) return succ;
|
||||
if (succ != kSuccess) return succ;
|
||||
// calculate the number of things we can read from next link
|
||||
int nread_end = nlocal, nwrite_end = 1;
|
||||
// have to have itself in order to get other data from prev link
|
||||
@ -936,7 +951,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
nread_end = std::max(nread_end, i + 1);
|
||||
nwrite_end = i + 1;
|
||||
}
|
||||
if (nwrite_end > n) nwrite_end = n;
|
||||
if (nwrite_end > n) nwrite_end = n;
|
||||
} else {
|
||||
nread_end = 0; nwrite_end = 0;
|
||||
}
|
||||
@ -963,7 +978,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
rptr[i + 1] = rptr[i] + sizes[i];
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
// pass data through the link
|
||||
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
|
||||
rptr[nwrite_start], rptr[nwrite_end],
|
||||
ring_prev, ring_next);
|
||||
@ -995,7 +1010,8 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||
if (num_local_replica == 0) return kSuccess;
|
||||
std::vector<size_t> &rptr = *p_local_rptr;
|
||||
std::string &chkpt = *p_local_chkpt;
|
||||
utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state");
|
||||
utils::Assert(rptr.size() == 2,
|
||||
"TryCheckinLocalState must have exactly 1 state");
|
||||
const int n = num_local_replica;
|
||||
std::vector<size_t> sizes(n + 1);
|
||||
sizes[0] = rptr[1] - rptr[0];
|
||||
@ -1012,9 +1028,9 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||
rptr.resize(n + 2);
|
||||
for (int i = 1; i <= n; ++i) {
|
||||
rptr[i + 1] = rptr[i] + sizes[i];
|
||||
}
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
// pass data through the link
|
||||
succ = RingPassing(BeginPtr(chkpt),
|
||||
rptr[1], rptr[n + 1],
|
||||
rptr[0], rptr[n],
|
||||
@ -1050,13 +1066,14 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
||||
LinkRecord *read_link,
|
||||
LinkRecord *write_link) {
|
||||
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
|
||||
utils::Assert(write_end <= read_end, "RingPassing: boundary check1, write_end=%lu, read_end=%lu", write_end, read_end);
|
||||
utils::Assert(write_end <= read_end,
|
||||
"RingPassing: boundary check1");
|
||||
utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2");
|
||||
utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3");
|
||||
// take reference
|
||||
LinkRecord &prev = *read_link, &next = *write_link;
|
||||
// send recv buffer
|
||||
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
while (true) {
|
||||
bool finished = true;
|
||||
utils::SelectHelper selecter;
|
||||
@ -1066,7 +1083,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
||||
}
|
||||
if (write_ptr < read_ptr && write_ptr != write_end) {
|
||||
selecter.WatchWrite(next.sock);
|
||||
finished = false;
|
||||
finished = false;
|
||||
}
|
||||
selecter.WatchException(prev.sock);
|
||||
selecter.WatchException(next.sock);
|
||||
@ -1078,7 +1095,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
||||
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
|
||||
if (len == 0) {
|
||||
prev.sock.Close(); return kSockError;
|
||||
}
|
||||
}
|
||||
if (len != -1) {
|
||||
read_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_robust.h
|
||||
* \brief Robust implementation of Allreduce
|
||||
* using TCP non-block socket and tree-shape reduction.
|
||||
@ -7,10 +8,12 @@
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_ROBUST_H
|
||||
#define RABIT_ALLREDUCE_ROBUST_H
|
||||
#ifndef RABIT_ALLREDUCE_ROBUST_H_
|
||||
#define RABIT_ALLREDUCE_ROBUST_H_
|
||||
#include <vector>
|
||||
#include <rabit/engine.h>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "rabit/engine.h"
|
||||
#include "./allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
@ -111,11 +114,11 @@ class AllreduceRobust : public AllreduceBase {
|
||||
private:
|
||||
// constant one byte out of band message to indicate error happening
|
||||
// and mark for channel cleanup
|
||||
const static char kOOBReset = 95;
|
||||
static const char kOOBReset = 95;
|
||||
// and mark for channel cleanup, after OOB signal
|
||||
const static char kResetMark = 97;
|
||||
static const char kResetMark = 97;
|
||||
// and mark for channel cleanup
|
||||
const static char kResetAck = 97;
|
||||
static const char kResetAck = 97;
|
||||
/*! \brief type of roles each node can play during recovery */
|
||||
enum RecoverType {
|
||||
/*! \brief current node have data */
|
||||
@ -132,29 +135,29 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*/
|
||||
struct ActionSummary {
|
||||
// maximumly allowed sequence id
|
||||
const static int kSpecialOp = (1 << 26);
|
||||
static const int kSpecialOp = (1 << 26);
|
||||
// special sequence number for local state checkpoint
|
||||
const static int kLocalCheckPoint = (1 << 26) - 2;
|
||||
static const int kLocalCheckPoint = (1 << 26) - 2;
|
||||
// special sequnce number for local state checkpoint ack signal
|
||||
const static int kLocalCheckAck = (1 << 26) - 1;
|
||||
static const int kLocalCheckAck = (1 << 26) - 1;
|
||||
//---------------------------------------------
|
||||
// The following are bit mask of flag used in
|
||||
// The following are bit mask of flag used in
|
||||
//----------------------------------------------
|
||||
// some node want to load check point
|
||||
const static int kLoadCheck = 1;
|
||||
static const int kLoadCheck = 1;
|
||||
// some node want to do check point
|
||||
const static int kCheckPoint = 2;
|
||||
static const int kCheckPoint = 2;
|
||||
// check point Ack, we use a two phase message in check point,
|
||||
// this is the second phase of check pointing
|
||||
const static int kCheckAck = 4;
|
||||
static const int kCheckAck = 4;
|
||||
// there are difference sequence number the nodes proposed
|
||||
// this means we want to do recover execution of the lower sequence
|
||||
// action instead of normal execution
|
||||
const static int kDiffSeq = 8;
|
||||
static const int kDiffSeq = 8;
|
||||
// constructor
|
||||
ActionSummary(void) {}
|
||||
// constructor of action
|
||||
ActionSummary(int flag, int minseqno = kSpecialOp) {
|
||||
// constructor of action
|
||||
explicit ActionSummary(int flag, int minseqno = kSpecialOp) {
|
||||
seqcode = (minseqno << 4) | flag;
|
||||
}
|
||||
// minimum number of all operations
|
||||
@ -181,10 +184,11 @@ class AllreduceRobust : public AllreduceBase {
|
||||
inline int flag(void) const {
|
||||
return seqcode & 15;
|
||||
}
|
||||
// reducer for Allreduce, used to get the result ActionSummary from all nodes
|
||||
inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
|
||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||
inline static void Reducer(const void *src_, void *dst_,
|
||||
int len, const MPI::Datatype &dtype) {
|
||||
const ActionSummary *src = (const ActionSummary*)src_;
|
||||
ActionSummary *dst = (ActionSummary*)dst_;
|
||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
int src_seqno = src[i].min_seqno();
|
||||
int dst_seqno = dst[i].min_seqno();
|
||||
@ -192,7 +196,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
if (src_seqno == dst_seqno) {
|
||||
dst[i] = ActionSummary(flag, src_seqno);
|
||||
} else {
|
||||
dst[i] = ActionSummary(flag | kDiffSeq, std::min(src_seqno, dst_seqno));
|
||||
dst[i] = ActionSummary(flag | kDiffSeq,
|
||||
std::min(src_seqno, dst_seqno));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -222,7 +227,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
data_.resize(rptr_.back() + nhop);
|
||||
return BeginPtr(data_) + rptr_.back();
|
||||
}
|
||||
// push the result in temp to the
|
||||
// push the result in temp to the
|
||||
inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
|
||||
size_t size = type_nbytes * count;
|
||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
@ -234,13 +239,14 @@ class AllreduceRobust : public AllreduceBase {
|
||||
size_.push_back(size);
|
||||
utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
|
||||
}
|
||||
// return the stored result of seqid, if any
|
||||
// return the stored result of seqid, if any
|
||||
inline void* Query(int seqid, size_t *p_size) {
|
||||
size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin();
|
||||
size_t idx = std::lower_bound(seqno_.begin(),
|
||||
seqno_.end(), seqid) - seqno_.begin();
|
||||
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
|
||||
*p_size = size_[idx];
|
||||
return BeginPtr(data_) + rptr_[idx];
|
||||
}
|
||||
}
|
||||
// drop last stored result
|
||||
inline void DropLast(void) {
|
||||
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
|
||||
@ -254,15 +260,16 @@ class AllreduceRobust : public AllreduceBase {
|
||||
if (seqno_.size() == 0) return -1;
|
||||
return seqno_.back();
|
||||
}
|
||||
|
||||
private:
|
||||
// sequence number of each
|
||||
// sequence number of each
|
||||
std::vector<int> seqno_;
|
||||
// pointer to the positions
|
||||
std::vector<size_t> rptr_;
|
||||
// actual size of each buffer
|
||||
std::vector<size_t> size_;
|
||||
// content of the buffer
|
||||
std::vector<uint64_t> data_;
|
||||
std::vector<uint64_t> data_;
|
||||
};
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
@ -291,14 +298,16 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* \param buf the buffer to store the result
|
||||
* \param size the total size of the buffer
|
||||
* \param flag flag information about the action \sa ActionSummary
|
||||
* \param seqno sequence number of the action, if it is special action with flag set, seqno needs to be set to ActionSummary::kSpecialOp
|
||||
* \param seqno sequence number of the action, if it is special action with flag set,
|
||||
* seqno needs to be set to ActionSummary::kSpecialOp
|
||||
*
|
||||
* \return if this function can return true or false
|
||||
* - true means buf already set to the
|
||||
* result by recovering procedure, the action is complete, no further action is needed
|
||||
* - true means buf already set to the
|
||||
* result by recovering procedure, the action is complete, no further action is needed
|
||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||
*/
|
||||
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kSpecialOp);
|
||||
bool RecoverExec(void *buf, size_t size, int flag,
|
||||
int seqno = ActionSummary::kSpecialOp);
|
||||
/*!
|
||||
* \brief try to load check point
|
||||
*
|
||||
@ -363,7 +372,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
void *sendrecvbuf_,
|
||||
size_t size,
|
||||
int recv_link,
|
||||
const std::vector<bool> &req_in);
|
||||
const std::vector<bool> &req_in);
|
||||
/*!
|
||||
* \brief try to recover the local state, making each local state to be the result of itself
|
||||
* plus replication of states in previous num_local_replica hops in the ring
|
||||
@ -446,17 +455,17 @@ o * the input state must exactly one saved state(local state of current node)
|
||||
inline ReturnType MsgPassing(const NodeType &node_value,
|
||||
std::vector<EdgeType> *p_edge_in,
|
||||
std::vector<EdgeType> *p_edge_out,
|
||||
EdgeType (*func) (const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)
|
||||
);
|
||||
EdgeType (*func)
|
||||
(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index));
|
||||
//---- recovery data structure ----
|
||||
// the round of result buffer, used to mode the result
|
||||
int result_buffer_round;
|
||||
// result buffer of all reduce
|
||||
ResultBuffer resbuf;
|
||||
// last check point global model
|
||||
std::string global_checkpoint;
|
||||
std::string global_checkpoint;
|
||||
// number of replica for local state/model
|
||||
int num_local_replica;
|
||||
// --- recovery data structure for local checkpoint
|
||||
@ -465,16 +474,15 @@ o * the input state must exactly one saved state(local state of current node)
|
||||
// pointer to memory position in the local model
|
||||
// local model is stored in CSR format(like a sparse matrices)
|
||||
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
|
||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
|
||||
std::vector<size_t> local_rptr[2];
|
||||
// storage for local model replicas
|
||||
std::string local_chkpt[2];
|
||||
// version of local checkpoint can be 1 or 0
|
||||
int local_chkpt_version;
|
||||
int local_chkpt_version;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
// implementation of inline template function
|
||||
#include "./allreduce_robust-inl.h"
|
||||
|
||||
#endif // RABIT_ALLREDUCE_ROBUST_H
|
||||
#endif // RABIT_ALLREDUCE_ROBUST_H_
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine.cc
|
||||
* \brief this file governs which implementation of engine we are actually using
|
||||
* provides an singleton of engine interface
|
||||
@ -41,16 +42,17 @@ void Finalize(void) {
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
red, prepare_fun, prepare_arg);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_empty.cc
|
||||
* \brief this file provides a dummy implementation of engine that does nothing
|
||||
* this file provides a way to fall back to single node program without causing too many dependencies
|
||||
@ -25,9 +26,10 @@ class EmptyEngine : public IEngine {
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
@ -51,7 +53,7 @@ class EmptyEngine : public IEngine {
|
||||
virtual int GetWorldSize(void) const {
|
||||
return 1;
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const {
|
||||
return std::string("");
|
||||
}
|
||||
@ -59,6 +61,7 @@ class EmptyEngine : public IEngine {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
@ -77,11 +80,11 @@ void Finalize(void) {
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mock.cc
|
||||
* \brief this is an engine implementation that will
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mpi.cc
|
||||
* \brief this file gives an implementation of engine interface using MPI,
|
||||
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
|
||||
@ -8,10 +9,10 @@
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <cstdio>
|
||||
#include <rabit/engine.h>
|
||||
#include <rabit/utils.h>
|
||||
#include <mpi.h>
|
||||
#include <cstdio>
|
||||
#include "rabit/engine.h"
|
||||
#include "rabit/utils.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
@ -27,9 +28,10 @@ class MPIEngine : public IEngine {
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead");
|
||||
utils::Error("MPIEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
@ -48,13 +50,13 @@ class MPIEngine : public IEngine {
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
virtual int GetRank(void) const {
|
||||
return MPI::COMM_WORLD.Get_rank();
|
||||
return MPI::COMM_WORLD.Get_rank();
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
virtual int GetWorldSize(void) const {
|
||||
return MPI::COMM_WORLD.Get_size();
|
||||
return MPI::COMM_WORLD.Get_size();
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const {
|
||||
int len;
|
||||
char name[MPI_MAX_PROCESSOR_NAME];
|
||||
@ -68,6 +70,7 @@ class MPIEngine : public IEngine {
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
@ -91,7 +94,7 @@ IEngine *GetEngine(void) {
|
||||
// transform enum to MPI data type
|
||||
inline MPI::Datatype GetType(mpi::DataType dtype) {
|
||||
using namespace mpi;
|
||||
switch(dtype) {
|
||||
switch (dtype) {
|
||||
case kInt: return MPI::INT;
|
||||
case kUInt: return MPI::UNSIGNED;
|
||||
case kFloat: return MPI::FLOAT;
|
||||
@ -103,7 +106,7 @@ inline MPI::Datatype GetType(mpi::DataType dtype) {
|
||||
// transform enum to MPI OP
|
||||
inline MPI::Op GetOp(mpi::OpType otype) {
|
||||
using namespace mpi;
|
||||
switch(otype) {
|
||||
switch (otype) {
|
||||
case kMax: return MPI::MAX;
|
||||
case kMin: return MPI::MIN;
|
||||
case kSum: return MPI::SUM;
|
||||
@ -112,17 +115,18 @@ inline MPI::Op GetOp(mpi::OpType otype) {
|
||||
utils::Error("unknown mpi::OpType");
|
||||
return MPI::MAX;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
|
||||
count, GetType(dtype), GetOp(op));
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
@ -152,7 +156,7 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
created_type_nbytes_ = type_nbytes;
|
||||
htype_ = dtype;
|
||||
}
|
||||
|
||||
|
||||
MPI::Op *op = new MPI::Op();
|
||||
MPI::User_function *pf = redfunc;
|
||||
op->Init(pf, true);
|
||||
@ -175,7 +179,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
dtype->Commit();
|
||||
created_type_nbytes_ = type_nbytes;
|
||||
}
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
|
||||
}
|
||||
} // namespace engine
|
||||
|
||||
80
src/socket.h
80
src/socket.h
@ -1,10 +1,11 @@
|
||||
#ifndef RABIT_SOCKET_H
|
||||
#define RABIT_SOCKET_H
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file socket.h
|
||||
* \brief this file aims to provide a wrapper of sockets
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_SOCKET_H_
|
||||
#define RABIT_SOCKET_H_
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
@ -21,7 +22,7 @@
|
||||
#endif
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <rabit/utils.h>
|
||||
#include "rabit/utils.h"
|
||||
|
||||
#if defined(_WIN32)
|
||||
typedef int ssize_t;
|
||||
@ -68,9 +69,11 @@ struct SockAddr {
|
||||
inline std::string AddrStr(void) const {
|
||||
std::string buf; buf.resize(256);
|
||||
#ifdef _WIN32
|
||||
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, &buf[0], buf.length());
|
||||
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#else
|
||||
const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length());
|
||||
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#endif
|
||||
Assert(s != NULL, "cannot decode address");
|
||||
return std::string(s);
|
||||
@ -94,12 +97,12 @@ class Socket {
|
||||
*/
|
||||
inline static void Startup(void) {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsa_data;
|
||||
WSADATA wsa_data;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) {
|
||||
Socket::Error("Startup");
|
||||
}
|
||||
Socket::Error("Startup");
|
||||
}
|
||||
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||
WSACleanup();
|
||||
WSACleanup();
|
||||
utils::Error("Could not find a usable version of Winsock.dll\n");
|
||||
}
|
||||
#endif
|
||||
@ -118,11 +121,11 @@ class Socket {
|
||||
* it will set it back to block mode
|
||||
*/
|
||||
inline void SetNonBlock(bool non_block) {
|
||||
#ifdef _WIN32
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||
#ifdef _WIN32
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||
Socket::Error("SetNonBlock");
|
||||
}
|
||||
}
|
||||
#else
|
||||
int flag = fcntl(sockfd, F_GETFL, 0);
|
||||
if (flag == -1) {
|
||||
@ -143,7 +146,8 @@ class Socket {
|
||||
* \param addr
|
||||
*/
|
||||
inline void Bind(const SockAddr &addr) {
|
||||
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) {
|
||||
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == -1) {
|
||||
Socket::Error("Bind");
|
||||
}
|
||||
}
|
||||
@ -154,10 +158,11 @@ class Socket {
|
||||
* \return the port successfully bind to, return -1 if failed to bind any port
|
||||
*/
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
// TODO, add prefix check
|
||||
// TODO(tqchen) add prefix check
|
||||
for (int port = start_port; port < end_port; ++port) {
|
||||
SockAddr addr("0.0.0.0", port);
|
||||
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {
|
||||
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0) {
|
||||
return port;
|
||||
}
|
||||
if (errno != EADDRINUSE) {
|
||||
@ -179,22 +184,22 @@ class Socket {
|
||||
inline bool BadSocket(void) const {
|
||||
if (IsClosed()) return true;
|
||||
int err = GetSockError();
|
||||
if (err == EBADF || err == EINTR) return true;
|
||||
if (err == EBADF || err == EINTR) return true;
|
||||
return false;
|
||||
}
|
||||
/*! \brief check if socket is already closed */
|
||||
inline bool IsClosed(void) const {
|
||||
return sockfd == INVALID_SOCKET;
|
||||
}
|
||||
}
|
||||
/*! \brief close the socket */
|
||||
inline void Close(void) {
|
||||
if (sockfd != INVALID_SOCKET) {
|
||||
#ifdef _WIN32
|
||||
closesocket(sockfd);
|
||||
#else
|
||||
close(sockfd);
|
||||
close(sockfd);
|
||||
#endif
|
||||
sockfd = INVALID_SOCKET;
|
||||
sockfd = INVALID_SOCKET;
|
||||
} else {
|
||||
Error("Socket::Close double close the socket or close without create");
|
||||
}
|
||||
@ -204,6 +209,7 @@ class Socket {
|
||||
int errsv = errno;
|
||||
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
|
||||
}
|
||||
@ -227,7 +233,7 @@ class TCPSocket : public Socket{
|
||||
int opt = static_cast<int>(keepalive);
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) {
|
||||
Socket::Error("SetKeepAlive");
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief create the socket, call this before using socket
|
||||
@ -273,7 +279,8 @@ class TCPSocket : public Socket{
|
||||
* \return whether connect is successful
|
||||
*/
|
||||
inline bool Connect(const SockAddr &addr) {
|
||||
return connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0;
|
||||
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0;
|
||||
}
|
||||
/*!
|
||||
* \brief send data using the socket
|
||||
@ -284,7 +291,7 @@ class TCPSocket : public Socket{
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||
}
|
||||
/*!
|
||||
@ -296,7 +303,7 @@ class TCPSocket : public Socket{
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||
}
|
||||
/*!
|
||||
@ -331,7 +338,8 @@ class TCPSocket : public Socket{
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||
ssize_t ret = recv(sockfd, buf,
|
||||
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||
if (ret == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
|
||||
Socket::Error("RecvAll");
|
||||
@ -385,7 +393,7 @@ struct SelectHelper {
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchRead(SOCKET fd) {
|
||||
FD_SET(fd, &read_set);
|
||||
FD_SET(fd, &read_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
/*!
|
||||
@ -403,7 +411,7 @@ struct SelectHelper {
|
||||
inline void WatchException(SOCKET fd) {
|
||||
FD_SET(fd, &except_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor is ready for read
|
||||
* \param fd file descriptor to check status
|
||||
@ -435,8 +443,9 @@ struct SelectHelper {
|
||||
fd_set wait_set;
|
||||
FD_ZERO(&wait_set);
|
||||
FD_SET(fd, &wait_set);
|
||||
return Select_(static_cast<int>(fd + 1), NULL, NULL, &wait_set, timeout);
|
||||
}
|
||||
return Select_(static_cast<int>(fd + 1),
|
||||
NULL, NULL, &wait_set, timeout);
|
||||
}
|
||||
/*!
|
||||
* \brief peform select on the set defined
|
||||
* \param select_read whether to watch for read event
|
||||
@ -454,9 +463,10 @@ struct SelectHelper {
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) {
|
||||
inline static int Select_(int maxfd, fd_set *rfds,
|
||||
fd_set *wfds, fd_set *efds, long timeout) {
|
||||
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
|
||||
if (timeout == 0) {
|
||||
return select(maxfd, rfds, wfds, efds, NULL);
|
||||
@ -465,12 +475,12 @@ struct SelectHelper {
|
||||
tm.tv_usec = (timeout % 1000) * 1000;
|
||||
tm.tv_sec = timeout / 1000;
|
||||
return select(maxfd, rfds, wfds, efds, &tm);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SOCKET maxfd;
|
||||
|
||||
SOCKET maxfd;
|
||||
fd_set read_set, write_set, except_set;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif
|
||||
#endif // RABIT_SOCKET_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user