cpplint pass

This commit is contained in:
tqchen 2014-12-28 05:12:07 -08:00
parent 15836eb98e
commit 27d6977a3e
16 changed files with 406 additions and 328 deletions

View File

@ -36,4 +36,4 @@ $(ALIB):
ar cr $@ $+
clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~

View File

@ -1,6 +1,5 @@
#ifndef RABIT_RABIT_H
#define RABIT_RABIT_H
/*!
* Copyright (c) 2014 by Contributors
* \file rabit.h
* \brief This file defines unified Allreduce/Broadcast interface of rabit
* The actual implementation is redirected to rabit engine
@ -9,12 +8,14 @@
* rabit.h and serializable.h is all the user need to use rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_RABIT_H_
#define RABIT_RABIT_H_
#include <string>
#include <vector>
// optionally support of lambda function in C++11, if available
#if __cplusplus >= 201103L
#include <functional>
#endif // C++11
#endif // C++11
// contains definition of ISerializable
#include "./rabit_serializable.h"
// engine definition of rabit, defines internal implementation
@ -116,7 +117,7 @@ inline void Broadcast(std::string *sendrecv_data, int root);
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg) = NULL,
void (*prepare_fun)(void *arg) = NULL,
void *prepare_arg = NULL);
// C++11 support for lambda prepare function
@ -142,9 +143,9 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \tparam DType type of data
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun);
#endif // C++11
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun);
#endif // C++11
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
@ -228,6 +229,7 @@ class Reducer {
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun);
#endif
private:
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
@ -274,6 +276,7 @@ class SerializeReducer {
size_t max_nbyte, size_t count,
std::function<void()> prepare_fun);
#endif
private:
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
@ -283,4 +286,4 @@ class SerializeReducer {
} // namespace rabit
// implementation of template functions
#include "./rabit/rabit-inl.h"
#endif // RABIT_ALLREDUCE_H
#endif // RABIT_RABIT_H_

View File

@ -1,10 +1,12 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.h
* \brief This file defines the core interface of allreduce library
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_ENGINE_H
#define RABIT_ENGINE_H
#ifndef RABIT_ENGINE_H_
#define RABIT_ENGINE_H_
#include <string>
#include "../rabit_serializable.h"
namespace MPI {
@ -122,7 +124,7 @@ class IEngine {
virtual int GetRank(void) const = 0;
/*! \brief get total number of */
virtual int GetWorldSize(void) const = 0;
/*! \brief get the host name of current node */
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const = 0;
/*!
* \brief print the msg in the tracker,
@ -211,7 +213,7 @@ class ReduceHandle {
/*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype);
private:
protected:
// handle data field
void *handle_;
// handle to the type field
@ -221,5 +223,4 @@ class ReduceHandle {
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_H
#endif // RABIT_ENGINE_H_

View File

@ -1,16 +1,19 @@
#ifndef RABIT_UTILS_IO_H
#define RABIT_UTILS_IO_H
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include "./utils.h"
#include "../rabit_serializable.h"
/*!
* Copyright (c) 2014 by Contributors
* \file io.h
* \brief utilities that implements different serializable interface
* \author Tianqi Chen
*/
#ifndef RABIT_UTILS_IO_H_
#define RABIT_UTILS_IO_H_
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include <algorithm>
#include "./utils.h"
#include "../rabit_serializable.h"
namespace rabit {
namespace utils {
/*! \brief interface of i/o stream that support seek */
@ -25,8 +28,9 @@ class ISeekStream: public IStream {
/*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public ISeekStream {
public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual ~MemoryFixSizeBuffer(void) {}
@ -40,7 +44,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_,
utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size");
memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
@ -59,12 +63,12 @@ struct MemoryFixSizeBuffer : public ISeekStream {
size_t buffer_size_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryFixSizeBuffer
}; // class MemoryFixSizeBuffer
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream {
public:
MemoryBufferStream(std::string *p_buffer)
explicit MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
@ -82,7 +86,7 @@ struct MemoryBufferStream : public ISeekStream {
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
@ -97,36 +101,7 @@ struct MemoryBufferStream : public ISeekStream {
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
}; // class MemoryBufferStream
} // namespace utils
} // namespace rabit
#endif
#endif // RABIT_UTILS_IO_H_

View File

@ -1,10 +1,11 @@
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
/*!
* Copyright (c) 2014 by Contributors
* \file utils.h
* \brief simple utils to support the code
* \author Tianqi Chen
*/
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
#define _CRT_SECURE_NO_WARNINGS
#include <cstdio>
#include <string>
@ -19,7 +20,7 @@
#define fopen64 std::fopen
#endif
#ifdef _MSC_VER
// NOTE: sprintf_s is not equivalent to snprintf,
// NOTE: sprintf_s is not equivalent to snprintf,
// they are equivalent when success, which is sufficient for our case
#define snprintf sprintf_s
#define vsnprintf vsprintf_s
@ -30,7 +31,7 @@
#endif
#endif
#ifdef __APPLE__
#ifdef __APPLE__
#define off64_t off_t
#define fopen64 std::fopen
#endif
@ -186,5 +187,5 @@ inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return NULL;
return &str[0];
}
} // namespace rabit
} // namespace rabit
#endif // RABIT_UTILS_H_

View File

@ -1,13 +1,14 @@
#ifndef RABIT_RABIT_SERIALIZABLE_H
#define RABIT_RABIT_SERIALIZABLE_H
#include <vector>
#include <string>
#include "./rabit/utils.h"
/*!
* Copyright (c) 2014 by Contributors
* \file serializable.h
* \brief defines serializable interface of rabit
* \author Tianqi Chen
*/
#ifndef RABIT_RABIT_SERIALIZABLE_H_
#define RABIT_RABIT_SERIALIZABLE_H_
#include <vector>
#include <string>
#include "./rabit/utils.h"
namespace rabit {
/*!
* \brief interface of stream I/O, used by ISerializable
@ -96,4 +97,4 @@ class ISerializable {
virtual void Save(IStream &fo) const = 0;
};
} // namespace rabit
#endif
#endif // RABIT_RABIT_SERIALIZABLE_H_

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.cc
* \brief Basic implementation of AllReduce
*
@ -32,13 +33,15 @@ AllreduceBase::AllreduceBase(void) {
// initialization function
void AllreduceBase::Init(void) {
// setup from enviroment variables
{// handling for hadoop
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode != 0) {
utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id");
utils::Check(task_id != NULL,
"hadoop_mode is set but cannot find mapred_task_id");
}
if (task_id != NULL) {
this->SetParam("rabit_task_id", task_id);
@ -48,7 +51,7 @@ void AllreduceBase::Init(void) {
if (attempt_id != 0) {
const char *att = strrchr(attempt_id, '_');
int num_trial;
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1);
}
}
@ -58,7 +61,8 @@ void AllreduceBase::Init(void) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode != 0) {
utils::Check(num_task != NULL, "hadoop_mode is set but cannot find mapred_map_tasks");
utils::Check(num_task != NULL,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
if (num_task != NULL) {
this->SetParam("rabit_world_size", num_task);
@ -81,11 +85,11 @@ void AllreduceBase::Shutdown(void) {
}
all_links.clear();
tree_links.plinks.clear();
if (tracker_uri == "NULL") return;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.SendStr(std::string("shutdown"));
tracker.Close();
utils::TCPSocket::Finalize();
}
@ -107,11 +111,11 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_reduce_buffer")) {
char unit;
unsigned long amount;
uint64_t amount;
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
switch (unit) {
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
@ -121,7 +125,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
default: utils::Error("invalid format for reduce buffer");
}
} else {
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
utils::Error("invalid format for reduce_buffer,"\
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
}
}
}
@ -137,11 +142,16 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect");
}
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id);
return tracker;
}
@ -161,29 +171,30 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
{// get new ranks
int newrank, num_neighbors;
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
rank = newrank;
utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors),
"ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
using utils::Assert;
// get new ranks
int newrank, num_neighbors;
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
sizeof(parent_rank), "ReConnectLink failure 4");
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
Assert(rank == -1 || newrank == rank,
"must keep rank to same if the node already have one");
rank = newrank;
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
// create listening socket
utils::TCPSocket sock_listen;
sock_listen.Create();
@ -204,56 +215,67 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
}
}
int ngood = static_cast<int>(good_link.size());
utils::Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
utils::Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]),
"ReConnectLink failure 6");
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
sizeof(good_link[i]), "ReConnectLink failure 6");
}
utils::Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
utils::Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
"ReConnectLink failure 8");
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
utils::Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
utils::Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
tracker.RecvStr(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1; r.sock.Close(); continue;
}
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12");
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
if (!match) all_links.push_back(r);
}
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
utils::Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
@ -278,9 +300,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
}
utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link");
utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link");
utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link");
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
@ -326,7 +351,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// if no childs, no need to reduce
if (nlink == static_cast<int>(parent_index != -1)) {
size_up_reduce = total_size;
}
}
// while we have not passed the messages out
while (true) {
// select helper
@ -347,7 +372,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
}
// size_write <= size_read
// size_write <= size_read
if (links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock);
// only watch for exception in live channels
@ -358,11 +383,11 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
}
// finish runing allreduce
if (finished) break;
// select must return
// select must return
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
}
// read data from childs
@ -392,7 +417,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// start position
size_t start = size_up_reduce % buffer_size;
// peform read till end of buffer
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce);
size_t nread = std::min(buffer_size - start,
max_reduce - size_up_reduce);
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
@ -407,7 +433,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
}
if (parent_index != -1) {
// pass message up to parent, can pass data that are already been reduced
if (selecter.CheckWrite(links[parent_index].sock)) {
if (selecter.CheckWrite(links[parent_index].sock)) {
ssize_t len = links[parent_index].sock.
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
if (len != -1) {
@ -417,7 +443,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) {
if (selecter.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
@ -425,7 +452,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
}
if (len != -1) {
size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
utils::Assert(size_down_in <= size_up_out,
"Allreduce: boundary error");
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
}
@ -437,11 +465,13 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// can pass message down to childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) {
return kSockError;
}
}
}
}
return kSuccess;
return kSuccess;
}
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
@ -455,14 +485,15 @@ AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
// size of space already read from data
size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root
int in_link = -2;
// initialize the link statistics
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
@ -471,9 +502,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
if (this->rank == root) {
size_in = total_size;
in_link = -1;
}
}
// while we have not passed the messages out
while(true) {
while (true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
@ -487,7 +518,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
selecter.WatchWrite(links[i].sock); finished = false;
}
selecter.WatchException(links[i].sock);
selecter.WatchException(links[i].sock);
}
// finish running
if (finished) break;
@ -495,14 +526,16 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) return kSockError;
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) {
return kSockError;
}
size_in = links[i].size_read;
if (size_in != 0) {
in_link = i; break;
@ -512,7 +545,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) return kSockError;
if (!links[in_link].ReadToArray(sendrecvbuf_, total_size)) {
return kSockError;
}
size_in = links[in_link].size_read;
}
}

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.h
* \brief Basic implementation of AllReduce
* using TCP non-block socket and tree-shape reduction.
@ -8,13 +9,14 @@
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_BASE_H
#define RABIT_ALLREDUCE_BASE_H
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <vector>
#include <string>
#include <rabit/utils.h>
#include <rabit/engine.h>
#include <algorithm>
#include "rabit/utils.h"
#include "rabit/engine.h"
#include "./socket.h"
namespace MPI {
@ -22,7 +24,7 @@ namespace MPI {
class Datatype {
public:
size_t type_size;
Datatype(size_t type_size) : type_size(type_size) {}
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
}
namespace rabit {
@ -31,7 +33,7 @@ namespace engine {
class AllreduceBase : public IEngine {
public:
// magic number to verify server
const static int kMagic = 0xff99;
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
@ -79,12 +81,13 @@ class AllreduceBase : public IEngine {
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess,
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
"Allreduce failed");
}
/*!
@ -201,12 +204,16 @@ class AllreduceBase : public IEngine {
// constructor
LinkRecord(void) {}
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n));
// make sure align to type_nbytes
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size);
buffer_size =
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size,
"too large type_nbytes=%lu, buffer_size=%lu",
type_nbytes, buffer_size);
// set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}
@ -225,7 +232,7 @@ class AllreduceBase : public IEngine {
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size;
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
if (nmax == 0) return true;
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
@ -235,7 +242,7 @@ class AllreduceBase : public IEngine {
if (len == -1) return errno == EAGAIN || errno == EWOULDBLOCK;
size_read += static_cast<size_t>(len);
return true;
}
}
/*!
* \brief read data into array,
* this function can not be used together with ReadToRingBuffer

View File

@ -1,11 +1,13 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust-inl.h
* \brief implementation of inline template function in AllreduceRobust
*
* \author Tianqi Chen
*/
#ifndef RABIT_ENGINE_ROBUST_INL_H
#define RABIT_ENGINE_ROBUST_INL_H
#ifndef RABIT_ENGINE_ROBUST_INL_H_
#define RABIT_ENGINE_ROBUST_INL_H_
#include <vector>
namespace rabit {
namespace engine {
@ -33,14 +35,14 @@ inline AllreduceRobust::ReturnType
AllreduceRobust::MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func) (const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)
) {
EdgeType (*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());
const int nlink = static_cast<int>(links.size());
// initialize the pointers
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
@ -58,7 +60,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// if no childs, no need to, directly start passing message
if (nlink == static_cast<int>(parent_index != -1)) {
utils::Assert(parent_index == 0, "parent must be 0");
edge_out[parent_index] = func(node_value, edge_in, parent_index);
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
}
// while we have not passed the messages out
@ -94,7 +96,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
}
if (stage == 0) {
@ -103,7 +105,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) return kSockError;
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) {
return kSockError;
}
}
if (links[i].size_read != sizeof(EdgeType)) finished = false;
}
@ -124,13 +128,17 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (stage == 1) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) return kSockError;
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) {
return kSockError;
}
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
}
if (stage == 2) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) return kSockError;
utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) {
return kSockError;
}
if (links[pid].size_read == sizeof(EdgeType)) {
for (int i = 0; i < nlink; ++i) {
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
@ -141,7 +149,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (stage == 3) {
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) {
return kSockError;
}
}
}
}
@ -150,4 +160,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
}
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_ROBUST_INL_H
#endif // RABIT_ENGINE_ROBUST_INL_H_

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.cc
* \brief Robust implementation of Allreduce
*
@ -9,10 +10,10 @@
#define NOMINMAX
#include <limits>
#include <utility>
#include <rabit/io.h>
#include <rabit/utils.h>
#include <rabit/engine.h>
#include <rabit/rabit-inl.h>
#include "rabit/io.h"
#include "rabit/utils.h"
#include "rabit/engine.h"
#include "rabit/rabit-inl.h"
#include "./allreduce_robust.h"
namespace rabit {
@ -30,10 +31,10 @@ void AllreduceRobust::Shutdown(void) {
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"Shutdown: check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"Shutdown: check ack must return true");
"Shutdown: check ack must return true");
AllreduceBase::Shutdown();
}
/*!
@ -89,7 +90,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
} else {
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
}
}
}
}
resbuf.PushTemp(seq_counter, type_nbytes, count);
seq_counter += 1;
@ -102,7 +103,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
*/
void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
// skip action in single node
if (world_size == 1) return;
if (world_size == 1) return;
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
@ -119,7 +120,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
} else {
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
}
}
}
}
resbuf.PushTemp(seq_counter, 1, total_size);
seq_counter += 1;
@ -151,7 +152,8 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
// skip action in single node
if (world_size == 1) return 0;
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
utils::Check(local_model == NULL,
"need to set num_local_replica larger than 1 to checkpoint local_model");
}
// check if we succesful
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
@ -171,9 +173,10 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
// load from buffer
utils::MemoryBufferStream fs(&global_checkpoint);
if (global_checkpoint.length() == 0) {
version_number = 0;
version_number = 0;
} else {
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number");
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
"read in version number");
global_model->Load(fs);
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
"local model inconsistent, nlocal=%d", nlocal);
@ -212,9 +215,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
version_number += 1; return;
}
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
}
if (num_local_replica != 0) {
utils::Check(local_model == NULL,
"need to set num_local_replica larger than 1 to checkpoint local_model");
}
if (num_local_replica != 0) {
while (true) {
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
// save model model to new version place
@ -247,10 +251,10 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs);
// reset result buffer
resbuf.Clear(); seq_counter = 0;
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"check ack must return true");
"check ack must return true");
}
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
@ -383,7 +387,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
*/
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
if (err_type == kSuccess) return true;
{// simple way, shutdown all links
{
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
}
@ -392,8 +397,8 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
}
// this was old way
// TryResetLinks still causes possible errors, so not use this one
while(err_type != kSuccess) {
switch(err_type) {
while (err_type != kSuccess) {
switch (err_type) {
case kGetExcept: err_type = TryResetLinks(); break;
case kSockError: {
TryResetLinks();
@ -416,7 +421,7 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
* \param out_index the edge index of output link
* \return the shorest distance result of out edge specified by out_index
*/
inline std::pair<int,size_t>
inline std::pair<int, size_t>
ShortestDist(const std::pair<bool, size_t> &node_value,
const std::vector< std::pair<int, size_t> > &dist_in,
size_t out_index) {
@ -484,8 +489,9 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
int *p_recvlink,
std::vector<bool> *p_req_in) {
int best_link = -2;
{// get the shortest distance to the request point
std::vector< std::pair<int,size_t> > dist_in, dist_out;
{
// get the shortest distance to the request point
std::vector<std::pair<int, size_t> > dist_in, dist_out;
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
&dist_in, &dist_out, ShortestDist);
if (succ != kSuccess) return succ;
@ -512,7 +518,7 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
&req_in, &req_out, DataRequest);
if (succ != kSuccess) return succ;
// set p_req_in
p_req_in->resize(req_in.size());
p_req_in->resize(req_in.size());
for (size_t i = 0; i < req_in.size(); ++i) {
// set p_req_in
(*p_req_in)[i] = (req_in[i] != 0);
@ -591,19 +597,23 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kRequestData) {
const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) {
if(!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError;
if (!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError;
}
for (int i = 0; i < nlink; ++i) {
if (req_in[i] && links[i].size_write != links[pid].size_read &&
selecter.CheckWrite(links[i].sock)) {
if(!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) {
return kSockError;
}
}
}
}
if (role == kHaveData) {
for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock)) {
if(!links[i].WriteFromArray(sendrecvbuf_, size)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf_, size)) {
return kSockError;
}
}
}
}
@ -616,13 +626,14 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
}
utils::Assert(min_write <= links[pid].size_read, "boundary check");
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
}
for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) {
for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock) &&
links[pid].size_read != links[i].size_write) {
size_t start = links[i].size_write % buffer_size;
// send out data from ring buffer
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite);
if (len != -1) {
links[i].size_write += len;
@ -648,15 +659,15 @@ AllreduceRobust::TryRecoverData(RecoverType role,
*/
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
// check in local data
RecoverType role = requester ? kRequestData : kHaveData;
ReturnType succ;
RecoverType role = requester ? kRequestData : kHaveData;
ReturnType succ;
if (num_local_replica != 0) {
if (requester) {
// clear existing history, if any, before load
local_rptr[local_chkpt_version].clear();
local_chkpt[local_chkpt_version].clear();
local_chkpt[local_chkpt_version].clear();
}
// recover local checkpoint
// recover local checkpoint
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
&local_chkpt[local_chkpt_version]);
if (succ != kSuccess) return succ;
@ -716,7 +727,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
// if we goes to this place, use must have already setup the state once
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
"TryGetResult::Checkpoint");
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
}
// handles normal data recovery
RecoverType role;
@ -735,8 +746,9 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
utils::Check(data_size != 0, "zero size check point is not allowed");
if (role == kRequestData || role == kHaveData) {
utils::Check(data_size == size,
"Allreduce Recovered data size do not match the specification of function call\n"\
"Please check if calling sequence of recovered program is the same the original one in current VersionNumber");
"Allreduce Recovered data size do not match the specification of function call.\n"\
"Please check if calling sequence of recovered program is the " \
"same the original one in current VersionNumber");
}
return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
}
@ -766,7 +778,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
while (true) {
this->ReportStatus();
// action
ActionSummary act = req;
ActionSummary act = req;
// get the reduced action
if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
if (act.check_ack()) {
@ -816,7 +828,8 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
if (requester) return true;
} else {
// all the request is same, this is most recent command that is yet to be executed
// all the request is same,
// this is most recent command that is yet to be executed
return false;
}
}
@ -855,7 +868,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
}
const int n = num_local_replica;
{// backward passing, passing state in backward direction of the ring
{
// backward passing, passing state in backward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_back(n + 1);
@ -897,10 +911,10 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
// update rptr
rptr.resize(nread_end + 1);
for (int i = nlocal; i < nread_end; ++i) {
rptr[i + 1] = rptr[i] + sizes[i];
rptr[i + 1] = rptr[i] + sizes[i];
}
chkpt.resize(rptr.back());
// pass data through the link
// pass data through the link
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nread_end],
ring_next, ring_prev);
@ -908,7 +922,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
}
}
{// forward passing, passing state in forward direction of the ring
{
// forward passing, passing state in forward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_forward(n + 1);
@ -926,7 +941,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
1 * sizeof(int), 2 * sizeof(int),
0 * sizeof(int), 1 * sizeof(int),
ring_next, ring_prev);
if (succ != kSuccess) return succ;
if (succ != kSuccess) return succ;
// calculate the number of things we can read from next link
int nread_end = nlocal, nwrite_end = 1;
// have to have itself in order to get other data from prev link
@ -936,7 +951,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
nread_end = std::max(nread_end, i + 1);
nwrite_end = i + 1;
}
if (nwrite_end > n) nwrite_end = n;
if (nwrite_end > n) nwrite_end = n;
} else {
nread_end = 0; nwrite_end = 0;
}
@ -963,7 +978,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
rptr[i + 1] = rptr[i] + sizes[i];
}
chkpt.resize(rptr.back());
// pass data through the link
// pass data through the link
succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nwrite_end],
ring_prev, ring_next);
@ -995,7 +1010,8 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
if (num_local_replica == 0) return kSuccess;
std::vector<size_t> &rptr = *p_local_rptr;
std::string &chkpt = *p_local_chkpt;
utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state");
utils::Assert(rptr.size() == 2,
"TryCheckinLocalState must have exactly 1 state");
const int n = num_local_replica;
std::vector<size_t> sizes(n + 1);
sizes[0] = rptr[1] - rptr[0];
@ -1012,9 +1028,9 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
rptr.resize(n + 2);
for (int i = 1; i <= n; ++i) {
rptr[i + 1] = rptr[i] + sizes[i];
}
}
chkpt.resize(rptr.back());
// pass data through the link
// pass data through the link
succ = RingPassing(BeginPtr(chkpt),
rptr[1], rptr[n + 1],
rptr[0], rptr[n],
@ -1050,13 +1066,14 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
LinkRecord *read_link,
LinkRecord *write_link) {
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
utils::Assert(write_end <= read_end, "RingPassing: boundary check1, write_end=%lu, read_end=%lu", write_end, read_end);
utils::Assert(write_end <= read_end,
"RingPassing: boundary check1");
utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2");
utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3");
// take reference
LinkRecord &prev = *read_link, &next = *write_link;
// send recv buffer
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) {
bool finished = true;
utils::SelectHelper selecter;
@ -1066,7 +1083,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
}
if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock);
finished = false;
finished = false;
}
selecter.WatchException(prev.sock);
selecter.WatchException(next.sock);
@ -1078,7 +1095,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) {
prev.sock.Close(); return kSockError;
}
}
if (len != -1) {
read_ptr += static_cast<size_t>(len);
} else {

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.h
* \brief Robust implementation of Allreduce
* using TCP non-block socket and tree-shape reduction.
@ -7,10 +8,12 @@
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_ROBUST_H
#define RABIT_ALLREDUCE_ROBUST_H
#ifndef RABIT_ALLREDUCE_ROBUST_H_
#define RABIT_ALLREDUCE_ROBUST_H_
#include <vector>
#include <rabit/engine.h>
#include <string>
#include <algorithm>
#include "rabit/engine.h"
#include "./allreduce_base.h"
namespace rabit {
@ -111,11 +114,11 @@ class AllreduceRobust : public AllreduceBase {
private:
// constant one byte out of band message to indicate error happening
// and mark for channel cleanup
const static char kOOBReset = 95;
static const char kOOBReset = 95;
// and mark for channel cleanup, after OOB signal
const static char kResetMark = 97;
static const char kResetMark = 97;
// and mark for channel cleanup
const static char kResetAck = 97;
static const char kResetAck = 97;
/*! \brief type of roles each node can play during recovery */
enum RecoverType {
/*! \brief current node have data */
@ -132,29 +135,29 @@ class AllreduceRobust : public AllreduceBase {
*/
struct ActionSummary {
// maximumly allowed sequence id
const static int kSpecialOp = (1 << 26);
static const int kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint
const static int kLocalCheckPoint = (1 << 26) - 2;
static const int kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal
const static int kLocalCheckAck = (1 << 26) - 1;
static const int kLocalCheckAck = (1 << 26) - 1;
//---------------------------------------------
// The following are bit mask of flag used in
// The following are bit mask of flag used in
//----------------------------------------------
// some node want to load check point
const static int kLoadCheck = 1;
static const int kLoadCheck = 1;
// some node want to do check point
const static int kCheckPoint = 2;
static const int kCheckPoint = 2;
// check point Ack, we use a two phase message in check point,
// this is the second phase of check pointing
const static int kCheckAck = 4;
static const int kCheckAck = 4;
// there are difference sequence number the nodes proposed
// this means we want to do recover execution of the lower sequence
// action instead of normal execution
const static int kDiffSeq = 8;
static const int kDiffSeq = 8;
// constructor
ActionSummary(void) {}
// constructor of action
ActionSummary(int flag, int minseqno = kSpecialOp) {
// constructor of action
explicit ActionSummary(int flag, int minseqno = kSpecialOp) {
seqcode = (minseqno << 4) | flag;
}
// minimum number of all operations
@ -181,10 +184,11 @@ class AllreduceRobust : public AllreduceBase {
inline int flag(void) const {
return seqcode & 15;
}
// reducer for Allreduce, used to get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
// reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_;
ActionSummary *dst = (ActionSummary*)dst_;
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) {
int src_seqno = src[i].min_seqno();
int dst_seqno = dst[i].min_seqno();
@ -192,7 +196,8 @@ class AllreduceRobust : public AllreduceBase {
if (src_seqno == dst_seqno) {
dst[i] = ActionSummary(flag, src_seqno);
} else {
dst[i] = ActionSummary(flag | kDiffSeq, std::min(src_seqno, dst_seqno));
dst[i] = ActionSummary(flag | kDiffSeq,
std::min(src_seqno, dst_seqno));
}
}
}
@ -222,7 +227,7 @@ class AllreduceRobust : public AllreduceBase {
data_.resize(rptr_.back() + nhop);
return BeginPtr(data_) + rptr_.back();
}
// push the result in temp to the
// push the result in temp to the
inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
size_t size = type_nbytes * count;
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
@ -234,13 +239,14 @@ class AllreduceRobust : public AllreduceBase {
size_.push_back(size);
utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
}
// return the stored result of seqid, if any
// return the stored result of seqid, if any
inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin();
size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
*p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx];
}
}
// drop last stored result
inline void DropLast(void) {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
@ -254,15 +260,16 @@ class AllreduceRobust : public AllreduceBase {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
private:
// sequence number of each
// sequence number of each
std::vector<int> seqno_;
// pointer to the positions
std::vector<size_t> rptr_;
// actual size of each buffer
std::vector<size_t> size_;
// content of the buffer
std::vector<uint64_t> data_;
std::vector<uint64_t> data_;
};
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
@ -291,14 +298,16 @@ class AllreduceRobust : public AllreduceBase {
* \param buf the buffer to store the result
* \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set, seqno needs to be set to ActionSummary::kSpecialOp
* \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kSpecialOp
*
* \return if this function can return true or false
* - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed
* - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action
*/
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kSpecialOp);
bool RecoverExec(void *buf, size_t size, int flag,
int seqno = ActionSummary::kSpecialOp);
/*!
* \brief try to load check point
*
@ -363,7 +372,7 @@ class AllreduceRobust : public AllreduceBase {
void *sendrecvbuf_,
size_t size,
int recv_link,
const std::vector<bool> &req_in);
const std::vector<bool> &req_in);
/*!
* \brief try to recover the local state, making each local state to be the result of itself
* plus replication of states in previous num_local_replica hops in the ring
@ -446,17 +455,17 @@ o * the input state must exactly one saved state(local state of current node)
inline ReturnType MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func) (const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)
);
EdgeType (*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index));
//---- recovery data structure ----
// the round of result buffer, used to mode the result
int result_buffer_round;
// result buffer of all reduce
ResultBuffer resbuf;
// last check point global model
std::string global_checkpoint;
std::string global_checkpoint;
// number of replica for local state/model
int num_local_replica;
// --- recovery data structure for local checkpoint
@ -465,16 +474,15 @@ o * the input state must exactly one saved state(local state of current node)
// pointer to memory position in the local model
// local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2];
// storage for local model replicas
std::string local_chkpt[2];
// version of local checkpoint can be 1 or 0
int local_chkpt_version;
int local_chkpt_version;
};
} // namespace engine
} // namespace rabit
// implementation of inline template function
#include "./allreduce_robust-inl.h"
#endif // RABIT_ALLREDUCE_ROBUST_H
#endif // RABIT_ALLREDUCE_ROBUST_H_

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.cc
* \brief this file governs which implementation of engine we are actually using
* provides an singleton of engine interface
@ -41,16 +42,17 @@ void Finalize(void) {
IEngine *GetEngine(void) {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
red, prepare_fun, prepare_arg);
}
// code for reduce handle

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
@ -25,9 +26,10 @@ class EmptyEngine : public IEngine {
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
}
virtual void InitAfterException(void) {
utils::Error("EmptyEngine is not fault tolerant");
@ -51,7 +53,7 @@ class EmptyEngine : public IEngine {
virtual int GetWorldSize(void) const {
return 1;
}
/*! \brief get the host name of current node */
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
return std::string("");
}
@ -59,6 +61,7 @@ class EmptyEngine : public IEngine {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
};
@ -77,11 +80,11 @@ void Finalize(void) {
IEngine *GetEngine(void) {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mpi.cc
* \brief this file gives an implementation of engine interface using MPI,
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
@ -8,10 +9,10 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <cstdio>
#include <rabit/engine.h>
#include <rabit/utils.h>
#include <mpi.h>
#include <cstdio>
#include "rabit/engine.h"
#include "rabit/utils.h"
namespace rabit {
namespace engine {
@ -27,9 +28,10 @@ class MPIEngine : public IEngine {
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead");
utils::Error("MPIEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
}
virtual void InitAfterException(void) {
@ -48,13 +50,13 @@ class MPIEngine : public IEngine {
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return MPI::COMM_WORLD.Get_rank();
return MPI::COMM_WORLD.Get_rank();
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return MPI::COMM_WORLD.Get_size();
return MPI::COMM_WORLD.Get_size();
}
/*! \brief get the host name of current node */
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
int len;
char name[MPI_MAX_PROCESSOR_NAME];
@ -68,6 +70,7 @@ class MPIEngine : public IEngine {
utils::Printf("%s", msg.c_str());
}
}
private:
int version_number;
};
@ -91,7 +94,7 @@ IEngine *GetEngine(void) {
// transform enum to MPI data type
inline MPI::Datatype GetType(mpi::DataType dtype) {
using namespace mpi;
switch(dtype) {
switch (dtype) {
case kInt: return MPI::INT;
case kUInt: return MPI::UNSIGNED;
case kFloat: return MPI::FLOAT;
@ -103,7 +106,7 @@ inline MPI::Datatype GetType(mpi::DataType dtype) {
// transform enum to MPI OP
inline MPI::Op GetOp(mpi::OpType otype) {
using namespace mpi;
switch(otype) {
switch (otype) {
case kMax: return MPI::MAX;
case kMin: return MPI::MIN;
case kSum: return MPI::SUM;
@ -112,17 +115,18 @@ inline MPI::Op GetOp(mpi::OpType otype) {
utils::Error("unknown mpi::OpType");
return MPI::MAX;
}
// perform in-place allreduce, on sendrecvbuf
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
count, GetType(dtype), GetOp(op));
}
// code for reduce handle
@ -152,7 +156,7 @@ void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
created_type_nbytes_ = type_nbytes;
htype_ = dtype;
}
MPI::Op *op = new MPI::Op();
MPI::User_function *pf = redfunc;
op->Init(pf, true);
@ -175,7 +179,7 @@ void ReduceHandle::Allreduce(void *sendrecvbuf,
dtype->Commit();
created_type_nbytes_ = type_nbytes;
}
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
}
} // namespace engine

View File

@ -1,10 +1,11 @@
#ifndef RABIT_SOCKET_H
#define RABIT_SOCKET_H
/*!
* Copyright (c) 2014 by Contributors
* \file socket.h
* \brief this file aims to provide a wrapper of sockets
* \author Tianqi Chen
*/
#ifndef RABIT_SOCKET_H_
#define RABIT_SOCKET_H_
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
@ -21,7 +22,7 @@
#endif
#include <string>
#include <cstring>
#include <rabit/utils.h>
#include "rabit/utils.h"
#if defined(_WIN32)
typedef int ssize_t;
@ -68,9 +69,11 @@ struct SockAddr {
inline std::string AddrStr(void) const {
std::string buf; buf.resize(256);
#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, &buf[0], buf.length());
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length());
#else
const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length());
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif
Assert(s != NULL, "cannot decode address");
return std::string(s);
@ -94,12 +97,12 @@ class Socket {
*/
inline static void Startup(void) {
#ifdef _WIN32
WSADATA wsa_data;
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != -1) {
Socket::Error("Startup");
}
Socket::Error("Startup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n");
}
#endif
@ -118,11 +121,11 @@ class Socket {
* it will set it back to block mode
*/
inline void SetNonBlock(bool non_block) {
#ifdef _WIN32
u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
#ifdef _WIN32
u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock");
}
}
#else
int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) {
@ -143,7 +146,8 @@ class Socket {
* \param addr
*/
inline void Bind(const SockAddr &addr) {
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) {
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) {
Socket::Error("Bind");
}
}
@ -154,10 +158,11 @@ class Socket {
* \return the port successfully bind to, return -1 if failed to bind any port
*/
inline int TryBindHost(int start_port, int end_port) {
// TODO, add prefix check
// TODO(tqchen) add prefix check
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0) {
return port;
}
if (errno != EADDRINUSE) {
@ -179,22 +184,22 @@ class Socket {
inline bool BadSocket(void) const {
if (IsClosed()) return true;
int err = GetSockError();
if (err == EBADF || err == EINTR) return true;
if (err == EBADF || err == EINTR) return true;
return false;
}
/*! \brief check if socket is already closed */
inline bool IsClosed(void) const {
return sockfd == INVALID_SOCKET;
}
}
/*! \brief close the socket */
inline void Close(void) {
if (sockfd != INVALID_SOCKET) {
#ifdef _WIN32
closesocket(sockfd);
#else
close(sockfd);
close(sockfd);
#endif
sockfd = INVALID_SOCKET;
sockfd = INVALID_SOCKET;
} else {
Error("Socket::Close double close the socket or close without create");
}
@ -204,6 +209,7 @@ class Socket {
int errsv = errno;
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
}
protected:
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
}
@ -227,7 +233,7 @@ class TCPSocket : public Socket{
int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive");
}
}
}
/*!
* \brief create the socket, call this before using socket
@ -273,7 +279,8 @@ class TCPSocket : public Socket{
* \return whether connect is successful
*/
inline bool Connect(const SockAddr &addr) {
return connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0;
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0;
}
/*!
* \brief send data using the socket
@ -284,7 +291,7 @@ class TCPSocket : public Socket{
* return -1 if error occurs
*/
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_);
const char *buf = reinterpret_cast<const char*>(buf_);
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
}
/*!
@ -296,7 +303,7 @@ class TCPSocket : public Socket{
* return -1 if error occurs
*/
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_);
char *buf = reinterpret_cast<char*>(buf_);
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
}
/*!
@ -331,7 +338,8 @@ class TCPSocket : public Socket{
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
Socket::Error("RecvAll");
@ -385,7 +393,7 @@ struct SelectHelper {
* \param fd file descriptor to be watched
*/
inline void WatchRead(SOCKET fd) {
FD_SET(fd, &read_set);
FD_SET(fd, &read_set);
if (fd > maxfd) maxfd = fd;
}
/*!
@ -403,7 +411,7 @@ struct SelectHelper {
inline void WatchException(SOCKET fd) {
FD_SET(fd, &except_set);
if (fd > maxfd) maxfd = fd;
}
}
/*!
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
@ -435,8 +443,9 @@ struct SelectHelper {
fd_set wait_set;
FD_ZERO(&wait_set);
FD_SET(fd, &wait_set);
return Select_(static_cast<int>(fd + 1), NULL, NULL, &wait_set, timeout);
}
return Select_(static_cast<int>(fd + 1),
NULL, NULL, &wait_set, timeout);
}
/*!
* \brief peform select on the set defined
* \param select_read whether to watch for read event
@ -454,9 +463,10 @@ struct SelectHelper {
}
return ret;
}
private:
inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) {
inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) {
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL);
@ -465,12 +475,12 @@ struct SelectHelper {
tm.tv_usec = (timeout % 1000) * 1000;
tm.tv_sec = timeout / 1000;
return select(maxfd, rfds, wfds, efds, &tm);
}
}
}
SOCKET maxfd;
SOCKET maxfd;
fd_set read_set, write_set, except_set;
};
} // namespace utils
} // namespace rabit
#endif
#endif // RABIT_SOCKET_H_