cpplint pass

This commit is contained in:
tqchen 2014-12-28 05:12:07 -08:00
parent 15836eb98e
commit 27d6977a3e
16 changed files with 406 additions and 328 deletions

View File

@ -36,4 +36,4 @@ $(ALIB):
ar cr $@ $+
clean:
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~
$(RM) $(OBJ) $(MPIOBJ) $(ALIB) $(MPIALIB) *~ src/*~ include/*~ include/*/*~

View File

@ -1,6 +1,5 @@
#ifndef RABIT_RABIT_H
#define RABIT_RABIT_H
/*!
* Copyright (c) 2014 by Contributors
* \file rabit.h
* \brief This file defines unified Allreduce/Broadcast interface of rabit
* The actual implementation is redirected to rabit engine
@ -9,6 +8,8 @@
* rabit.h and serializable.h is all the user need to use rabit interface
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_RABIT_H_
#define RABIT_RABIT_H_
#include <string>
#include <vector>
// optionally support of lambda function in C++11, if available
@ -142,9 +143,9 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* \tparam DType type of data
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count, std::function<void()> prepare_fun);
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun);
#endif // C++11
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
@ -228,6 +229,7 @@ class Reducer {
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun);
#endif
private:
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
@ -274,6 +276,7 @@ class SerializeReducer {
size_t max_nbyte, size_t count,
std::function<void()> prepare_fun);
#endif
private:
/*! \brief function handle to do reduce */
engine::ReduceHandle handle_;
@ -283,4 +286,4 @@ class SerializeReducer {
} // namespace rabit
// implementation of template functions
#include "./rabit/rabit-inl.h"
#endif // RABIT_ALLREDUCE_H
#endif // RABIT_RABIT_H_

View File

@ -1,10 +1,12 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.h
* \brief This file defines the core interface of allreduce library
* \author Tianqi Chen, Nacho, Tianyi
*/
#ifndef RABIT_ENGINE_H
#define RABIT_ENGINE_H
#ifndef RABIT_ENGINE_H_
#define RABIT_ENGINE_H_
#include <string>
#include "../rabit_serializable.h"
namespace MPI {
@ -211,7 +213,7 @@ class ReduceHandle {
/*! \return the number of bytes occupied by the type */
static int TypeSize(const MPI::Datatype &dtype);
private:
protected:
// handle data field
void *handle_;
// handle to the type field
@ -221,5 +223,4 @@ class ReduceHandle {
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_H
#endif // RABIT_ENGINE_H_

View File

@ -1,16 +1,19 @@
#ifndef RABIT_UTILS_IO_H
#define RABIT_UTILS_IO_H
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include "./utils.h"
#include "../rabit_serializable.h"
/*!
* Copyright (c) 2014 by Contributors
* \file io.h
* \brief utilities that implements different serializable interface
* \author Tianqi Chen
*/
#ifndef RABIT_UTILS_IO_H_
#define RABIT_UTILS_IO_H_
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include <algorithm>
#include "./utils.h"
#include "../rabit_serializable.h"
namespace rabit {
namespace utils {
/*! \brief interface of i/o stream that support seek */
@ -26,7 +29,8 @@ class ISeekStream: public IStream {
struct MemoryFixSizeBuffer : public ISeekStream {
public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual ~MemoryFixSizeBuffer(void) {}
@ -64,7 +68,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream {
public:
MemoryBufferStream(std::string *p_buffer)
explicit MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
@ -98,35 +102,6 @@ struct MemoryBufferStream : public ISeekStream {
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
public:
explicit FileStream(FILE *fp) : fp(fp) {}
explicit FileStream(void) {
this->fp = NULL;
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, size, 1, fp);
}
virtual void Write(const void *ptr, size_t size) {
std::fwrite(ptr, size, 1, fp);
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET);
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;
}
}
private:
FILE *fp;
};
} // namespace utils
} // namespace rabit
#endif
#endif // RABIT_UTILS_IO_H_

View File

@ -1,10 +1,11 @@
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
/*!
* Copyright (c) 2014 by Contributors
* \file utils.h
* \brief simple utils to support the code
* \author Tianqi Chen
*/
#ifndef RABIT_UTILS_H_
#define RABIT_UTILS_H_
#define _CRT_SECURE_NO_WARNINGS
#include <cstdio>
#include <string>

View File

@ -1,13 +1,14 @@
#ifndef RABIT_RABIT_SERIALIZABLE_H
#define RABIT_RABIT_SERIALIZABLE_H
#include <vector>
#include <string>
#include "./rabit/utils.h"
/*!
* Copyright (c) 2014 by Contributors
* \file serializable.h
* \brief defines serializable interface of rabit
* \author Tianqi Chen
*/
#ifndef RABIT_RABIT_SERIALIZABLE_H_
#define RABIT_RABIT_SERIALIZABLE_H_
#include <vector>
#include <string>
#include "./rabit/utils.h"
namespace rabit {
/*!
* \brief interface of stream I/O, used by ISerializable
@ -96,4 +97,4 @@ class ISerializable {
virtual void Save(IStream &fo) const = 0;
};
} // namespace rabit
#endif
#endif // RABIT_RABIT_SERIALIZABLE_H_

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.cc
* \brief Basic implementation of AllReduce
*
@ -32,13 +33,15 @@ AllreduceBase::AllreduceBase(void) {
// initialization function
void AllreduceBase::Init(void) {
// setup from enviroment variables
{// handling for hadoop
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode != 0) {
utils::Check(task_id != NULL, "hadoop_mode is set but cannot find mapred_task_id");
utils::Check(task_id != NULL,
"hadoop_mode is set but cannot find mapred_task_id");
}
if (task_id != NULL) {
this->SetParam("rabit_task_id", task_id);
@ -58,7 +61,8 @@ void AllreduceBase::Init(void) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode != 0) {
utils::Check(num_task != NULL, "hadoop_mode is set but cannot find mapred_map_tasks");
utils::Check(num_task != NULL,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
if (num_task != NULL) {
this->SetParam("rabit_world_size", num_task);
@ -111,7 +115,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_reduce_buffer")) {
char unit;
unsigned long amount;
uint64_t amount;
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
switch (unit) {
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
@ -121,7 +125,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
default: utils::Error("invalid format for reduce buffer");
}
} else {
utils::Error("invalid format for reduce_buffer, shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
utils::Error("invalid format for reduce_buffer,"\
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
}
}
}
@ -137,11 +142,16 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect");
}
utils::Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1");
utils::Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2");
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
utils::Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
utils::Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id);
return tracker;
}
@ -161,29 +171,30 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
{// get new ranks
using utils::Assert;
// get new ranks
int newrank, num_neighbors;
utils::Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank),
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
sizeof(parent_rank), "ReConnectLink failure 4");
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one");
Assert(rank == -1 || newrank == rank,
"must keep rank to same if the node already have one");
rank = newrank;
utils::Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == sizeof(num_neighbors),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
utils::Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
utils::Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
utils::Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
}
// create listening socket
utils::TCPSocket sock_listen;
sock_listen.Create();
@ -204,56 +215,67 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
}
}
int ngood = static_cast<int>(good_link.size());
utils::Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
utils::Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == sizeof(good_link[i]),
"ReConnectLink failure 6");
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
sizeof(good_link[i]), "ReConnectLink failure 6");
}
utils::Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
utils::Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept),
"ReConnectLink failure 8");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
utils::Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
utils::Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1; r.sock.Close(); continue;
}
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12");
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 13");
utils::Check(hrank == r.rank, "ReConnectLink failure, link rank inconsistent");
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
if (!match) all_links.push_back(r);
}
utils::Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14");
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
utils::Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
utils::Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15");
utils::Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), "ReConnectLink failure 15");
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(), "Override a link that is active");
utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
@ -278,9 +300,12 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
}
utils::Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link");
utils::Assert(prev_rank == -1 || ring_prev != NULL, "cannot find prev ring in the link");
utils::Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link");
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
@ -392,7 +417,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// start position
size_t start = size_up_reduce % buffer_size;
// peform read till end of buffer
size_t nread = std::min(buffer_size - start, max_reduce - size_up_reduce);
size_t nread = std::min(buffer_size - start,
max_reduce - size_up_reduce);
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
@ -417,7 +443,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) && total_size > size_down_in) {
if (selecter.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
@ -425,7 +452,8 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
}
if (len != -1) {
size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
utils::Assert(size_down_in <= size_up_out,
"Allreduce: boundary error");
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
}
@ -437,7 +465,9 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
// can pass message down to childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckWrite(links[i].sock)) {
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf, size_down_in)) {
return kSockError;
}
}
}
}
@ -455,7 +485,8 @@ AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size, "Broadcast: root should be smaller than world size");
utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
// size of space already read from data
@ -502,7 +533,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) return kSockError;
if (!links[i].ReadToArray(sendrecvbuf_, total_size)) {
return kSockError;
}
size_in = links[i].size_read;
if (size_in != 0) {
in_link = i; break;
@ -512,7 +545,9 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if(!links[in_link].ReadToArray(sendrecvbuf_, total_size)) return kSockError;
if (!links[in_link].ReadToArray(sendrecvbuf_, total_size)) {
return kSockError;
}
size_in = links[in_link].size_read;
}
}

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.h
* \brief Basic implementation of AllReduce
* using TCP non-block socket and tree-shape reduction.
@ -8,13 +9,14 @@
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_BASE_H
#define RABIT_ALLREDUCE_BASE_H
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <vector>
#include <string>
#include <rabit/utils.h>
#include <rabit/engine.h>
#include <algorithm>
#include "rabit/utils.h"
#include "rabit/engine.h"
#include "./socket.h"
namespace MPI {
@ -22,7 +24,7 @@ namespace MPI {
class Datatype {
public:
size_t type_size;
Datatype(size_t type_size) : type_size(type_size) {}
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
}
namespace rabit {
@ -31,7 +33,7 @@ namespace engine {
class AllreduceBase : public IEngine {
public:
// magic number to verify server
const static int kMagic = 0xff99;
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
@ -84,7 +86,8 @@ class AllreduceBase : public IEngine {
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == kSuccess,
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
"Allreduce failed");
}
/*!
@ -201,12 +204,16 @@ class AllreduceBase : public IEngine {
// constructor
LinkRecord(void) {}
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n));
// make sure align to type_nbytes
buffer_size = buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size, "too large type_nbytes=%lu, buffer_size=%lu", type_nbytes, buffer_size);
buffer_size =
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size,
"too large type_nbytes=%lu, buffer_size=%lu",
type_nbytes, buffer_size);
// set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}

View File

@ -1,11 +1,13 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust-inl.h
* \brief implementation of inline template function in AllreduceRobust
*
* \author Tianqi Chen
*/
#ifndef RABIT_ENGINE_ROBUST_INL_H
#define RABIT_ENGINE_ROBUST_INL_H
#ifndef RABIT_ENGINE_ROBUST_INL_H_
#define RABIT_ENGINE_ROBUST_INL_H_
#include <vector>
namespace rabit {
namespace engine {
@ -33,10 +35,10 @@ inline AllreduceRobust::ReturnType
AllreduceRobust::MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func) (const NodeType &node_value,
EdgeType (*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)
) {
size_t out_index)) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
// number of links
@ -103,7 +105,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) return kSockError;
if (!links[i].ReadToArray(&edge_in[i], sizeof(EdgeType))) {
return kSockError;
}
}
if (links[i].size_read != sizeof(EdgeType)) finished = false;
}
@ -124,13 +128,17 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (stage == 1) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) return kSockError;
if (!links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType))) {
return kSockError;
}
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
}
if (stage == 2) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) return kSockError;
if (!links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType))) {
return kSockError;
}
if (links[pid].size_read == sizeof(EdgeType)) {
for (int i = 0; i < nlink; ++i) {
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
@ -141,7 +149,9 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (stage == 3) {
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) return kSockError;
if (!links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType))) {
return kSockError;
}
}
}
}
@ -150,4 +160,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
}
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_ROBUST_INL_H
#endif // RABIT_ENGINE_ROBUST_INL_H_

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.cc
* \brief Robust implementation of Allreduce
*
@ -9,10 +10,10 @@
#define NOMINMAX
#include <limits>
#include <utility>
#include <rabit/io.h>
#include <rabit/utils.h>
#include <rabit/engine.h>
#include <rabit/rabit-inl.h>
#include "rabit/io.h"
#include "rabit/utils.h"
#include "rabit/engine.h"
#include "rabit/rabit-inl.h"
#include "./allreduce_robust.h"
namespace rabit {
@ -151,7 +152,8 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
// skip action in single node
if (world_size == 1) return 0;
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
utils::Check(local_model == NULL,
"need to set num_local_replica larger than 1 to checkpoint local_model");
}
// check if we succesful
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp)) {
@ -173,7 +175,8 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
if (global_checkpoint.length() == 0) {
version_number = 0;
} else {
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number");
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
"read in version number");
global_model->Load(fs);
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
"local model inconsistent, nlocal=%d", nlocal);
@ -212,7 +215,8 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
version_number += 1; return;
}
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
utils::Check(local_model == NULL,
"need to set num_local_replica larger than 1 to checkpoint local_model");
}
if (num_local_replica != 0) {
while (true) {
@ -383,7 +387,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
*/
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
if (err_type == kSuccess) return true;
{// simple way, shutdown all links
{
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
}
@ -484,7 +489,8 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
int *p_recvlink,
std::vector<bool> *p_req_in) {
int best_link = -2;
{// get the shortest distance to the request point
{
// get the shortest distance to the request point
std::vector<std::pair<int, size_t> > dist_in, dist_out;
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
&dist_in, &dist_out, ShortestDist);
@ -596,14 +602,18 @@ AllreduceRobust::TryRecoverData(RecoverType role,
for (int i = 0; i < nlink; ++i) {
if (req_in[i] && links[i].size_write != links[pid].size_read &&
selecter.CheckWrite(links[i].sock)) {
if(!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) {
return kSockError;
}
}
}
}
if (role == kHaveData) {
for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock)) {
if(!links[i].WriteFromArray(sendrecvbuf_, size)) return kSockError;
if (!links[i].WriteFromArray(sendrecvbuf_, size)) {
return kSockError;
}
}
}
}
@ -619,7 +629,8 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
}
for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) {
if (req_in[i] && selecter.CheckWrite(links[i].sock) &&
links[pid].size_read != links[i].size_write) {
size_t start = links[i].size_write % buffer_size;
// send out data from ring buffer
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
@ -735,8 +746,9 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
utils::Check(data_size != 0, "zero size check point is not allowed");
if (role == kRequestData || role == kHaveData) {
utils::Check(data_size == size,
"Allreduce Recovered data size do not match the specification of function call\n"\
"Please check if calling sequence of recovered program is the same the original one in current VersionNumber");
"Allreduce Recovered data size do not match the specification of function call.\n"\
"Please check if calling sequence of recovered program is the " \
"same the original one in current VersionNumber");
}
return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
}
@ -816,7 +828,8 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
if (requester) return true;
} else {
// all the request is same, this is most recent command that is yet to be executed
// all the request is same,
// this is most recent command that is yet to be executed
return false;
}
}
@ -855,7 +868,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
}
const int n = num_local_replica;
{// backward passing, passing state in backward direction of the ring
{
// backward passing, passing state in backward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_back(n + 1);
@ -908,7 +922,8 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
}
}
{// forward passing, passing state in forward direction of the ring
{
// forward passing, passing state in forward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_forward(n + 1);
@ -995,7 +1010,8 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
if (num_local_replica == 0) return kSuccess;
std::vector<size_t> &rptr = *p_local_rptr;
std::string &chkpt = *p_local_chkpt;
utils::Assert(rptr.size() == 2, "TryCheckinLocalState must have exactly 1 state");
utils::Assert(rptr.size() == 2,
"TryCheckinLocalState must have exactly 1 state");
const int n = num_local_replica;
std::vector<size_t> sizes(n + 1);
sizes[0] = rptr[1] - rptr[0];
@ -1050,7 +1066,8 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
LinkRecord *read_link,
LinkRecord *write_link) {
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
utils::Assert(write_end <= read_end, "RingPassing: boundary check1, write_end=%lu, read_end=%lu", write_end, read_end);
utils::Assert(write_end <= read_end,
"RingPassing: boundary check1");
utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2");
utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3");
// take reference

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.h
* \brief Robust implementation of Allreduce
* using TCP non-block socket and tree-shape reduction.
@ -7,10 +8,12 @@
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_ROBUST_H
#define RABIT_ALLREDUCE_ROBUST_H
#ifndef RABIT_ALLREDUCE_ROBUST_H_
#define RABIT_ALLREDUCE_ROBUST_H_
#include <vector>
#include <rabit/engine.h>
#include <string>
#include <algorithm>
#include "rabit/engine.h"
#include "./allreduce_base.h"
namespace rabit {
@ -111,11 +114,11 @@ class AllreduceRobust : public AllreduceBase {
private:
// constant one byte out of band message to indicate error happening
// and mark for channel cleanup
const static char kOOBReset = 95;
static const char kOOBReset = 95;
// and mark for channel cleanup, after OOB signal
const static char kResetMark = 97;
static const char kResetMark = 97;
// and mark for channel cleanup
const static char kResetAck = 97;
static const char kResetAck = 97;
/*! \brief type of roles each node can play during recovery */
enum RecoverType {
/*! \brief current node have data */
@ -132,29 +135,29 @@ class AllreduceRobust : public AllreduceBase {
*/
struct ActionSummary {
// maximumly allowed sequence id
const static int kSpecialOp = (1 << 26);
static const int kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint
const static int kLocalCheckPoint = (1 << 26) - 2;
static const int kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal
const static int kLocalCheckAck = (1 << 26) - 1;
static const int kLocalCheckAck = (1 << 26) - 1;
//---------------------------------------------
// The following are bit mask of flag used in
//----------------------------------------------
// some node want to load check point
const static int kLoadCheck = 1;
static const int kLoadCheck = 1;
// some node want to do check point
const static int kCheckPoint = 2;
static const int kCheckPoint = 2;
// check point Ack, we use a two phase message in check point,
// this is the second phase of check pointing
const static int kCheckAck = 4;
static const int kCheckAck = 4;
// there are difference sequence number the nodes proposed
// this means we want to do recover execution of the lower sequence
// action instead of normal execution
const static int kDiffSeq = 8;
static const int kDiffSeq = 8;
// constructor
ActionSummary(void) {}
// constructor of action
ActionSummary(int flag, int minseqno = kSpecialOp) {
explicit ActionSummary(int flag, int minseqno = kSpecialOp) {
seqcode = (minseqno << 4) | flag;
}
// minimum number of all operations
@ -181,10 +184,11 @@ class AllreduceRobust : public AllreduceBase {
inline int flag(void) const {
return seqcode & 15;
}
// reducer for Allreduce, used to get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
// reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_;
ActionSummary *dst = (ActionSummary*)dst_;
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) {
int src_seqno = src[i].min_seqno();
int dst_seqno = dst[i].min_seqno();
@ -192,7 +196,8 @@ class AllreduceRobust : public AllreduceBase {
if (src_seqno == dst_seqno) {
dst[i] = ActionSummary(flag, src_seqno);
} else {
dst[i] = ActionSummary(flag | kDiffSeq, std::min(src_seqno, dst_seqno));
dst[i] = ActionSummary(flag | kDiffSeq,
std::min(src_seqno, dst_seqno));
}
}
}
@ -236,7 +241,8 @@ class AllreduceRobust : public AllreduceBase {
}
// return the stored result of seqid, if any
inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(), seqno_.end(), seqid) - seqno_.begin();
size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
*p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx];
@ -254,6 +260,7 @@ class AllreduceRobust : public AllreduceBase {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
private:
// sequence number of each
std::vector<int> seqno_;
@ -291,14 +298,16 @@ class AllreduceRobust : public AllreduceBase {
* \param buf the buffer to store the result
* \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set, seqno needs to be set to ActionSummary::kSpecialOp
* \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kSpecialOp
*
* \return if this function can return true or false
* - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action
*/
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kSpecialOp);
bool RecoverExec(void *buf, size_t size, int flag,
int seqno = ActionSummary::kSpecialOp);
/*!
* \brief try to load check point
*
@ -446,10 +455,10 @@ o * the input state must exactly one saved state(local state of current node)
inline ReturnType MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func) (const NodeType &node_value,
EdgeType (*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)
);
size_t out_index));
//---- recovery data structure ----
// the round of result buffer, used to mode the result
int result_buffer_round;
@ -465,7 +474,7 @@ o * the input state must exactly one saved state(local state of current node)
// pointer to memory position in the local model
// local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2];
// storage for local model replicas
std::string local_chkpt[2];
@ -476,5 +485,4 @@ o * the input state must exactly one saved state(local state of current node)
} // namespace rabit
// implementation of inline template function
#include "./allreduce_robust-inl.h"
#endif // RABIT_ALLREDUCE_ROBUST_H
#endif // RABIT_ALLREDUCE_ROBUST_H_

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.cc
* \brief this file governs which implementation of engine we are actually using
* provides an singleton of engine interface
@ -50,7 +51,8 @@ void Allreduce_(void *sendrecvbuf,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg);
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
red, prepare_fun, prepare_arg);
}
// code for reduce handle

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
@ -25,7 +26,8 @@ class EmptyEngine : public IEngine {
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("EmptyEngine:: Allreduce is not supported, use Allreduce_ instead");
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
}
@ -59,6 +61,7 @@ class EmptyEngine : public IEngine {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
};

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure

View File

@ -1,4 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mpi.cc
* \brief this file gives an implementation of engine interface using MPI,
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
@ -8,10 +9,10 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <cstdio>
#include <rabit/engine.h>
#include <rabit/utils.h>
#include <mpi.h>
#include <cstdio>
#include "rabit/engine.h"
#include "rabit/utils.h"
namespace rabit {
namespace engine {
@ -27,7 +28,8 @@ class MPIEngine : public IEngine {
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("MPIEngine:: Allreduce is not supported, use Allreduce_ instead");
utils::Error("MPIEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
@ -68,6 +70,7 @@ class MPIEngine : public IEngine {
utils::Printf("%s", msg.c_str());
}
}
private:
int version_number;
};
@ -122,7 +125,8 @@ void Allreduce_(void *sendrecvbuf,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, GetType(dtype), GetOp(op));
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
count, GetType(dtype), GetOp(op));
}
// code for reduce handle

View File

@ -1,10 +1,11 @@
#ifndef RABIT_SOCKET_H
#define RABIT_SOCKET_H
/*!
* Copyright (c) 2014 by Contributors
* \file socket.h
* \brief this file aims to provide a wrapper of sockets
* \author Tianqi Chen
*/
#ifndef RABIT_SOCKET_H_
#define RABIT_SOCKET_H_
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
@ -21,7 +22,7 @@
#endif
#include <string>
#include <cstring>
#include <rabit/utils.h>
#include "rabit/utils.h"
#if defined(_WIN32)
typedef int ssize_t;
@ -68,9 +69,11 @@ struct SockAddr {
inline std::string AddrStr(void) const {
std::string buf; buf.resize(256);
#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, &buf[0], buf.length());
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length());
#else
const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length());
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif
Assert(s != NULL, "cannot decode address");
return std::string(s);
@ -143,7 +146,8 @@ class Socket {
* \param addr
*/
inline void Bind(const SockAddr &addr) {
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == -1) {
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) {
Socket::Error("Bind");
}
}
@ -154,10 +158,11 @@ class Socket {
* \return the port successfully bind to, return -1 if failed to bind any port
*/
inline int TryBindHost(int start_port, int end_port) {
// TODO, add prefix check
// TODO(tqchen) add prefix check
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0) {
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0) {
return port;
}
if (errno != EADDRINUSE) {
@ -204,6 +209,7 @@ class Socket {
int errsv = errno;
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
}
protected:
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
}
@ -273,7 +279,8 @@ class TCPSocket : public Socket{
* \return whether connect is successful
*/
inline bool Connect(const SockAddr &addr) {
return connect(sockfd, (sockaddr*)&addr.addr, sizeof(addr.addr)) == 0;
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0;
}
/*!
* \brief send data using the socket
@ -331,7 +338,8 @@ class TCPSocket : public Socket{
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
Socket::Error("RecvAll");
@ -435,7 +443,8 @@ struct SelectHelper {
fd_set wait_set;
FD_ZERO(&wait_set);
FD_SET(fd, &wait_set);
return Select_(static_cast<int>(fd + 1), NULL, NULL, &wait_set, timeout);
return Select_(static_cast<int>(fd + 1),
NULL, NULL, &wait_set, timeout);
}
/*!
* \brief peform select on the set defined
@ -456,7 +465,8 @@ struct SelectHelper {
}
private:
inline static int Select_(int maxfd, fd_set *rfds, fd_set *wfds, fd_set *efds, long timeout) {
inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) {
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL);
@ -473,4 +483,4 @@ struct SelectHelper {
};
} // namespace utils
} // namespace rabit
#endif
#endif // RABIT_SOCKET_H_