Merge commit 'd87691ec603db325d5b1c5db1186295a748df7cc' as 'subtree/rabit'

This commit is contained in:
tqchen
2015-01-18 21:08:17 -08:00
68 changed files with 9081 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
Source Files of Rabit
====
* This folder contains the source files of rabit library
* The library headers are in folder [include](../include)
* The .h files in this folder are internal header files that are only used by rabit and will not be seen by users

View File

@@ -0,0 +1,590 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.cc
* \brief Basic implementation of AllReduce
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <map>
#include <cstdlib>
#include <cstring>
#include "./allreduce_base.h"
namespace rabit {
namespace engine {
// constructor
AllreduceBase::AllreduceBase(void) {
tracker_uri = "NULL";
tracker_port = 9000;
host_uri = "";
slave_port = 9010;
nport_trial = 1000;
rank = 0;
world_size = -1;
hadoop_mode = 0;
version_number = 0;
task_id = "NULL";
err_link = NULL;
this->SetParam("rabit_reduce_buffer", "256MB");
}
// initialization function
void AllreduceBase::Init(void) {
// setup from enviroment variables
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode != 0) {
utils::Check(task_id != NULL,
"hadoop_mode is set but cannot find mapred_task_id");
}
if (task_id != NULL) {
this->SetParam("rabit_task_id", task_id);
this->SetParam("rabit_hadoop_mode", "1");
}
const char *attempt_id = getenv("mapred_task_id");
if (attempt_id != 0) {
const char *att = strrchr(attempt_id, '_');
int num_trial;
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1);
}
}
// handling for hadoop
const char *num_task = getenv("mapred_map_tasks");
if (num_task == NULL) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode != 0) {
utils::Check(num_task != NULL,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
if (num_task != NULL) {
this->SetParam("rabit_world_size", num_task);
}
}
// clear the setting before start reconnection
this->rank = -1;
//---------------------
// start socket
utils::Socket::Startup();
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
// get information from tracker
this->ReConnectLinks();
}
void AllreduceBase::Shutdown(void) {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
if (tracker_uri == "NULL") return;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.Close();
utils::TCPSocket::Finalize();
}
void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("print"));
tracker.SendStr(msg);
tracker.Close();
}
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_reduce_buffer")) {
char unit;
uint64_t amount;
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
switch (unit) {
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
case 'K': reduce_buffer_size = amount << 7UL; break;
case 'M': reduce_buffer_size = amount << 17UL; break;
case 'G': reduce_buffer_size = amount << 27UL; break;
default: utils::Error("invalid format for reduce buffer");
}
} else {
utils::Error("invalid format for reduce_buffer,"\
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
}
}
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
tracker.Create();
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
utils::Socket::Error("Connect");
}
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id);
return tracker;
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string(cmd));
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
using utils::Assert;
// get new ranks
int newrank, num_neighbors;
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
sizeof(parent_rank), "ReConnectLink failure 4");
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
Assert(rank == -1 || newrank == rank,
"must keep rank to same if the node already have one");
rank = newrank;
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
// create listening socket
utils::TCPSocket sock_listen;
sock_listen.Create();
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
} else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
sizeof(good_link[i]), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1; r.sock.Close(); continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
if (!match) all_links.push_back(r);
}
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
if (!match) all_links.push_back(r);
}
// close listening sockets
sock_listen.Close();
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_links[i]);
}
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
}
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*
* NOTE on Allreduce:
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess;
// total size of message
const size_t total_size = type_nbytes * count;
// number of links
const int nlink = static_cast<int>(links.size());
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// size of space that we already performs reduce in up pass
size_t size_up_reduce = 0;
// size of space that we have already passed to parent
size_t size_up_out = 0;
// size of message we received, and send in the down pass
size_t size_down_in = 0;
// initialize the link ring-buffer and pointer
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
}
links[i].ResetSize();
}
// if no childs, no need to reduce
if (nlink == static_cast<int>(parent_index != -1)) {
size_up_reduce = total_size;
}
// while we have not passed the messages out
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock);
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
selecter.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size){
if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
finished = false;
}
}
}
// finish runing allreduce
if (finished) break;
// select must return
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
// this node have childs, peform reduce
if (nlink > static_cast<int>(parent_index != -1)) {
size_t buffer_size = 0;
// do upstream reduce
size_t max_reduce = total_size;
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
max_reduce= std::min(max_reduce, links[i].size_read);
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
"buffer size inconsistent");
buffer_size = links[i].buffer_size;
}
}
utils::Assert(buffer_size != 0, "must assign buffer_size");
// round to type_n4bytes
max_reduce = (max_reduce / type_nbytes * type_nbytes);
// peform reduce, can be at most two rounds
while (size_up_reduce < max_reduce) {
// start position
size_t start = size_up_reduce % buffer_size;
// peform read till end of buffer
size_t nread = std::min(buffer_size - start,
max_reduce - size_up_reduce);
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
reducer(links[i].buffer_head + start,
sendrecvbuf + size_up_reduce,
static_cast<int>(nread / type_nbytes),
MPI::Datatype(type_nbytes));
}
}
size_up_reduce += nread;
}
}
if (parent_index != -1) {
// pass message up to parent, can pass data that are already been reduced
if (size_up_out < size_up_reduce) {
ssize_t len = links[parent_index].sock.
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
if (len != -1) {
size_up_out += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) {
return ReportError(&links[parent_index], ret);
}
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
links[parent_index].sock.Close();
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {
size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out,
"Allreduce: boundary error");
} else {
ReturnType ret = Errno2Return(errno);
if (ret != kSuccess) {
return ReportError(&links[parent_index], ret);
}
}
}
} else {
// this is root, can use reduce as most recent point
size_down_in = size_up_out = size_up_reduce;
}
// can pass message down to childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write < size_down_in) {
ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and recving data
* \param total_size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
// size of space already read from data
size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root
int in_link = -2;
// initialize the link statistics
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
// root have all the data
if (this->rank == root) {
size_in = total_size;
in_link = -1;
}
// while we have not passed the messages out
while (true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
selecter.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
size_in = links[i].size_read;
if (size_in != 0) {
in_link = i; break;
}
}
}
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
}
size_in = links[in_link].size_read;
}
}
// send data to all out-link
for (int i = 0; i < nlink; ++i) {
if (i != in_link && links[i].size_write < size_in) {
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
} // namespace engine
} // namespace rabit

View File

@@ -0,0 +1,436 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.h
* \brief Basic implementation of AllReduce
* using TCP non-block socket and tree-shape reduction.
*
* This implementation provides basic utility of AllReduce and Broadcast
* without considering node failure
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <vector>
#include <string>
#include <algorithm>
#include "rabit/utils.h"
#include "rabit/engine.h"
#include "./socket.h"
namespace MPI {
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
}
namespace rabit {
namespace engine {
/*! \brief implementation of basic Allreduce engine */
class AllreduceBase : public IEngine {
public:
// magic number to verify server
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
virtual void Init(void);
// shutdown the engine
virtual void Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
/*! \brief get rank */
virtual int GetRank(void) const {
return rank;
}
/*! \brief get rank */
virtual int GetWorldSize(void) const {
if (world_size == -1) return 1;
return world_size;
}
/*! \brief get rank */
virtual std::string GetHost(void) const {
return host_uri;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
"Allreduce failed");
}
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"Broadcast failed");
}
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
return 0;
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
version_number += 1;
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const {
return version_number;
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
utils::Error("InitAfterException: not implemented");
}
/*!
* \brief report current status to the job tracker
* depending on the job tracker we are in
*/
inline void ReportStatus(void) const {
if (hadoop_mode != 0) {
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
version_number, seq_counter);
}
}
protected:
/*! \brief enumeration of possible returning results from Try functions */
enum ReturnTypeEnum {
/*! \brief execution is successful */
kSuccess,
/*! \brief a link was reset by peer */
kConnReset,
/*! \brief received a zero length message */
kRecvZeroLen,
/*! \brief a neighbor node go down, the connection is dropped */
kSockError,
/*!
* \brief another node which is not my neighbor go down,
* get Out-of-Band exception notification from my neighbor
*/
kGetExcept
};
/*! \brief struct return type to avoid implicit conversion to int/bool */
struct ReturnType {
/*! \brief internal return type */
ReturnTypeEnum value;
// constructor
ReturnType() {}
ReturnType(ReturnTypeEnum value) : value(value){}
inline bool operator==(const ReturnTypeEnum &v) const {
return value == v;
}
inline bool operator!=(const ReturnTypeEnum &v) const {
return value != v;
}
};
/*! \brief translate errno to return type */
inline static ReturnType Errno2Return(int errsv) {
if (errsv == EAGAIN || errsv == EWOULDBLOCK) return kSuccess;
if (errsv == ECONNRESET) return kConnReset;
return kSockError;
}
// link record to a neighbor
struct LinkRecord {
public:
// socket to get data from/to link
utils::TCPSocket sock;
// rank of the node in this link
int rank;
// size of data readed from link
size_t size_read;
// size of data sent to the link
size_t size_write;
// pointer to buffer head
char *buffer_head;
// buffer size, in bytes
size_t buffer_size;
// constructor
LinkRecord(void)
: buffer_head(NULL), buffer_size(0) {
}
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n));
// make sure align to type_nbytes
buffer_size =
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size,
"too large type_nbytes=%lu, buffer_size=%lu",
type_nbytes, buffer_size);
// set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}
// reset the recv and sent size
inline void ResetSize(void) {
size_write = size_read = 0;
}
/*!
* \brief read data into ring-buffer, with care not to existing useful override data
* position after protect_start
* \param protect_start all data start from protect_start is still needed in buffer
* read shall not override this
* \return the type of reading
*/
inline ReturnType ReadToRingBuffer(size_t protect_start) {
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size;
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
if (nmax == 0) return kSuccess;
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
}
if (len == -1) return Errno2Return(errno);
size_read += static_cast<size_t>(len);
return kSuccess;
}
/*!
* \brief read data into array,
* this function can not be used together with ReadToRingBuffer
* a link can either read into the ring buffer, or existing array
* \param max_size maximum size of array
* \return true if it is an successful read, false if there is some error happens, check errno
*/
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
if (max_size == size_read) return kSuccess;
char *p = static_cast<char*>(recvbuf_);
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
}
if (len == -1) return Errno2Return(errno);
size_read += static_cast<size_t>(len);
return kSuccess;
}
/*!
* \brief write data in array to sock
* \param sendbuf_ head of array
* \param max_size maximum size of array
* \return true if it is an successful write, false if there is some error happens, check errno
*/
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
const char *p = static_cast<const char*>(sendbuf_);
ssize_t len = sock.Send(p + size_write, max_size - size_write);
if (len == -1) return Errno2Return(errno);
size_write += static_cast<size_t>(len);
return kSuccess;
}
private:
// recv buffer to get data from child
// aligned with 64 bits, will be able to perform 64 bits operations freely
std::vector<uint64_t> buffer_;
};
/*!
* \brief simple data structure that works like a vector
* but takes reference instead of space
*/
struct RefLinkVector {
std::vector<LinkRecord*> plinks;
inline LinkRecord &operator[](size_t i) {
return *plinks[i];
}
inline size_t size(void) const {
return plinks.size();
}
};
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker(void) const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
void ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*
* NOTE on Allreduce:
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and recving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
/*!
* \brief function used to report error when a link goes wrong
* \param link the pointer to the link who causes the error
* \param err the error type
*/
inline ReturnType ReportError(LinkRecord *link, ReturnType err) {
err_link = link; return err;
}
//---- data structure related to model ----
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
int seq_counter;
// version number of model
int version_number;
// whether the job is running in hadoop
int hadoop_mode;
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index;
// rank of parent node, can be -1
int parent_rank;
// sockets of all links this connects to
std::vector<LinkRecord> all_links;
// used to record the link where things goes wrong
LinkRecord *err_link;
// all the links in the reduction tree connection
RefLinkVector tree_links;
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
//----- meta information-----
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string task_id;
// uri of current host, to be set by Init
std::string host_uri;
// uri of tracker
std::string tracker_uri;
// port of tracker address
int tracker_port;
// port of slave process
int slave_port, nport_trial;
// reduce buffer size
size_t reduce_buffer_size;
// current rank
int rank;
// world size
int world_size;
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_BASE_H

View File

@@ -0,0 +1,100 @@
/*!
* \file allreduce_mock.h
* \brief Mock test module of AllReduce engine,
* insert failures in certain call point, to test if the engine is robust to failure
*
* \author Ignacio Cano, Tianqi Chen
*/
#ifndef RABIT_ALLREDUCE_MOCK_H
#define RABIT_ALLREDUCE_MOCK_H
#include <vector>
#include <rabit/engine.h>
#include <map>
#include "./allreduce_robust.h"
namespace rabit {
namespace engine {
class AllreduceMock : public AllreduceRobust {
public:
// constructor
AllreduceMock(void) {
num_trial = 0;
}
// destructor
virtual ~AllreduceMock(void) {}
virtual void SetParam(const char *name, const char *val) {
AllreduceRobust::SetParam(name, val);
// additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "mock")) {
MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d",
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
"invalid mock parameter");
mock_map[k] = 1;
}
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg);
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
AllreduceRobust::CheckPoint(global_model, local_model);
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model);
}
private:
// key to identify the mock stage
struct MockKey {
int rank;
int version;
int seqno;
int ntrial;
MockKey(void) {}
MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const {
return rank == b.rank &&
version == b.version &&
seqno == b.seqno &&
ntrial == b.ntrial;
}
inline bool operator<(const MockKey &b) const {
if (rank != b.rank) return rank < b.rank;
if (version != b.version) return version < b.version;
if (seqno != b.seqno) return seqno < b.seqno;
return ntrial < b.ntrial;
}
};
// number of failure trials
int num_trial;
// record all mock actions
std::map<MockKey, int> mock_map;
// used to generate all kinds of exceptions
inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) {
num_trial += 1;
fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name);
exit(-2);
}
}
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_MOCK_H

View File

@@ -0,0 +1,161 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust-inl.h
* \brief implementation of inline template function in AllreduceRobust
*
* \author Tianqi Chen
*/
#ifndef RABIT_ENGINE_ROBUST_INL_H_
#define RABIT_ENGINE_ROBUST_INL_H_
#include <vector>
namespace rabit {
namespace engine {
/*!
* \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge
* \param p_edge_out used to store output message from each of the edge
* \param func a function that defines the message passing rule
* Parameters of func:
* - node_value same as node_value in the main function
* - edge_in the array of input messages from each edge,
* this includes the output edge, which should be excluded
* - out_index array the index of output edge, the function should
* exclude the output edge when compute the message passing value
* Return of func:
* the function returns the output message based on the input message and node_value
*
* \tparam EdgeType type of edge message, must be simple struct
* \tparam NodeType type of node value
*/
template<typename NodeType, typename EdgeType>
inline AllreduceRobust::ReturnType
AllreduceRobust::MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());
// initialize the pointers
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
std::vector<EdgeType> &edge_in = *p_edge_in;
std::vector<EdgeType> &edge_out = *p_edge_out;
edge_in.resize(nlink);
edge_out.resize(nlink);
// stages in the process
// 0: recv messages from childs
// 1: send message to parent
// 2: recv message from parent
// 3: send message to childs
int stage = 0;
// if no childs, no need to, directly start passing message
if (nlink == static_cast<int>(parent_index != -1)) {
utils::Assert(parent_index == 0, "parent must be 0");
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
}
// while we have not passed the messages out
while (true) {
// for node with no parent, directly do stage 3
if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
}
break;
case 1: if (i == parent_index) selecter.WatchWrite(links[i].sock); break;
case 2: if (i == parent_index) selecter.WatchRead(links[i].sock); break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
done = false;
}
break;
default: utils::Error("invalid stage");
}
}
// finish all the stages, and write out message
if (done) break;
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (stage == 0) {
bool finished = true;
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
if (links[i].size_read != sizeof(EdgeType)) finished = false;
}
}
// if no parent, jump to stage 3, otherwise do stage 1
if (finished) {
if (parent_index != -1) {
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
} else {
for (int i = 0; i < nlink; ++i) {
edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
}
if (stage == 1) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[pid], ret);
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
}
if (stage == 2) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[pid], ret);
if (links[pid].size_read == sizeof(EdgeType)) {
for (int i = 0; i < nlink; ++i) {
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
if (stage == 3) {
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
} // namespace engine
} // namespace rabit
#endif // RABIT_ENGINE_ROBUST_INL_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,553 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.h
* \brief Robust implementation of Allreduce
* using TCP non-block socket and tree-shape reduction.
*
* This implementation considers the failure of nodes
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_ROBUST_H_
#define RABIT_ALLREDUCE_ROBUST_H_
#include <vector>
#include <string>
#include <algorithm>
#include "rabit/engine.h"
#include "./allreduce_base.h"
namespace rabit {
namespace engine {
/*! \brief implementation of fault tolerant all reduce engine */
class AllreduceRobust : public AllreduceBase {
public:
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
// initialize the manager
virtual void Init(void);
/*! \brief shutdown the engine */
virtual void Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL);
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false);
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
this->CheckPoint_(global_model, NULL, true);
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
}
ReConnectLinks("recover");
}
private:
// constant one byte out of band message to indicate error happening
// and mark for channel cleanup
static const char kOOBReset = 95;
// and mark for channel cleanup, after OOB signal
static const char kResetMark = 97;
// and mark for channel cleanup
static const char kResetAck = 97;
/*! \brief type of roles each node can play during recovery */
enum RecoverType {
/*! \brief current node have data */
kHaveData = 0,
/*! \brief current node request data */
kRequestData = 1,
/*! \brief current node only helps to pass data around */
kPassData = 2
};
/*!
* \brief summary of actions proposed in all nodes
* this data structure is used to make consensus decision
* about next action to take in the recovery mode
*/
struct ActionSummary {
// maximumly allowed sequence id
static const int kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint
static const int kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal
static const int kLocalCheckAck = (1 << 26) - 1;
//---------------------------------------------
// The following are bit mask of flag used in
//----------------------------------------------
// some node want to load check point
static const int kLoadCheck = 1;
// some node want to do check point
static const int kCheckPoint = 2;
// check point Ack, we use a two phase message in check point,
// this is the second phase of check pointing
static const int kCheckAck = 4;
// there are difference sequence number the nodes proposed
// this means we want to do recover execution of the lower sequence
// action instead of normal execution
static const int kDiffSeq = 8;
// constructor
ActionSummary(void) {}
// constructor of action
explicit ActionSummary(int flag, int minseqno = kSpecialOp) {
seqcode = (minseqno << 4) | flag;
}
// minimum number of all operations
inline int min_seqno(void) const {
return seqcode >> 4;
}
// whether the operation set contains a load_check
inline bool load_check(void) const {
return (seqcode & kLoadCheck) != 0;
}
// whether the operation set contains a check point
inline bool check_point(void) const {
return (seqcode & kCheckPoint) != 0;
}
// whether the operation set contains a check ack
inline bool check_ack(void) const {
return (seqcode & kCheckAck) != 0;
}
// whether the operation set contains different sequence number
inline bool diff_seq(void) const {
return (seqcode & kDiffSeq) != 0;
}
// returns the operation flag of the result
inline int flag(void) const {
return seqcode & 15;
}
// reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_;
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) {
int src_seqno = src[i].min_seqno();
int dst_seqno = dst[i].min_seqno();
int flag = src[i].flag() | dst[i].flag();
if (src_seqno == dst_seqno) {
dst[i] = ActionSummary(flag, src_seqno);
} else {
dst[i] = ActionSummary(flag | kDiffSeq,
std::min(src_seqno, dst_seqno));
}
}
}
private:
// internel sequence code
int seqcode;
};
/*! \brief data structure to remember result of Bcast and Allreduce calls */
class ResultBuffer {
public:
// constructor
ResultBuffer(void) {
this->Clear();
}
// clear the existing record
inline void Clear(void) {
seqno_.clear(); size_.clear();
rptr_.clear(); rptr_.push_back(0);
data_.clear();
}
// allocate temporal space
inline void *AllocTemp(size_t type_nbytes, size_t count) {
size_t size = type_nbytes * count;
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
utils::Assert(nhop != 0, "cannot allocate 0 size memory");
data_.resize(rptr_.back() + nhop);
return BeginPtr(data_) + rptr_.back();
}
// push the result in temp to the
inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
size_t size = type_nbytes * count;
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
if (seqno_.size() != 0) {
utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent");
}
seqno_.push_back(seqid);
rptr_.push_back(rptr_.back() + nhop);
size_.push_back(size);
utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
}
// return the stored result of seqid, if any
inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
*p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx];
}
// drop last stored result
inline void DropLast(void) {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
seqno_.pop_back();
rptr_.pop_back();
size_.pop_back();
data_.resize(rptr_.back());
}
// the sequence number of last stored result
inline int LastSeqNo(void) const {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
private:
// sequence number of each
std::vector<int> seqno_;
// pointer to the positions
std::vector<size_t> rptr_;
// actual size of each buffer
std::vector<size_t> size_;
// content of the buffer
std::vector<uint64_t> data_;
};
/*!
* \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint
*
* \param with_local whether the user calls CheckPoint with local model
*/
void LocalModelCheck(bool with_local);
/*!
* \brief internal implementation of checkpoint, support both lazy and normal way
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
* \param lazy_checkpt whether the action is lazy checkpoint
*
* \sa CheckPoint, LazyCheckPoint
*/
void CheckPoint_(const ISerializable *global_model,
const ISerializable *local_model,
bool lazy_checkpt);
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent
* before in all live links are discarded,
* This allows us to get a fresh start after error has happened
*
* TODO(tqchen): this function is not yet functioning was not used by engine,
* simple resetlink and reconnect strategy is used
*
* \return this function can return kSuccess or kSockError
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
*/
ReturnType TryResetLinks(void);
/*!
* \brief if err_type indicates an error
* recover links according to the error type reported
* if there is no error, return true
* \param err_type the type of error happening in the system
* \return true if err_type is kSuccess, false otherwise
*/
bool CheckAndRecover(ReturnType err_type);
/*!
* \brief try to run recover execution for a request action described by flag and seqno,
* the function will keep blocking to run possible recovery operations before the specified action,
* until the requested result is received by a recovering procedure,
* or the function discovers that the requested action is not yet executed, and return false
*
* \param buf the buffer to store the result
* \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kSpecialOp
*
* \return if this function can return true or false
* - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action
*/
bool RecoverExec(void *buf, size_t size, int flag,
int seqno = ActionSummary::kSpecialOp);
/*!
* \brief try to load check point
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to load the check point
* other nodes acts as collaborative roles to complete this request
*
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryLoadCheckPoint(bool requester);
/*!
* \brief try to get the result of operation specified by seqno
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to get the result
* other nodes acts as collaborative roles to complete this request
*
* \param buf the buffer to store the result, this parameter is only used when current node is requester
* \param size the total size of the buffer, this parameter is only used when current node is requester
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
/*!
* \brief try to decide the routing strategy for recovery
* \param role the current role of the node
* \param p_size used to store the size of the message, for node in state kHaveData,
* this size must be set correctly before calling the function
* for others, this surves as output parameter
* \param p_recvlink used to store the link current node should recv data from, if necessary
* this can be -1, which means current node have the data
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
*
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryRecoverData
*/
ReturnType TryDecideRouting(RecoverType role,
size_t *p_size,
int *p_recvlink,
std::vector<bool> *p_req_in);
/*!
* \brief try to finish the data recovery request,
* this function is used together with TryDecideRouting
* \param role the current role of the node
* \param sendrecvbuf_ the buffer to store the data to be sent/recived
* - if the role is kHaveData, this stores the data to be sent
* - if the role is kRequestData, this is the buffer to store the result
* - if the role is kPassData, this will not be used, and can be NULL
* \param size the size of the data, obtained from TryDecideRouting
* \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting
* \param req_in the request of each link to send data, obtained from TryDecideRouting
*
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryDecideRouting
*/
ReturnType TryRecoverData(RecoverType role,
void *sendrecvbuf_,
size_t size,
int recv_link,
const std::vector<bool> &req_in);
/*!
* \brief try to recover the local state, making each local state to be the result of itself
* plus replication of states in previous num_local_replica hops in the ring
*
* The input parameters must contain the valid local states available in current nodes,
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
* If there is sufficient information in the ring, when the function returns, local_chkpt will
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
* If there is no sufficient information in the ring, this function the number of checkpoints
* will be less than the specified value
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt);
/*!
* \brief try to checkpoint local state, this function is called in normal executation phase
* of checkpoint that contains local state
o * the input state must exactly one saved state(local state of current node),
* after complete, this function will get local state from previous num_local_replica nodes and put them
* into local_chkpt and local_rptr
*
* It is also OK to call TryRecoverLocalState instead,
* TryRecoverLocalState makes less assumption about the input, and requires more communications
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryRecoverLocalState
*/
ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt);
/*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure
* sendrecvbuf[0:read_ptr] are already provided by current node
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
* current node will send sendrecvbuf[write_ptr:write_end] to next link
* write_ptr will wait till the data is readed before sending the data
* this function requires read_end >= write_end
*
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
* \param read_ptr the initial read pointer
* \param read_end the ending position to read
* \param write_ptr the initial write pointer
* \param write_end the ending position to write
* \param read_link pointer to link to previous position in ring
* \param write_link pointer to link of next position in ring
*/
ReturnType RingPassing(void *senrecvbuf_,
size_t read_ptr,
size_t read_end,
size_t write_ptr,
size_t write_end,
LinkRecord *read_link,
LinkRecord *write_link);
/*!
* \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge
* \param p_edge_out used to store output message from each of the edge
* \param func a function that defines the message passing rule
* Parameters of func:
* - node_value same as node_value in the main function
* - edge_in the array of input messages from each edge,
* this includes the output edge, which should be excluded
* - out_index array the index of output edge, the function should
* exclude the output edge when compute the message passing value
* Return of func:
* the function returns the output message based on the input message and node_value
*
* \tparam EdgeType type of edge message, must be simple struct
* \tparam NodeType type of node value
*/
template<typename NodeType, typename EdgeType>
inline ReturnType MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType (*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index));
//---- recovery data structure ----
// the round of result buffer, used to mode the result
int result_buffer_round;
// result buffer of all reduce
ResultBuffer resbuf;
// last check point global model
std::string global_checkpoint;
// lazy checkpoint of global model
const ISerializable *global_lazycheck;
// number of replica for local state/model
int num_local_replica;
// number of default local replica
int default_local_replica;
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
int use_local_model;
// number of replica for global state/model
int num_global_replica;
// number of times recovery happens
int recover_counter;
// --- recovery data structure for local checkpoint
// there is two version of the data structure,
// at one time one version is valid and another is used as temp memory
// pointer to memory position in the local model
// local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2];
// storage for local model replicas
std::string local_chkpt[2];
// version of local checkpoint can be 1 or 0
int local_chkpt_version;
};
} // namespace engine
} // namespace rabit
// implementation of inline template function
#include "./allreduce_robust-inl.h"
#endif // RABIT_ALLREDUCE_ROBUST_H_

View File

@@ -0,0 +1,80 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.cc
* \brief this file governs which implementation of engine we are actually using
* provides an singleton of engine interface
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <rabit/engine.h>
#include "./allreduce_base.h"
#include "./allreduce_robust.h"
namespace rabit {
namespace engine {
// singleton sync manager
#ifndef RABIT_USE_MOCK
AllreduceRobust manager;
#else
AllreduceMock manager;
#endif
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
for (int i = 1; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
manager.SetParam(name, val);
}
}
manager.Init();
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
manager.Shutdown();
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
red, prepare_fun, prepare_arg);
}
// code for reduce handle
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size);
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
redfunc_, prepare_fun, prepare_arg);
}
} // namespace engine
} // namespace rabit

View File

@@ -0,0 +1,111 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
* This is usually NOT needed, use engine_mpi or engine for real distributed version
* \author Tianqi Chen
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <rabit/engine.h>
namespace rabit {
namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine(void) {
version_number = 0;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
}
virtual void InitAfterException(void) {
utils::Error("EmptyEngine is not fault tolerant");
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return 0;
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return 1;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
return std::string("");
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
};
// singleton sync manager
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0;
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {}
} // namespace engine
} // namespace rabit

View File

@@ -0,0 +1,16 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure
* \author Tianqi Chen
*/
// define use MOCK, os we will use mock Manager
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
// switch engine to AllreduceMock
#define RABIT_USE_MOCK
#include "./allreduce_mock.h"
#include "./engine.cc"

View File

@@ -0,0 +1,194 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mpi.cc
* \brief this file gives an implementation of engine interface using MPI,
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
*
* \author Tianqi Chen
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <mpi.h>
#include <cstdio>
#include "rabit/engine.h"
#include "rabit/utils.h"
namespace rabit {
namespace engine {
/*! \brief implementation of engine using MPI */
class MPIEngine : public IEngine {
public:
MPIEngine(void) {
version_number = 0;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Error("MPIEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
}
virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant");
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return MPI::COMM_WORLD.Get_rank();
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return MPI::COMM_WORLD.Get_size();
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
int len;
char name[MPI_MAX_PROCESSOR_NAME];
MPI::Get_processor_name(name, len);
name[len] = '\0';
return std::string(name);
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}
private:
int version_number;
};
// singleton sync manager
MPIEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
MPI::Init(argc, argv);
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
MPI::Finalize();
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// transform enum to MPI data type
inline MPI::Datatype GetType(mpi::DataType dtype) {
using namespace mpi;
switch (dtype) {
case kChar: return MPI::CHAR;
case kUChar: return MPI::BYTE;
case kInt: return MPI::INT;
case kUInt: return MPI::UNSIGNED;
case kLong: return MPI::LONG;
case kULong: return MPI::UNSIGNED_LONG;
case kFloat: return MPI::FLOAT;
case kDouble: return MPI::DOUBLE;
}
utils::Error("unknown mpi::DataType");
return MPI::CHAR;
}
// transform enum to MPI OP
inline MPI::Op GetOp(mpi::OpType otype) {
using namespace mpi;
switch (otype) {
case kMax: return MPI::MAX;
case kMin: return MPI::MIN;
case kSum: return MPI::SUM;
case kBitwiseOR: return MPI::BOR;
}
utils::Error("unknown mpi::OpType");
return MPI::MAX;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
count, GetType(dtype), GetOp(op));
}
// code for reduce handle
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {
if (handle_ != NULL) {
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
op->Free();
delete op;
}
if (htype_ != NULL) {
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
dtype->Free();
delete dtype;
}
}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return dtype.Get_size();
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
if (type_nbytes != 0) {
MPI::Datatype *dtype = new MPI::Datatype();
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
dtype->Commit();
created_type_nbytes_ = type_nbytes;
htype_ = dtype;
}
MPI::Op *op = new MPI::Op();
MPI::User_function *pf = redfunc;
op->Init(pf, true);
handle_ = op;
}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg) {
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
if (created_type_nbytes_ != type_nbytes || dtype == NULL) {
if (dtype == NULL) {
dtype = new MPI::Datatype();
} else {
dtype->Free();
}
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
dtype->Commit();
created_type_nbytes_ = type_nbytes;
}
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
}
} // namespace engine
} // namespace rabit

499
subtree/rabit/src/socket.h Normal file
View File

@@ -0,0 +1,499 @@
/*!
* Copyright (c) 2014 by Contributors
* \file socket.h
* \brief this file aims to provide a wrapper of sockets
* \author Tianqi Chen
*/
#ifndef RABIT_SOCKET_H_
#define RABIT_SOCKET_H_
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif
#else
#include <fcntl.h>
#include <netdb.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <sys/ioctl.h>
#endif
#include <string>
#include <cstring>
#include "rabit/utils.h"
#if defined(_WIN32)
typedef int ssize_t;
typedef int sock_size_t;
#else
typedef int SOCKET;
typedef size_t sock_size_t;
const int INVALID_SOCKET = -1;
#endif
namespace rabit {
namespace utils {
/*! \brief data structure for network address */
struct SockAddr {
sockaddr_in addr;
// constructor
SockAddr(void) {}
SockAddr(const char *url, int port) {
this->Set(url, port);
}
inline static std::string GetHostName(void) {
std::string buf; buf.resize(256);
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
return std::string(buf.c_str());
}
/*!
* \brief set the address
* \param url the url of the address
* \param port the port of address
*/
inline void Set(const char *host, int port) {
hostent *hp = gethostbyname(host);
Check(hp != NULL, "cannot obtain address of %s", host);
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
memcpy(&addr.sin_addr, hp->h_addr_list[0], hp->h_length);
}
/*! \brief return port of the address*/
inline int port(void) const {
return ntohs(addr.sin_port);
}
/*! \return a string representation of the address */
inline std::string AddrStr(void) const {
std::string buf; buf.resize(256);
#ifdef _WIN32
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
&buf[0], buf.length());
#else
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif
Assert(s != NULL, "cannot decode address");
return std::string(s);
}
};
/*!
* \brief base class containing common operations of TCP and UDP sockets
*/
class Socket {
public:
/*! \brief the file descriptor of socket */
SOCKET sockfd;
// default conversion to int
inline operator SOCKET() const {
return sockfd;
}
/*!
* \brief start up the socket module
* call this before using the sockets
*/
inline static void Startup(void) {
#ifdef _WIN32
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
Socket::Error("Startup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
utils::Error("Could not find a usable version of Winsock.dll\n");
}
#endif
}
/*!
* \brief shutdown the socket module after use, all sockets need to be closed
*/
inline static void Finalize(void) {
#ifdef _WIN32
WSACleanup();
#endif
}
/*!
* \brief set this socket to use non-blocking mode
* \param non_block whether set it to be non-block, if it is false
* it will set it back to block mode
*/
inline void SetNonBlock(bool non_block) {
#ifdef _WIN32
u_long mode = non_block ? 1 : 0;
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
Socket::Error("SetNonBlock");
}
#else
int flag = fcntl(sockfd, F_GETFL, 0);
if (flag == -1) {
Socket::Error("SetNonBlock-1");
}
if (non_block) {
flag |= O_NONBLOCK;
} else {
flag &= ~O_NONBLOCK;
}
if (fcntl(sockfd, F_SETFL, flag) == -1) {
Socket::Error("SetNonBlock-2");
}
#endif
}
/*!
* \brief bind the socket to an address
* \param addr
*/
inline void Bind(const SockAddr &addr) {
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == -1) {
Socket::Error("Bind");
}
}
/*!
* \brief try bind the socket to host, from start_port to end_port
* \param start_port starting port number to try
* \param end_port ending port number to try
* \return the port successfully bind to, return -1 if failed to bind any port
*/
inline int TryBindHost(int start_port, int end_port) {
// TODO(tqchen) add prefix check
for (int port = start_port; port < end_port; ++port) {
SockAddr addr("0.0.0.0", port);
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0) {
return port;
}
#if defined(_WIN32)
if (WSAGetLastError() != WSAEADDRINUSE) {
Socket::Error("TryBindHost");
}
#else
if (errno != EADDRINUSE) {
Socket::Error("TryBindHost");
}
#endif
}
return -1;
}
/*! \brief get last error code if any */
inline int GetSockError(void) const {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
Error("GetSockError");
}
return error;
}
/*! \brief check if anything bad happens */
inline bool BadSocket(void) const {
if (IsClosed()) return true;
int err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
/*! \brief check if socket is already closed */
inline bool IsClosed(void) const {
return sockfd == INVALID_SOCKET;
}
/*! \brief close the socket */
inline void Close(void) {
if (sockfd != INVALID_SOCKET) {
#ifdef _WIN32
closesocket(sockfd);
#else
close(sockfd);
#endif
sockfd = INVALID_SOCKET;
} else {
Error("Socket::Close double close the socket or close without create");
}
}
// report an socket error
inline static void Error(const char *msg) {
int errsv = errno;
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
}
protected:
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
}
};
/*!
* \brief a wrapper of TCP socket that hopefully be cross platform
*/
class TCPSocket : public Socket{
public:
// constructor
TCPSocket(void) : Socket(INVALID_SOCKET) {
}
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
}
/*!
* \brief enable/disable TCP keepalive
* \param keepalive whether to set the keep alive option on
*/
inline void SetKeepAlive(bool keepalive) {
int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive");
}
}
/*!
* \brief create the socket, call this before using socket
* \param af domain
*/
inline void Create(int af = PF_INET) {
sockfd = socket(PF_INET, SOCK_STREAM, 0);
if (sockfd == INVALID_SOCKET) {
Socket::Error("Create");
}
}
/*!
* \brief perform listen of the socket
* \param backlog backlog parameter
*/
inline void Listen(int backlog = 16) {
listen(sockfd, backlog);
}
/*! \brief get a new connection */
TCPSocket Accept(void) {
SOCKET newfd = accept(sockfd, NULL, NULL);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
return TCPSocket(newfd);
}
/*!
* \brief decide whether the socket is at OOB mark
* \return 1 if at mark, 0 if not, -1 if an error occured
*/
inline int AtMark(void) const {
#ifdef _WIN32
unsigned long atmark;
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
#else
int atmark;
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
#endif
return static_cast<int>(atmark);
}
/*!
* \brief connect to an address
* \param addr the address to connect to
* \return whether connect is successful
*/
inline bool Connect(const SockAddr &addr) {
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
sizeof(addr.addr)) == 0;
}
/*!
* \brief send data using the socket
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually sent
* return -1 if error occurs
*/
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
const char *buf = reinterpret_cast<const char*>(buf_);
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
}
/*!
* \brief receive data using the socket
* \param buf_ the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually received
* return -1 if error occurs
*/
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
char *buf = reinterpret_cast<char*>(buf_);
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
}
/*!
* \brief peform block write that will attempt to send all data out
* can still return smaller than request when error occurs
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \return size of data actually sent
*/
inline size_t SendAll(const void *buf_, size_t len) {
const char *buf = reinterpret_cast<const char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
Socket::Error("SendAll");
}
buf += ret;
ndone += ret;
}
return ndone;
}
/*!
* \brief peforma block read that will attempt to read all data
* can still return smaller than request when error occurs
* \param buf_ the buffer pointer
* \param len length of data to recv
* \return size of data actually sent
*/
inline size_t RecvAll(void *buf_, size_t len) {
char *buf = reinterpret_cast<char*>(buf_);
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(sockfd, buf,
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
if (ret == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
Socket::Error("RecvAll");
}
if (ret == 0) return ndone;
buf += ret;
ndone += ret;
}
return ndone;
}
/*!
* \brief send a string over network
* \param str the string to be sent
*/
inline void SendStr(const std::string &str) {
int len = static_cast<int>(str.length());
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
"error during send SendStr");
if (len != 0) {
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
"error during send SendStr");
}
}
/*!
* \brief recv a string from network
* \param out_str the string to receive
*/
inline void RecvStr(std::string *out_str) {
int len;
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
"error during send RecvStr");
out_str->resize(len);
if (len != 0) {
utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
"error during send SendStr");
}
}
};
/*! \brief helper data structure to perform select */
struct SelectHelper {
public:
SelectHelper(void) {
FD_ZERO(&read_set);
FD_ZERO(&write_set);
FD_ZERO(&except_set);
maxfd = 0;
}
/*!
* \brief add file descriptor to watch for read
* \param fd file descriptor to be watched
*/
inline void WatchRead(SOCKET fd) {
FD_SET(fd, &read_set);
if (fd > maxfd) maxfd = fd;
}
/*!
* \brief add file descriptor to watch for write
* \param fd file descriptor to be watched
*/
inline void WatchWrite(SOCKET fd) {
FD_SET(fd, &write_set);
if (fd > maxfd) maxfd = fd;
}
/*!
* \brief add file descriptor to watch for exception
* \param fd file descriptor to be watched
*/
inline void WatchException(SOCKET fd) {
FD_SET(fd, &except_set);
if (fd > maxfd) maxfd = fd;
}
/*!
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
*/
inline bool CheckRead(SOCKET fd) const {
return FD_ISSET(fd, &read_set) != 0;
}
/*!
* \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status
*/
inline bool CheckWrite(SOCKET fd) const {
return FD_ISSET(fd, &write_set) != 0;
}
/*!
* \brief Check if the descriptor has any exception
* \param fd file descriptor to check status
*/
inline bool CheckExcept(SOCKET fd) const {
return FD_ISSET(fd, &except_set) != 0;
}
/*!
* \brief wait for exception event on a single descriptor
* \param fd the file descriptor to wait the event for
* \param timeout the timeout counter, can be 0, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs
*/
inline static int WaitExcept(SOCKET fd, long timeout = 0) {
fd_set wait_set;
FD_ZERO(&wait_set);
FD_SET(fd, &wait_set);
return Select_(static_cast<int>(fd + 1),
NULL, NULL, &wait_set, timeout);
}
/*!
* \brief peform select on the set defined
* \param select_read whether to watch for read event
* \param select_write whether to watch for write event
* \param select_except whether to watch for exception event
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
* \return number of active descriptors selected,
* return -1 if error occurs
*/
inline int Select(long timeout = 0) {
int ret = Select_(static_cast<int>(maxfd + 1),
&read_set, &write_set, &except_set, timeout);
if (ret == -1) {
Socket::Error("Select");
}
return ret;
}
private:
inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) {
#if !defined(_WIN32)
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
#endif
if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL);
} else {
timeval tm;
tm.tv_usec = (timeout % 1000) * 1000;
tm.tv_sec = timeout / 1000;
return select(maxfd, rfds, wfds, efds, &tm);
}
}
SOCKET maxfd;
fd_set read_set, write_set, except_set;
};
} // namespace utils
} // namespace rabit
#endif // RABIT_SOCKET_H_