Merge commit 'd87691ec603db325d5b1c5db1186295a748df7cc' as 'subtree/rabit'
This commit is contained in:
6
subtree/rabit/src/README.md
Normal file
6
subtree/rabit/src/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
Source Files of Rabit
|
||||
====
|
||||
* This folder contains the source files of rabit library
|
||||
* The library headers are in folder [include](../include)
|
||||
* The .h files in this folder are internal header files that are only used by rabit and will not be seen by users
|
||||
|
||||
590
subtree/rabit/src/allreduce_base.cc
Normal file
590
subtree/rabit/src/allreduce_base.cc
Normal file
@@ -0,0 +1,590 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_base.cc
|
||||
* \brief Basic implementation of AllReduce
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <map>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include "./allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
// constructor
|
||||
AllreduceBase::AllreduceBase(void) {
|
||||
tracker_uri = "NULL";
|
||||
tracker_port = 9000;
|
||||
host_uri = "";
|
||||
slave_port = 9010;
|
||||
nport_trial = 1000;
|
||||
rank = 0;
|
||||
world_size = -1;
|
||||
hadoop_mode = 0;
|
||||
version_number = 0;
|
||||
task_id = "NULL";
|
||||
err_link = NULL;
|
||||
this->SetParam("rabit_reduce_buffer", "256MB");
|
||||
}
|
||||
|
||||
// initialization function
|
||||
void AllreduceBase::Init(void) {
|
||||
// setup from enviroment variables
|
||||
{
|
||||
// handling for hadoop
|
||||
const char *task_id = getenv("mapred_tip_id");
|
||||
if (task_id == NULL) {
|
||||
task_id = getenv("mapreduce_task_id");
|
||||
}
|
||||
if (hadoop_mode != 0) {
|
||||
utils::Check(task_id != NULL,
|
||||
"hadoop_mode is set but cannot find mapred_task_id");
|
||||
}
|
||||
if (task_id != NULL) {
|
||||
this->SetParam("rabit_task_id", task_id);
|
||||
this->SetParam("rabit_hadoop_mode", "1");
|
||||
}
|
||||
const char *attempt_id = getenv("mapred_task_id");
|
||||
if (attempt_id != 0) {
|
||||
const char *att = strrchr(attempt_id, '_');
|
||||
int num_trial;
|
||||
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
|
||||
this->SetParam("rabit_num_trial", att + 1);
|
||||
}
|
||||
}
|
||||
// handling for hadoop
|
||||
const char *num_task = getenv("mapred_map_tasks");
|
||||
if (num_task == NULL) {
|
||||
num_task = getenv("mapreduce_job_maps");
|
||||
}
|
||||
if (hadoop_mode != 0) {
|
||||
utils::Check(num_task != NULL,
|
||||
"hadoop_mode is set but cannot find mapred_map_tasks");
|
||||
}
|
||||
if (num_task != NULL) {
|
||||
this->SetParam("rabit_world_size", num_task);
|
||||
}
|
||||
}
|
||||
// clear the setting before start reconnection
|
||||
this->rank = -1;
|
||||
//---------------------
|
||||
// start socket
|
||||
utils::Socket::Startup();
|
||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||
this->host_uri = utils::SockAddr::GetHostName();
|
||||
// get information from tracker
|
||||
this->ReConnectLinks();
|
||||
}
|
||||
|
||||
void AllreduceBase::Shutdown(void) {
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
all_links[i].sock.Close();
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
|
||||
if (tracker_uri == "NULL") return;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
tracker.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
}
|
||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
utils::Printf("%s", msg.c_str()); return;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("print"));
|
||||
tracker.SendStr(msg);
|
||||
tracker.Close();
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
|
||||
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
|
||||
if (!strcmp(name, "rabit_task_id")) task_id = val;
|
||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
||||
if (!strcmp(name, "rabit_reduce_buffer")) {
|
||||
char unit;
|
||||
uint64_t amount;
|
||||
if (sscanf(val, "%lu%c", &amount, &unit) == 2) {
|
||||
switch (unit) {
|
||||
case 'B': reduce_buffer_size = (amount + 7)/ 8; break;
|
||||
case 'K': reduce_buffer_size = amount << 7UL; break;
|
||||
case 'M': reduce_buffer_size = amount << 17UL; break;
|
||||
case 'G': reduce_buffer_size = amount << 27UL; break;
|
||||
default: utils::Error("invalid format for reduce buffer");
|
||||
}
|
||||
} else {
|
||||
utils::Error("invalid format for reduce_buffer,"\
|
||||
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}");
|
||||
}
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
int magic = kMagic;
|
||||
// get information from tracker
|
||||
utils::TCPSocket tracker;
|
||||
tracker.Create();
|
||||
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
|
||||
utils::Socket::Error("Connect");
|
||||
}
|
||||
using utils::Assert;
|
||||
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
"ReConnectLink failure 1");
|
||||
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
|
||||
"ReConnectLink failure 2");
|
||||
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
|
||||
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 3");
|
||||
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 3");
|
||||
tracker.SendStr(task_id);
|
||||
return tracker;
|
||||
}
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string(cmd));
|
||||
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
std::map<int, int> tree_neighbors;
|
||||
using utils::Assert;
|
||||
// get new ranks
|
||||
int newrank, num_neighbors;
|
||||
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
|
||||
sizeof(parent_rank), "ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(rank == -1 || newrank == rank,
|
||||
"must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
int nrank;
|
||||
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||
"ReConnectLink failure 4");
|
||||
tree_neighbors[nrank] = 1;
|
||||
}
|
||||
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
// create listening socket
|
||||
utils::TCPSocket sock_listen;
|
||||
sock_listen.Create();
|
||||
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
|
||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||
sock_listen.Listen();
|
||||
|
||||
// get number of to connect and number of to accept nodes from tracker
|
||||
int num_conn, num_accept, num_error = 1;
|
||||
do {
|
||||
// send over good links
|
||||
std::vector<int> good_link;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
good_link.push_back(static_cast<int>(all_links[i].rank));
|
||||
} else {
|
||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||
}
|
||||
}
|
||||
int ngood = static_cast<int>(good_link.size());
|
||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||
"ReConnectLink failure 5");
|
||||
for (size_t i = 0; i < good_link.size(); ++i) {
|
||||
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
||||
sizeof(good_link[i]), "ReConnectLink failure 6");
|
||||
}
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
||||
sizeof(num_accept), "ReConnectLink failure 8");
|
||||
num_error = 0;
|
||||
for (int i = 0; i < num_conn; ++i) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
tracker.RecvStr(&hname);
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
||||
"ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
||||
"ReConnectLink failure 10");
|
||||
r.sock.Create();
|
||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||
num_error += 1; r.sock.Close(); continue;
|
||||
}
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 12");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 13");
|
||||
utils::Check(hrank == r.rank,
|
||||
"ReConnectLink failure, link rank inconsistent");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == hrank) {
|
||||
Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||
"ReConnectLink failure 14");
|
||||
} while (num_error != 0);
|
||||
// send back socket listening port to tracker
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
||||
"ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
tracker.Close();
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
LinkRecord r;
|
||||
r.sock = sock_listen.Accept();
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 15");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 15");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == r.rank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
// close listening sockets
|
||||
sock_listen.Close();
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetKeepAlive(true);
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
}
|
||||
tree_links.plinks.push_back(&all_links[i]);
|
||||
}
|
||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||
}
|
||||
Assert(parent_rank == -1 || parent_index != -1,
|
||||
"cannot find parent in the link");
|
||||
Assert(prev_rank == -1 || ring_prev != NULL,
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != NULL,
|
||||
"cannot find next ring in the link");
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
*
|
||||
* NOTE on Allreduce:
|
||||
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||
* It only means the current node get the correct result of Allreduce.
|
||||
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || count == 0) return kSuccess;
|
||||
// total size of message
|
||||
const size_t total_size = type_nbytes * count;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// send recv buffer
|
||||
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
// size of space that we already performs reduce in up pass
|
||||
size_t size_up_reduce = 0;
|
||||
// size of space that we have already passed to parent
|
||||
size_t size_up_out = 0;
|
||||
// size of message we received, and send in the down pass
|
||||
size_t size_down_in = 0;
|
||||
// initialize the link ring-buffer and pointer
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
|
||||
}
|
||||
links[i].ResetSize();
|
||||
}
|
||||
// if no childs, no need to reduce
|
||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||
size_up_reduce = total_size;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
// select helper
|
||||
bool finished = true;
|
||||
utils::SelectHelper selecter;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i == parent_index) {
|
||||
if (size_down_in != total_size) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
// only watch for exception in live channels
|
||||
selecter.WatchException(links[i].sock);
|
||||
finished = false;
|
||||
}
|
||||
if (size_up_out != total_size && size_up_out < size_up_reduce) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
}
|
||||
} else {
|
||||
if (links[i].size_read != total_size) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
}
|
||||
// size_write <= size_read
|
||||
if (links[i].size_write != total_size){
|
||||
if (links[i].size_write < size_down_in) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
}
|
||||
// only watch for exception in live channels
|
||||
selecter.WatchException(links[i].sock);
|
||||
finished = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
// finish runing allreduce
|
||||
if (finished) break;
|
||||
// select must return
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
if (selecter.CheckExcept(links[i].sock)) {
|
||||
return ReportError(&links[i], kGetExcept);
|
||||
}
|
||||
}
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
|
||||
ReturnType ret = links[i].ReadToRingBuffer(size_up_out);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
// this node have childs, peform reduce
|
||||
if (nlink > static_cast<int>(parent_index != -1)) {
|
||||
size_t buffer_size = 0;
|
||||
// do upstream reduce
|
||||
size_t max_reduce = total_size;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
max_reduce= std::min(max_reduce, links[i].size_read);
|
||||
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
|
||||
"buffer size inconsistent");
|
||||
buffer_size = links[i].buffer_size;
|
||||
}
|
||||
}
|
||||
utils::Assert(buffer_size != 0, "must assign buffer_size");
|
||||
// round to type_n4bytes
|
||||
max_reduce = (max_reduce / type_nbytes * type_nbytes);
|
||||
// peform reduce, can be at most two rounds
|
||||
while (size_up_reduce < max_reduce) {
|
||||
// start position
|
||||
size_t start = size_up_reduce % buffer_size;
|
||||
// peform read till end of buffer
|
||||
size_t nread = std::min(buffer_size - start,
|
||||
max_reduce - size_up_reduce);
|
||||
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
reducer(links[i].buffer_head + start,
|
||||
sendrecvbuf + size_up_reduce,
|
||||
static_cast<int>(nread / type_nbytes),
|
||||
MPI::Datatype(type_nbytes));
|
||||
}
|
||||
}
|
||||
size_up_reduce += nread;
|
||||
}
|
||||
}
|
||||
if (parent_index != -1) {
|
||||
// pass message up to parent, can pass data that are already been reduced
|
||||
if (size_up_out < size_up_reduce) {
|
||||
ssize_t len = links[parent_index].sock.
|
||||
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
|
||||
if (len != -1) {
|
||||
size_up_out += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return(errno);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[parent_index], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
// read data from parent
|
||||
if (selecter.CheckRead(links[parent_index].sock) &&
|
||||
total_size > size_down_in) {
|
||||
ssize_t len = links[parent_index].sock.
|
||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||
if (len == 0) {
|
||||
links[parent_index].sock.Close();
|
||||
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||
}
|
||||
if (len != -1) {
|
||||
size_down_in += static_cast<size_t>(len);
|
||||
utils::Assert(size_down_in <= size_up_out,
|
||||
"Allreduce: boundary error");
|
||||
} else {
|
||||
ReturnType ret = Errno2Return(errno);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[parent_index], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// this is root, can use reduce as most recent point
|
||||
size_down_in = size_up_out = size_up_reduce;
|
||||
}
|
||||
// can pass message down to childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && links[i].size_write < size_down_in) {
|
||||
ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param total_size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
AllreduceBase::ReturnType
|
||||
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0 || total_size == 0) return kSuccess;
|
||||
utils::Check(root < world_size,
|
||||
"Broadcast: root should be smaller than world size");
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// size of space already read from data
|
||||
size_t size_in = 0;
|
||||
// input link, -2 means unknown yet, -1 means this is root
|
||||
int in_link = -2;
|
||||
|
||||
// initialize the link statistics
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
}
|
||||
// root have all the data
|
||||
if (this->rank == root) {
|
||||
size_in = total_size;
|
||||
in_link = -1;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
bool finished = true;
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (in_link == -2) {
|
||||
selecter.WatchRead(links[i].sock); finished = false;
|
||||
}
|
||||
if (i == in_link && links[i].size_read != total_size) {
|
||||
selecter.WatchRead(links[i].sock); finished = false;
|
||||
}
|
||||
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
|
||||
if (links[i].size_write < size_in) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
}
|
||||
finished = false;
|
||||
}
|
||||
selecter.WatchException(links[i].sock);
|
||||
}
|
||||
// finish running
|
||||
if (finished) break;
|
||||
// select
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
if (selecter.CheckExcept(links[i].sock)) {
|
||||
return ReportError(&links[i], kGetExcept);
|
||||
}
|
||||
}
|
||||
if (in_link == -2) {
|
||||
// probe in-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (selecter.CheckRead(links[i].sock)) {
|
||||
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
size_in = links[i].size_read;
|
||||
if (size_in != 0) {
|
||||
in_link = i; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// read from in link
|
||||
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
|
||||
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[in_link], ret);
|
||||
}
|
||||
size_in = links[in_link].size_read;
|
||||
}
|
||||
}
|
||||
// send data to all out-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != in_link && links[i].size_write < size_in) {
|
||||
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in);
|
||||
if (ret != kSuccess) {
|
||||
return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
436
subtree/rabit/src/allreduce_base.h
Normal file
436
subtree/rabit/src/allreduce_base.h
Normal file
@@ -0,0 +1,436 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_base.h
|
||||
* \brief Basic implementation of AllReduce
|
||||
* using TCP non-block socket and tree-shape reduction.
|
||||
*
|
||||
* This implementation provides basic utility of AllReduce and Broadcast
|
||||
* without considering node failure
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_BASE_H_
|
||||
#define RABIT_ALLREDUCE_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "rabit/utils.h"
|
||||
#include "rabit/engine.h"
|
||||
#include "./socket.h"
|
||||
|
||||
namespace MPI {
|
||||
// MPI data type to be compatible with existing MPI interface
|
||||
class Datatype {
|
||||
public:
|
||||
size_t type_size;
|
||||
explicit Datatype(size_t type_size) : type_size(type_size) {}
|
||||
};
|
||||
}
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief implementation of basic Allreduce engine */
|
||||
class AllreduceBase : public IEngine {
|
||||
public:
|
||||
// magic number to verify server
|
||||
static const int kMagic = 0xff99;
|
||||
// constant one byte out of band message to indicate error happening
|
||||
AllreduceBase(void);
|
||||
virtual ~AllreduceBase(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(void);
|
||||
// shutdown the engine
|
||||
virtual void Shutdown(void);
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*!
|
||||
* \brief print the msg in the tracker,
|
||||
* this function can be used to communicate the information of the progress to
|
||||
* the user who monitors the tracker
|
||||
* \param msg message to be printed in the tracker
|
||||
*/
|
||||
virtual void TrackerPrint(const std::string &msg);
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
return rank;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
virtual int GetWorldSize(void) const {
|
||||
if (world_size == -1) return 1;
|
||||
return world_size;
|
||||
}
|
||||
/*! \brief get rank */
|
||||
virtual std::string GetHost(void) const {
|
||||
return host_uri;
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_,
|
||||
type_nbytes, count, reducer) == kSuccess,
|
||||
"Allreduce failed");
|
||||
}
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
||||
"Broadcast failed");
|
||||
}
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met(see detailed expplaination).
|
||||
*
|
||||
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||
* In another words, global_model model can be changed only between last call of
|
||||
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||
*
|
||||
* For example, suppose the calling sequence is:
|
||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||
*
|
||||
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||
* improve efficiency of the program.
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
* \sa LoadCheckPoint, CheckPoint
|
||||
*/
|
||||
virtual int VersionNumber(void) const {
|
||||
return version_number;
|
||||
}
|
||||
/*!
|
||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throw an exception out,
|
||||
* this function is only used for test purpose
|
||||
*/
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("InitAfterException: not implemented");
|
||||
}
|
||||
/*!
|
||||
* \brief report current status to the job tracker
|
||||
* depending on the job tracker we are in
|
||||
*/
|
||||
inline void ReportStatus(void) const {
|
||||
if (hadoop_mode != 0) {
|
||||
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
|
||||
version_number, seq_counter);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/*! \brief enumeration of possible returning results from Try functions */
|
||||
enum ReturnTypeEnum {
|
||||
/*! \brief execution is successful */
|
||||
kSuccess,
|
||||
/*! \brief a link was reset by peer */
|
||||
kConnReset,
|
||||
/*! \brief received a zero length message */
|
||||
kRecvZeroLen,
|
||||
/*! \brief a neighbor node go down, the connection is dropped */
|
||||
kSockError,
|
||||
/*!
|
||||
* \brief another node which is not my neighbor go down,
|
||||
* get Out-of-Band exception notification from my neighbor
|
||||
*/
|
||||
kGetExcept
|
||||
};
|
||||
/*! \brief struct return type to avoid implicit conversion to int/bool */
|
||||
struct ReturnType {
|
||||
/*! \brief internal return type */
|
||||
ReturnTypeEnum value;
|
||||
// constructor
|
||||
ReturnType() {}
|
||||
ReturnType(ReturnTypeEnum value) : value(value){}
|
||||
inline bool operator==(const ReturnTypeEnum &v) const {
|
||||
return value == v;
|
||||
}
|
||||
inline bool operator!=(const ReturnTypeEnum &v) const {
|
||||
return value != v;
|
||||
}
|
||||
};
|
||||
/*! \brief translate errno to return type */
|
||||
inline static ReturnType Errno2Return(int errsv) {
|
||||
if (errsv == EAGAIN || errsv == EWOULDBLOCK) return kSuccess;
|
||||
if (errsv == ECONNRESET) return kConnReset;
|
||||
return kSockError;
|
||||
}
|
||||
// link record to a neighbor
|
||||
struct LinkRecord {
|
||||
public:
|
||||
// socket to get data from/to link
|
||||
utils::TCPSocket sock;
|
||||
// rank of the node in this link
|
||||
int rank;
|
||||
// size of data readed from link
|
||||
size_t size_read;
|
||||
// size of data sent to the link
|
||||
size_t size_write;
|
||||
// pointer to buffer head
|
||||
char *buffer_head;
|
||||
// buffer size, in bytes
|
||||
size_t buffer_size;
|
||||
// constructor
|
||||
LinkRecord(void)
|
||||
: buffer_head(NULL), buffer_size(0) {
|
||||
}
|
||||
// initialize buffer
|
||||
inline void InitBuffer(size_t type_nbytes, size_t count,
|
||||
size_t reduce_buffer_size) {
|
||||
size_t n = (type_nbytes * count + 7)/ 8;
|
||||
buffer_.resize(std::min(reduce_buffer_size, n));
|
||||
// make sure align to type_nbytes
|
||||
buffer_size =
|
||||
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
|
||||
utils::Assert(type_nbytes <= buffer_size,
|
||||
"too large type_nbytes=%lu, buffer_size=%lu",
|
||||
type_nbytes, buffer_size);
|
||||
// set buffer head
|
||||
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
|
||||
}
|
||||
// reset the recv and sent size
|
||||
inline void ResetSize(void) {
|
||||
size_write = size_read = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief read data into ring-buffer, with care not to existing useful override data
|
||||
* position after protect_start
|
||||
* \param protect_start all data start from protect_start is still needed in buffer
|
||||
* read shall not override this
|
||||
* \return the type of reading
|
||||
*/
|
||||
inline ReturnType ReadToRingBuffer(size_t protect_start) {
|
||||
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
|
||||
size_t ngap = size_read - protect_start;
|
||||
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
|
||||
size_t offset = size_read % buffer_size;
|
||||
size_t nmax = std::min(buffer_size - ngap, buffer_size - offset);
|
||||
if (nmax == 0) return kSuccess;
|
||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return(errno);
|
||||
size_read += static_cast<size_t>(len);
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief read data into array,
|
||||
* this function can not be used together with ReadToRingBuffer
|
||||
* a link can either read into the ring buffer, or existing array
|
||||
* \param max_size maximum size of array
|
||||
* \return true if it is an successful read, false if there is some error happens, check errno
|
||||
*/
|
||||
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
|
||||
if (max_size == size_read) return kSuccess;
|
||||
char *p = static_cast<char*>(recvbuf_);
|
||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||
// length equals 0, remote disconnected
|
||||
if (len == 0) {
|
||||
sock.Close(); return kRecvZeroLen;
|
||||
}
|
||||
if (len == -1) return Errno2Return(errno);
|
||||
size_read += static_cast<size_t>(len);
|
||||
return kSuccess;
|
||||
}
|
||||
/*!
|
||||
* \brief write data in array to sock
|
||||
* \param sendbuf_ head of array
|
||||
* \param max_size maximum size of array
|
||||
* \return true if it is an successful write, false if there is some error happens, check errno
|
||||
*/
|
||||
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
|
||||
const char *p = static_cast<const char*>(sendbuf_);
|
||||
ssize_t len = sock.Send(p + size_write, max_size - size_write);
|
||||
if (len == -1) return Errno2Return(errno);
|
||||
size_write += static_cast<size_t>(len);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
private:
|
||||
// recv buffer to get data from child
|
||||
// aligned with 64 bits, will be able to perform 64 bits operations freely
|
||||
std::vector<uint64_t> buffer_;
|
||||
};
|
||||
/*!
|
||||
* \brief simple data structure that works like a vector
|
||||
* but takes reference instead of space
|
||||
*/
|
||||
struct RefLinkVector {
|
||||
std::vector<LinkRecord*> plinks;
|
||||
inline LinkRecord &operator[](size_t i) {
|
||||
return *plinks[i];
|
||||
}
|
||||
inline size_t size(void) const {
|
||||
return plinks.size();
|
||||
}
|
||||
};
|
||||
/*!
|
||||
* \brief initialize connection to the tracker
|
||||
* \return a socket that initializes the connection
|
||||
*/
|
||||
utils::TCPSocket ConnectTracker(void) const;
|
||||
/*!
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
* \param cmd possible command to sent to tracker
|
||||
*/
|
||||
void ReConnectLinks(const char *cmd = "start");
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
*
|
||||
* NOTE on Allreduce:
|
||||
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
|
||||
* It only means the current node get the correct result of Allreduce.
|
||||
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryAllreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer);
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
|
||||
/*!
|
||||
* \brief function used to report error when a link goes wrong
|
||||
* \param link the pointer to the link who causes the error
|
||||
* \param err the error type
|
||||
*/
|
||||
inline ReturnType ReportError(LinkRecord *link, ReturnType err) {
|
||||
err_link = link; return err;
|
||||
}
|
||||
//---- data structure related to model ----
|
||||
// call sequence counter, records how many calls we made so far
|
||||
// from last call to CheckPoint, LoadCheckPoint
|
||||
int seq_counter;
|
||||
// version number of model
|
||||
int version_number;
|
||||
// whether the job is running in hadoop
|
||||
int hadoop_mode;
|
||||
//---- local data related to link ----
|
||||
// index of parent link, can be -1, meaning this is root of the tree
|
||||
int parent_index;
|
||||
// rank of parent node, can be -1
|
||||
int parent_rank;
|
||||
// sockets of all links this connects to
|
||||
std::vector<LinkRecord> all_links;
|
||||
// used to record the link where things goes wrong
|
||||
LinkRecord *err_link;
|
||||
// all the links in the reduction tree connection
|
||||
RefLinkVector tree_links;
|
||||
// pointer to links in the ring
|
||||
LinkRecord *ring_prev, *ring_next;
|
||||
//----- meta information-----
|
||||
// unique identifier of the possible job this process is doing
|
||||
// used to assign ranks, optional, default to NULL
|
||||
std::string task_id;
|
||||
// uri of current host, to be set by Init
|
||||
std::string host_uri;
|
||||
// uri of tracker
|
||||
std::string tracker_uri;
|
||||
// port of tracker address
|
||||
int tracker_port;
|
||||
// port of slave process
|
||||
int slave_port, nport_trial;
|
||||
// reduce buffer size
|
||||
size_t reduce_buffer_size;
|
||||
// current rank
|
||||
int rank;
|
||||
// world size
|
||||
int world_size;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ALLREDUCE_BASE_H
|
||||
100
subtree/rabit/src/allreduce_mock.h
Normal file
100
subtree/rabit/src/allreduce_mock.h
Normal file
@@ -0,0 +1,100 @@
|
||||
/*!
|
||||
* \file allreduce_mock.h
|
||||
* \brief Mock test module of AllReduce engine,
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
*
|
||||
* \author Ignacio Cano, Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_MOCK_H
|
||||
#define RABIT_ALLREDUCE_MOCK_H
|
||||
#include <vector>
|
||||
#include <rabit/engine.h>
|
||||
#include <map>
|
||||
#include "./allreduce_robust.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
class AllreduceMock : public AllreduceRobust {
|
||||
public:
|
||||
// constructor
|
||||
AllreduceMock(void) {
|
||||
num_trial = 0;
|
||||
}
|
||||
// destructor
|
||||
virtual ~AllreduceMock(void) {}
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
AllreduceRobust::SetParam(name, val);
|
||||
// additional parameters
|
||||
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
|
||||
if (!strcmp(name, "mock")) {
|
||||
MockKey k;
|
||||
utils::Check(sscanf(val, "%d,%d,%d,%d",
|
||||
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
|
||||
"invalid mock parameter");
|
||||
mock_map[k] = 1;
|
||||
}
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||
count, reducer, prepare_fun, prepare_arg);
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||
}
|
||||
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
||||
AllreduceRobust::LazyCheckPoint(global_model);
|
||||
}
|
||||
|
||||
private:
|
||||
// key to identify the mock stage
|
||||
struct MockKey {
|
||||
int rank;
|
||||
int version;
|
||||
int seqno;
|
||||
int ntrial;
|
||||
MockKey(void) {}
|
||||
MockKey(int rank, int version, int seqno, int ntrial)
|
||||
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
|
||||
inline bool operator==(const MockKey &b) const {
|
||||
return rank == b.rank &&
|
||||
version == b.version &&
|
||||
seqno == b.seqno &&
|
||||
ntrial == b.ntrial;
|
||||
}
|
||||
inline bool operator<(const MockKey &b) const {
|
||||
if (rank != b.rank) return rank < b.rank;
|
||||
if (version != b.version) return version < b.version;
|
||||
if (seqno != b.seqno) return seqno < b.seqno;
|
||||
return ntrial < b.ntrial;
|
||||
}
|
||||
};
|
||||
// number of failure trials
|
||||
int num_trial;
|
||||
// record all mock actions
|
||||
std::map<MockKey, int> mock_map;
|
||||
// used to generate all kinds of exceptions
|
||||
inline void Verify(const MockKey &key, const char *name) {
|
||||
if (mock_map.count(key) != 0) {
|
||||
num_trial += 1;
|
||||
fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name);
|
||||
exit(-2);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ALLREDUCE_MOCK_H
|
||||
161
subtree/rabit/src/allreduce_robust-inl.h
Normal file
161
subtree/rabit/src/allreduce_robust-inl.h
Normal file
@@ -0,0 +1,161 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_robust-inl.h
|
||||
* \brief implementation of inline template function in AllreduceRobust
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_ENGINE_ROBUST_INL_H_
|
||||
#define RABIT_ENGINE_ROBUST_INL_H_
|
||||
#include <vector>
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*!
|
||||
* \brief run message passing algorithm on the allreduce tree
|
||||
* the result is edge message stored in p_edge_in and p_edge_out
|
||||
* \param node_value the value associated with current node
|
||||
* \param p_edge_in used to store input message from each of the edge
|
||||
* \param p_edge_out used to store output message from each of the edge
|
||||
* \param func a function that defines the message passing rule
|
||||
* Parameters of func:
|
||||
* - node_value same as node_value in the main function
|
||||
* - edge_in the array of input messages from each edge,
|
||||
* this includes the output edge, which should be excluded
|
||||
* - out_index array the index of output edge, the function should
|
||||
* exclude the output edge when compute the message passing value
|
||||
* Return of func:
|
||||
* the function returns the output message based on the input message and node_value
|
||||
*
|
||||
* \tparam EdgeType type of edge message, must be simple struct
|
||||
* \tparam NodeType type of node value
|
||||
*/
|
||||
template<typename NodeType, typename EdgeType>
|
||||
inline AllreduceRobust::ReturnType
|
||||
AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
std::vector<EdgeType> *p_edge_in,
|
||||
std::vector<EdgeType> *p_edge_out,
|
||||
EdgeType (*func)
|
||||
(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.size() == 0) return kSuccess;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.size());
|
||||
// initialize the pointers
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
}
|
||||
std::vector<EdgeType> &edge_in = *p_edge_in;
|
||||
std::vector<EdgeType> &edge_out = *p_edge_out;
|
||||
edge_in.resize(nlink);
|
||||
edge_out.resize(nlink);
|
||||
// stages in the process
|
||||
// 0: recv messages from childs
|
||||
// 1: send message to parent
|
||||
// 2: recv message from parent
|
||||
// 3: send message to childs
|
||||
int stage = 0;
|
||||
// if no childs, no need to, directly start passing message
|
||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||
utils::Assert(parent_index == 0, "parent must be 0");
|
||||
edge_out[parent_index] = func(node_value, edge_in, parent_index);
|
||||
stage = 1;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
// for node with no parent, directly do stage 3
|
||||
if (parent_index == -1) {
|
||||
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
|
||||
}
|
||||
// select helper
|
||||
utils::SelectHelper selecter;
|
||||
bool done = (stage == 3);
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
selecter.WatchException(links[i].sock);
|
||||
switch (stage) {
|
||||
case 0:
|
||||
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
|
||||
selecter.WatchRead(links[i].sock);
|
||||
}
|
||||
break;
|
||||
case 1: if (i == parent_index) selecter.WatchWrite(links[i].sock); break;
|
||||
case 2: if (i == parent_index) selecter.WatchRead(links[i].sock); break;
|
||||
case 3:
|
||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
done = false;
|
||||
}
|
||||
break;
|
||||
default: utils::Error("invalid stage");
|
||||
}
|
||||
}
|
||||
// finish all the stages, and write out message
|
||||
if (done) break;
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
if (selecter.CheckExcept(links[i].sock)) {
|
||||
return ReportError(&links[i], kGetExcept);
|
||||
}
|
||||
}
|
||||
if (stage == 0) {
|
||||
bool finished = true;
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
if (selecter.CheckRead(links[i].sock)) {
|
||||
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||
}
|
||||
if (links[i].size_read != sizeof(EdgeType)) finished = false;
|
||||
}
|
||||
}
|
||||
// if no parent, jump to stage 3, otherwise do stage 1
|
||||
if (finished) {
|
||||
if (parent_index != -1) {
|
||||
edge_out[parent_index] = func(node_value, edge_in, parent_index);
|
||||
stage = 1;
|
||||
} else {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
edge_out[i] = func(node_value, edge_in, i);
|
||||
}
|
||||
stage = 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (stage == 1) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[pid], ret);
|
||||
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
|
||||
}
|
||||
if (stage == 2) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[pid], ret);
|
||||
if (links[pid].size_read == sizeof(EdgeType)) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
|
||||
}
|
||||
stage = 3;
|
||||
}
|
||||
}
|
||||
if (stage == 3) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||
ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ENGINE_ROBUST_INL_H_
|
||||
1178
subtree/rabit/src/allreduce_robust.cc
Normal file
1178
subtree/rabit/src/allreduce_robust.cc
Normal file
File diff suppressed because it is too large
Load Diff
553
subtree/rabit/src/allreduce_robust.h
Normal file
553
subtree/rabit/src/allreduce_robust.h
Normal file
@@ -0,0 +1,553 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_robust.h
|
||||
* \brief Robust implementation of Allreduce
|
||||
* using TCP non-block socket and tree-shape reduction.
|
||||
*
|
||||
* This implementation considers the failure of nodes
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_ROBUST_H_
|
||||
#define RABIT_ALLREDUCE_ROBUST_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "rabit/engine.h"
|
||||
#include "./allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief implementation of fault tolerant all reduce engine */
|
||||
class AllreduceRobust : public AllreduceBase {
|
||||
public:
|
||||
AllreduceRobust(void);
|
||||
virtual ~AllreduceRobust(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(void);
|
||||
/*! \brief shutdown the engine */
|
||||
virtual void Shutdown(void);
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
*/
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun = NULL,
|
||||
void *prepare_arg = NULL);
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
*/
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
this->CheckPoint_(global_model, local_model, false);
|
||||
}
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met(see detailed expplaination).
|
||||
*
|
||||
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||
* In another words, global_model model can be changed only between last call of
|
||||
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||
*
|
||||
* For example, suppose the calling sequence is:
|
||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||
*
|
||||
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||
* improve efficiency of the program.
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
this->CheckPoint_(global_model, NULL, true);
|
||||
}
|
||||
/*!
|
||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throw an exception out,
|
||||
* this function is only used for test purpose
|
||||
*/
|
||||
virtual void InitAfterException(void) {
|
||||
// simple way, shutdown all links
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
||||
}
|
||||
ReConnectLinks("recover");
|
||||
}
|
||||
|
||||
private:
|
||||
// constant one byte out of band message to indicate error happening
|
||||
// and mark for channel cleanup
|
||||
static const char kOOBReset = 95;
|
||||
// and mark for channel cleanup, after OOB signal
|
||||
static const char kResetMark = 97;
|
||||
// and mark for channel cleanup
|
||||
static const char kResetAck = 97;
|
||||
/*! \brief type of roles each node can play during recovery */
|
||||
enum RecoverType {
|
||||
/*! \brief current node have data */
|
||||
kHaveData = 0,
|
||||
/*! \brief current node request data */
|
||||
kRequestData = 1,
|
||||
/*! \brief current node only helps to pass data around */
|
||||
kPassData = 2
|
||||
};
|
||||
/*!
|
||||
* \brief summary of actions proposed in all nodes
|
||||
* this data structure is used to make consensus decision
|
||||
* about next action to take in the recovery mode
|
||||
*/
|
||||
struct ActionSummary {
|
||||
// maximumly allowed sequence id
|
||||
static const int kSpecialOp = (1 << 26);
|
||||
// special sequence number for local state checkpoint
|
||||
static const int kLocalCheckPoint = (1 << 26) - 2;
|
||||
// special sequnce number for local state checkpoint ack signal
|
||||
static const int kLocalCheckAck = (1 << 26) - 1;
|
||||
//---------------------------------------------
|
||||
// The following are bit mask of flag used in
|
||||
//----------------------------------------------
|
||||
// some node want to load check point
|
||||
static const int kLoadCheck = 1;
|
||||
// some node want to do check point
|
||||
static const int kCheckPoint = 2;
|
||||
// check point Ack, we use a two phase message in check point,
|
||||
// this is the second phase of check pointing
|
||||
static const int kCheckAck = 4;
|
||||
// there are difference sequence number the nodes proposed
|
||||
// this means we want to do recover execution of the lower sequence
|
||||
// action instead of normal execution
|
||||
static const int kDiffSeq = 8;
|
||||
// constructor
|
||||
ActionSummary(void) {}
|
||||
// constructor of action
|
||||
explicit ActionSummary(int flag, int minseqno = kSpecialOp) {
|
||||
seqcode = (minseqno << 4) | flag;
|
||||
}
|
||||
// minimum number of all operations
|
||||
inline int min_seqno(void) const {
|
||||
return seqcode >> 4;
|
||||
}
|
||||
// whether the operation set contains a load_check
|
||||
inline bool load_check(void) const {
|
||||
return (seqcode & kLoadCheck) != 0;
|
||||
}
|
||||
// whether the operation set contains a check point
|
||||
inline bool check_point(void) const {
|
||||
return (seqcode & kCheckPoint) != 0;
|
||||
}
|
||||
// whether the operation set contains a check ack
|
||||
inline bool check_ack(void) const {
|
||||
return (seqcode & kCheckAck) != 0;
|
||||
}
|
||||
// whether the operation set contains different sequence number
|
||||
inline bool diff_seq(void) const {
|
||||
return (seqcode & kDiffSeq) != 0;
|
||||
}
|
||||
// returns the operation flag of the result
|
||||
inline int flag(void) const {
|
||||
return seqcode & 15;
|
||||
}
|
||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||
inline static void Reducer(const void *src_, void *dst_,
|
||||
int len, const MPI::Datatype &dtype) {
|
||||
const ActionSummary *src = (const ActionSummary*)src_;
|
||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
int src_seqno = src[i].min_seqno();
|
||||
int dst_seqno = dst[i].min_seqno();
|
||||
int flag = src[i].flag() | dst[i].flag();
|
||||
if (src_seqno == dst_seqno) {
|
||||
dst[i] = ActionSummary(flag, src_seqno);
|
||||
} else {
|
||||
dst[i] = ActionSummary(flag | kDiffSeq,
|
||||
std::min(src_seqno, dst_seqno));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// internel sequence code
|
||||
int seqcode;
|
||||
};
|
||||
/*! \brief data structure to remember result of Bcast and Allreduce calls */
|
||||
class ResultBuffer {
|
||||
public:
|
||||
// constructor
|
||||
ResultBuffer(void) {
|
||||
this->Clear();
|
||||
}
|
||||
// clear the existing record
|
||||
inline void Clear(void) {
|
||||
seqno_.clear(); size_.clear();
|
||||
rptr_.clear(); rptr_.push_back(0);
|
||||
data_.clear();
|
||||
}
|
||||
// allocate temporal space
|
||||
inline void *AllocTemp(size_t type_nbytes, size_t count) {
|
||||
size_t size = type_nbytes * count;
|
||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
utils::Assert(nhop != 0, "cannot allocate 0 size memory");
|
||||
data_.resize(rptr_.back() + nhop);
|
||||
return BeginPtr(data_) + rptr_.back();
|
||||
}
|
||||
// push the result in temp to the
|
||||
inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
|
||||
size_t size = type_nbytes * count;
|
||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
if (seqno_.size() != 0) {
|
||||
utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent");
|
||||
}
|
||||
seqno_.push_back(seqid);
|
||||
rptr_.push_back(rptr_.back() + nhop);
|
||||
size_.push_back(size);
|
||||
utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
|
||||
}
|
||||
// return the stored result of seqid, if any
|
||||
inline void* Query(int seqid, size_t *p_size) {
|
||||
size_t idx = std::lower_bound(seqno_.begin(),
|
||||
seqno_.end(), seqid) - seqno_.begin();
|
||||
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
|
||||
*p_size = size_[idx];
|
||||
return BeginPtr(data_) + rptr_[idx];
|
||||
}
|
||||
// drop last stored result
|
||||
inline void DropLast(void) {
|
||||
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
|
||||
seqno_.pop_back();
|
||||
rptr_.pop_back();
|
||||
size_.pop_back();
|
||||
data_.resize(rptr_.back());
|
||||
}
|
||||
// the sequence number of last stored result
|
||||
inline int LastSeqNo(void) const {
|
||||
if (seqno_.size() == 0) return -1;
|
||||
return seqno_.back();
|
||||
}
|
||||
|
||||
private:
|
||||
// sequence number of each
|
||||
std::vector<int> seqno_;
|
||||
// pointer to the positions
|
||||
std::vector<size_t> rptr_;
|
||||
// actual size of each buffer
|
||||
std::vector<size_t> size_;
|
||||
// content of the buffer
|
||||
std::vector<uint64_t> data_;
|
||||
};
|
||||
/*!
|
||||
* \brief internal consistency check function,
|
||||
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||
* with or without local but not both, this function will set the approperiate settings
|
||||
* in the first call of LoadCheckPoint/CheckPoint
|
||||
*
|
||||
* \param with_local whether the user calls CheckPoint with local model
|
||||
*/
|
||||
void LocalModelCheck(bool with_local);
|
||||
/*!
|
||||
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
* \param lazy_checkpt whether the action is lazy checkpoint
|
||||
*
|
||||
* \sa CheckPoint, LazyCheckPoint
|
||||
*/
|
||||
void CheckPoint_(const ISerializable *global_model,
|
||||
const ISerializable *local_model,
|
||||
bool lazy_checkpt);
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
* after this function finishes, all the messages received and sent
|
||||
* before in all live links are discarded,
|
||||
* This allows us to get a fresh start after error has happened
|
||||
*
|
||||
* TODO(tqchen): this function is not yet functioning was not used by engine,
|
||||
* simple resetlink and reconnect strategy is used
|
||||
*
|
||||
* \return this function can return kSuccess or kSockError
|
||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||
* and some link recovery proceduer is needed
|
||||
*/
|
||||
ReturnType TryResetLinks(void);
|
||||
/*!
|
||||
* \brief if err_type indicates an error
|
||||
* recover links according to the error type reported
|
||||
* if there is no error, return true
|
||||
* \param err_type the type of error happening in the system
|
||||
* \return true if err_type is kSuccess, false otherwise
|
||||
*/
|
||||
bool CheckAndRecover(ReturnType err_type);
|
||||
/*!
|
||||
* \brief try to run recover execution for a request action described by flag and seqno,
|
||||
* the function will keep blocking to run possible recovery operations before the specified action,
|
||||
* until the requested result is received by a recovering procedure,
|
||||
* or the function discovers that the requested action is not yet executed, and return false
|
||||
*
|
||||
* \param buf the buffer to store the result
|
||||
* \param size the total size of the buffer
|
||||
* \param flag flag information about the action \sa ActionSummary
|
||||
* \param seqno sequence number of the action, if it is special action with flag set,
|
||||
* seqno needs to be set to ActionSummary::kSpecialOp
|
||||
*
|
||||
* \return if this function can return true or false
|
||||
* - true means buf already set to the
|
||||
* result by recovering procedure, the action is complete, no further action is needed
|
||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||
*/
|
||||
bool RecoverExec(void *buf, size_t size, int flag,
|
||||
int seqno = ActionSummary::kSpecialOp);
|
||||
/*!
|
||||
* \brief try to load check point
|
||||
*
|
||||
* This is a collaborative function called by all nodes
|
||||
* only the nodes with requester set to true really needs to load the check point
|
||||
* other nodes acts as collaborative roles to complete this request
|
||||
*
|
||||
* \param requester whether current node is the requester
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryLoadCheckPoint(bool requester);
|
||||
/*!
|
||||
* \brief try to get the result of operation specified by seqno
|
||||
*
|
||||
* This is a collaborative function called by all nodes
|
||||
* only the nodes with requester set to true really needs to get the result
|
||||
* other nodes acts as collaborative roles to complete this request
|
||||
*
|
||||
* \param buf the buffer to store the result, this parameter is only used when current node is requester
|
||||
* \param size the total size of the buffer, this parameter is only used when current node is requester
|
||||
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
|
||||
* \param requester whether current node is the requester
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
|
||||
/*!
|
||||
* \brief try to decide the routing strategy for recovery
|
||||
* \param role the current role of the node
|
||||
* \param p_size used to store the size of the message, for node in state kHaveData,
|
||||
* this size must be set correctly before calling the function
|
||||
* for others, this surves as output parameter
|
||||
|
||||
* \param p_recvlink used to store the link current node should recv data from, if necessary
|
||||
* this can be -1, which means current node have the data
|
||||
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
||||
*
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryRecoverData
|
||||
*/
|
||||
ReturnType TryDecideRouting(RecoverType role,
|
||||
size_t *p_size,
|
||||
int *p_recvlink,
|
||||
std::vector<bool> *p_req_in);
|
||||
/*!
|
||||
* \brief try to finish the data recovery request,
|
||||
* this function is used together with TryDecideRouting
|
||||
* \param role the current role of the node
|
||||
* \param sendrecvbuf_ the buffer to store the data to be sent/recived
|
||||
* - if the role is kHaveData, this stores the data to be sent
|
||||
* - if the role is kRequestData, this is the buffer to store the result
|
||||
* - if the role is kPassData, this will not be used, and can be NULL
|
||||
* \param size the size of the data, obtained from TryDecideRouting
|
||||
* \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting
|
||||
* \param req_in the request of each link to send data, obtained from TryDecideRouting
|
||||
*
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryDecideRouting
|
||||
*/
|
||||
ReturnType TryRecoverData(RecoverType role,
|
||||
void *sendrecvbuf_,
|
||||
size_t size,
|
||||
int recv_link,
|
||||
const std::vector<bool> &req_in);
|
||||
/*!
|
||||
* \brief try to recover the local state, making each local state to be the result of itself
|
||||
* plus replication of states in previous num_local_replica hops in the ring
|
||||
*
|
||||
* The input parameters must contain the valid local states available in current nodes,
|
||||
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
|
||||
* If there is sufficient information in the ring, when the function returns, local_chkpt will
|
||||
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
|
||||
* If there is no sufficient information in the ring, this function the number of checkpoints
|
||||
* will be less than the specified value
|
||||
*
|
||||
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||
* \param p_local_chkpt the pointer to the storage of local check points
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
std::string *p_local_chkpt);
|
||||
/*!
|
||||
* \brief try to checkpoint local state, this function is called in normal executation phase
|
||||
* of checkpoint that contains local state
|
||||
o * the input state must exactly one saved state(local state of current node),
|
||||
* after complete, this function will get local state from previous num_local_replica nodes and put them
|
||||
* into local_chkpt and local_rptr
|
||||
*
|
||||
* It is also OK to call TryRecoverLocalState instead,
|
||||
* TryRecoverLocalState makes less assumption about the input, and requires more communications
|
||||
*
|
||||
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||
* \param p_local_chkpt the pointer to the storage of local check points
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryRecoverLocalState
|
||||
*/
|
||||
ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||
std::string *p_local_chkpt);
|
||||
/*!
|
||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||
* this allows data to stream over a ring structure
|
||||
* sendrecvbuf[0:read_ptr] are already provided by current node
|
||||
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
|
||||
* current node will send sendrecvbuf[write_ptr:write_end] to next link
|
||||
* write_ptr will wait till the data is readed before sending the data
|
||||
* this function requires read_end >= write_end
|
||||
*
|
||||
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
|
||||
* \param read_ptr the initial read pointer
|
||||
* \param read_end the ending position to read
|
||||
* \param write_ptr the initial write pointer
|
||||
* \param write_end the ending position to write
|
||||
* \param read_link pointer to link to previous position in ring
|
||||
* \param write_link pointer to link of next position in ring
|
||||
*/
|
||||
ReturnType RingPassing(void *senrecvbuf_,
|
||||
size_t read_ptr,
|
||||
size_t read_end,
|
||||
size_t write_ptr,
|
||||
size_t write_end,
|
||||
LinkRecord *read_link,
|
||||
LinkRecord *write_link);
|
||||
/*!
|
||||
* \brief run message passing algorithm on the allreduce tree
|
||||
* the result is edge message stored in p_edge_in and p_edge_out
|
||||
* \param node_value the value associated with current node
|
||||
* \param p_edge_in used to store input message from each of the edge
|
||||
* \param p_edge_out used to store output message from each of the edge
|
||||
* \param func a function that defines the message passing rule
|
||||
* Parameters of func:
|
||||
* - node_value same as node_value in the main function
|
||||
* - edge_in the array of input messages from each edge,
|
||||
* this includes the output edge, which should be excluded
|
||||
* - out_index array the index of output edge, the function should
|
||||
* exclude the output edge when compute the message passing value
|
||||
* Return of func:
|
||||
* the function returns the output message based on the input message and node_value
|
||||
*
|
||||
* \tparam EdgeType type of edge message, must be simple struct
|
||||
* \tparam NodeType type of node value
|
||||
*/
|
||||
template<typename NodeType, typename EdgeType>
|
||||
inline ReturnType MsgPassing(const NodeType &node_value,
|
||||
std::vector<EdgeType> *p_edge_in,
|
||||
std::vector<EdgeType> *p_edge_out,
|
||||
EdgeType (*func)
|
||||
(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index));
|
||||
//---- recovery data structure ----
|
||||
// the round of result buffer, used to mode the result
|
||||
int result_buffer_round;
|
||||
// result buffer of all reduce
|
||||
ResultBuffer resbuf;
|
||||
// last check point global model
|
||||
std::string global_checkpoint;
|
||||
// lazy checkpoint of global model
|
||||
const ISerializable *global_lazycheck;
|
||||
// number of replica for local state/model
|
||||
int num_local_replica;
|
||||
// number of default local replica
|
||||
int default_local_replica;
|
||||
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
|
||||
int use_local_model;
|
||||
// number of replica for global state/model
|
||||
int num_global_replica;
|
||||
// number of times recovery happens
|
||||
int recover_counter;
|
||||
// --- recovery data structure for local checkpoint
|
||||
// there is two version of the data structure,
|
||||
// at one time one version is valid and another is used as temp memory
|
||||
// pointer to memory position in the local model
|
||||
// local model is stored in CSR format(like a sparse matrices)
|
||||
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
|
||||
std::vector<size_t> local_rptr[2];
|
||||
// storage for local model replicas
|
||||
std::string local_chkpt[2];
|
||||
// version of local checkpoint can be 1 or 0
|
||||
int local_chkpt_version;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
// implementation of inline template function
|
||||
#include "./allreduce_robust-inl.h"
|
||||
#endif // RABIT_ALLREDUCE_ROBUST_H_
|
||||
80
subtree/rabit/src/engine.cc
Normal file
80
subtree/rabit/src/engine.cc
Normal file
@@ -0,0 +1,80 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine.cc
|
||||
* \brief this file governs which implementation of engine we are actually using
|
||||
* provides an singleton of engine interface
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
|
||||
#include <rabit/engine.h>
|
||||
#include "./allreduce_base.h"
|
||||
#include "./allreduce_robust.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
// singleton sync manager
|
||||
#ifndef RABIT_USE_MOCK
|
||||
AllreduceRobust manager;
|
||||
#else
|
||||
AllreduceMock manager;
|
||||
#endif
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
char name[256], val[256];
|
||||
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
|
||||
manager.SetParam(name, val);
|
||||
}
|
||||
}
|
||||
manager.Init();
|
||||
}
|
||||
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
manager.Shutdown();
|
||||
}
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
red, prepare_fun, prepare_arg);
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void)
|
||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return static_cast<int>(dtype.type_size);
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
|
||||
redfunc_ = redfunc;
|
||||
}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
|
||||
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
|
||||
redfunc_, prepare_fun, prepare_arg);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
111
subtree/rabit/src/engine_empty.cc
Normal file
111
subtree/rabit/src/engine_empty.cc
Normal file
@@ -0,0 +1,111 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_empty.cc
|
||||
* \brief this file provides a dummy implementation of engine that does nothing
|
||||
* this file provides a way to fall back to single node program without causing too many dependencies
|
||||
* This is usually NOT needed, use engine_mpi or engine for real distributed version
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
|
||||
#include <rabit/engine.h>
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief EmptyEngine */
|
||||
class EmptyEngine : public IEngine {
|
||||
public:
|
||||
EmptyEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Error("EmptyEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("EmptyEngine is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
return version_number;
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
virtual int GetRank(void) const {
|
||||
return 0;
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
virtual int GetWorldSize(void) const {
|
||||
return 1;
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const {
|
||||
return std::string("");
|
||||
}
|
||||
virtual void TrackerPrint(const std::string &msg) {
|
||||
// simply print information into the tracker
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
EmptyEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {}
|
||||
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return 0;
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
16
subtree/rabit/src/engine_mock.cc
Normal file
16
subtree/rabit/src/engine_mock.cc
Normal file
@@ -0,0 +1,16 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mock.cc
|
||||
* \brief this is an engine implementation that will
|
||||
* insert failures in certain call point, to test if the engine is robust to failure
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
// define use MOCK, os we will use mock Manager
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
// switch engine to AllreduceMock
|
||||
#define RABIT_USE_MOCK
|
||||
#include "./allreduce_mock.h"
|
||||
#include "./engine.cc"
|
||||
|
||||
194
subtree/rabit/src/engine_mpi.cc
Normal file
194
subtree/rabit/src/engine_mpi.cc
Normal file
@@ -0,0 +1,194 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file engine_mpi.cc
|
||||
* \brief this file gives an implementation of engine interface using MPI,
|
||||
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define _CRT_SECURE_NO_DEPRECATE
|
||||
#define NOMINMAX
|
||||
#include <mpi.h>
|
||||
#include <cstdio>
|
||||
#include "rabit/engine.h"
|
||||
#include "rabit/utils.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief implementation of engine using MPI */
|
||||
class MPIEngine : public IEngine {
|
||||
public:
|
||||
MPIEngine(void) {
|
||||
version_number = 0;
|
||||
}
|
||||
virtual void Allreduce(void *sendrecvbuf_,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Error("MPIEngine:: Allreduce is not supported,"\
|
||||
"use Allreduce_ instead");
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) {
|
||||
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
|
||||
}
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("MPI is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
return version_number;
|
||||
}
|
||||
/*! \brief get rank of current node */
|
||||
virtual int GetRank(void) const {
|
||||
return MPI::COMM_WORLD.Get_rank();
|
||||
}
|
||||
/*! \brief get total number of */
|
||||
virtual int GetWorldSize(void) const {
|
||||
return MPI::COMM_WORLD.Get_size();
|
||||
}
|
||||
/*! \brief get the host name of current node */
|
||||
virtual std::string GetHost(void) const {
|
||||
int len;
|
||||
char name[MPI_MAX_PROCESSOR_NAME];
|
||||
MPI::Get_processor_name(name, len);
|
||||
name[len] = '\0';
|
||||
return std::string(name);
|
||||
}
|
||||
virtual void TrackerPrint(const std::string &msg) {
|
||||
// simply print information into the tracker
|
||||
if (GetRank() == 0) {
|
||||
utils::Printf("%s", msg.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int version_number;
|
||||
};
|
||||
|
||||
// singleton sync manager
|
||||
MPIEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
MPI::Init(argc, argv);
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
MPI::Finalize();
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void) {
|
||||
return &manager;
|
||||
}
|
||||
// transform enum to MPI data type
|
||||
inline MPI::Datatype GetType(mpi::DataType dtype) {
|
||||
using namespace mpi;
|
||||
switch (dtype) {
|
||||
case kChar: return MPI::CHAR;
|
||||
case kUChar: return MPI::BYTE;
|
||||
case kInt: return MPI::INT;
|
||||
case kUInt: return MPI::UNSIGNED;
|
||||
case kLong: return MPI::LONG;
|
||||
case kULong: return MPI::UNSIGNED_LONG;
|
||||
case kFloat: return MPI::FLOAT;
|
||||
case kDouble: return MPI::DOUBLE;
|
||||
}
|
||||
utils::Error("unknown mpi::DataType");
|
||||
return MPI::CHAR;
|
||||
}
|
||||
// transform enum to MPI OP
|
||||
inline MPI::Op GetOp(mpi::OpType otype) {
|
||||
using namespace mpi;
|
||||
switch (otype) {
|
||||
case kMax: return MPI::MAX;
|
||||
case kMin: return MPI::MIN;
|
||||
case kSum: return MPI::SUM;
|
||||
case kBitwiseOR: return MPI::BOR;
|
||||
}
|
||||
utils::Error("unknown mpi::OpType");
|
||||
return MPI::MAX;
|
||||
}
|
||||
// perform in-place allreduce, on sendrecvbuf
|
||||
void Allreduce_(void *sendrecvbuf,
|
||||
size_t type_nbytes,
|
||||
size_t count,
|
||||
IEngine::ReduceFunction red,
|
||||
mpi::DataType dtype,
|
||||
mpi::OpType op,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
|
||||
count, GetType(dtype), GetOp(op));
|
||||
}
|
||||
|
||||
// code for reduce handle
|
||||
ReduceHandle::ReduceHandle(void)
|
||||
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
|
||||
}
|
||||
ReduceHandle::~ReduceHandle(void) {
|
||||
if (handle_ != NULL) {
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
|
||||
op->Free();
|
||||
delete op;
|
||||
}
|
||||
if (htype_ != NULL) {
|
||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
|
||||
dtype->Free();
|
||||
delete dtype;
|
||||
}
|
||||
}
|
||||
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
|
||||
return dtype.Get_size();
|
||||
}
|
||||
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
|
||||
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
|
||||
if (type_nbytes != 0) {
|
||||
MPI::Datatype *dtype = new MPI::Datatype();
|
||||
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||
dtype->Commit();
|
||||
created_type_nbytes_ = type_nbytes;
|
||||
htype_ = dtype;
|
||||
}
|
||||
|
||||
MPI::Op *op = new MPI::Op();
|
||||
MPI::User_function *pf = redfunc;
|
||||
op->Init(pf, true);
|
||||
handle_ = op;
|
||||
}
|
||||
void ReduceHandle::Allreduce(void *sendrecvbuf,
|
||||
size_t type_nbytes, size_t count,
|
||||
IEngine::PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
|
||||
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
|
||||
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
|
||||
if (created_type_nbytes_ != type_nbytes || dtype == NULL) {
|
||||
if (dtype == NULL) {
|
||||
dtype = new MPI::Datatype();
|
||||
} else {
|
||||
dtype->Free();
|
||||
}
|
||||
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
|
||||
dtype->Commit();
|
||||
created_type_nbytes_ = type_nbytes;
|
||||
}
|
||||
if (prepare_fun != NULL) prepare_fun(prepare_arg);
|
||||
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
499
subtree/rabit/src/socket.h
Normal file
499
subtree/rabit/src/socket.h
Normal file
@@ -0,0 +1,499 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file socket.h
|
||||
* \brief this file aims to provide a wrapper of sockets
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_SOCKET_H_
|
||||
#define RABIT_SOCKET_H_
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#ifdef _MSC_VER
|
||||
#pragma comment(lib, "Ws2_32.lib")
|
||||
#endif
|
||||
#else
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <errno.h>
|
||||
#include <unistd.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <netinet/in.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/select.h>
|
||||
#include <sys/ioctl.h>
|
||||
#endif
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include "rabit/utils.h"
|
||||
|
||||
#if defined(_WIN32)
|
||||
typedef int ssize_t;
|
||||
typedef int sock_size_t;
|
||||
#else
|
||||
typedef int SOCKET;
|
||||
typedef size_t sock_size_t;
|
||||
const int INVALID_SOCKET = -1;
|
||||
#endif
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*! \brief data structure for network address */
|
||||
struct SockAddr {
|
||||
sockaddr_in addr;
|
||||
// constructor
|
||||
SockAddr(void) {}
|
||||
SockAddr(const char *url, int port) {
|
||||
this->Set(url, port);
|
||||
}
|
||||
inline static std::string GetHostName(void) {
|
||||
std::string buf; buf.resize(256);
|
||||
utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
|
||||
return std::string(buf.c_str());
|
||||
}
|
||||
/*!
|
||||
* \brief set the address
|
||||
* \param url the url of the address
|
||||
* \param port the port of address
|
||||
*/
|
||||
inline void Set(const char *host, int port) {
|
||||
hostent *hp = gethostbyname(host);
|
||||
Check(hp != NULL, "cannot obtain address of %s", host);
|
||||
memset(&addr, 0, sizeof(addr));
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(port);
|
||||
memcpy(&addr.sin_addr, hp->h_addr_list[0], hp->h_length);
|
||||
}
|
||||
/*! \brief return port of the address*/
|
||||
inline int port(void) const {
|
||||
return ntohs(addr.sin_port);
|
||||
}
|
||||
/*! \return a string representation of the address */
|
||||
inline std::string AddrStr(void) const {
|
||||
std::string buf; buf.resize(256);
|
||||
#ifdef _WIN32
|
||||
const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#else
|
||||
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
|
||||
&buf[0], buf.length());
|
||||
#endif
|
||||
Assert(s != NULL, "cannot decode address");
|
||||
return std::string(s);
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief base class containing common operations of TCP and UDP sockets
|
||||
*/
|
||||
class Socket {
|
||||
public:
|
||||
/*! \brief the file descriptor of socket */
|
||||
SOCKET sockfd;
|
||||
// default conversion to int
|
||||
inline operator SOCKET() const {
|
||||
return sockfd;
|
||||
}
|
||||
/*!
|
||||
* \brief start up the socket module
|
||||
* call this before using the sockets
|
||||
*/
|
||||
inline static void Startup(void) {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsa_data;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
|
||||
Socket::Error("Startup");
|
||||
}
|
||||
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
|
||||
WSACleanup();
|
||||
utils::Error("Could not find a usable version of Winsock.dll\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
/*!
|
||||
* \brief shutdown the socket module after use, all sockets need to be closed
|
||||
*/
|
||||
inline static void Finalize(void) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
}
|
||||
/*!
|
||||
* \brief set this socket to use non-blocking mode
|
||||
* \param non_block whether set it to be non-block, if it is false
|
||||
* it will set it back to block mode
|
||||
*/
|
||||
inline void SetNonBlock(bool non_block) {
|
||||
#ifdef _WIN32
|
||||
u_long mode = non_block ? 1 : 0;
|
||||
if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
|
||||
Socket::Error("SetNonBlock");
|
||||
}
|
||||
#else
|
||||
int flag = fcntl(sockfd, F_GETFL, 0);
|
||||
if (flag == -1) {
|
||||
Socket::Error("SetNonBlock-1");
|
||||
}
|
||||
if (non_block) {
|
||||
flag |= O_NONBLOCK;
|
||||
} else {
|
||||
flag &= ~O_NONBLOCK;
|
||||
}
|
||||
if (fcntl(sockfd, F_SETFL, flag) == -1) {
|
||||
Socket::Error("SetNonBlock-2");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
/*!
|
||||
* \brief bind the socket to an address
|
||||
* \param addr
|
||||
*/
|
||||
inline void Bind(const SockAddr &addr) {
|
||||
if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == -1) {
|
||||
Socket::Error("Bind");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief try bind the socket to host, from start_port to end_port
|
||||
* \param start_port starting port number to try
|
||||
* \param end_port ending port number to try
|
||||
* \return the port successfully bind to, return -1 if failed to bind any port
|
||||
*/
|
||||
inline int TryBindHost(int start_port, int end_port) {
|
||||
// TODO(tqchen) add prefix check
|
||||
for (int port = start_port; port < end_port; ++port) {
|
||||
SockAddr addr("0.0.0.0", port);
|
||||
if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0) {
|
||||
return port;
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
if (WSAGetLastError() != WSAEADDRINUSE) {
|
||||
Socket::Error("TryBindHost");
|
||||
}
|
||||
#else
|
||||
if (errno != EADDRINUSE) {
|
||||
Socket::Error("TryBindHost");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
/*! \brief get last error code if any */
|
||||
inline int GetSockError(void) const {
|
||||
int error = 0;
|
||||
socklen_t len = sizeof(error);
|
||||
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
|
||||
Error("GetSockError");
|
||||
}
|
||||
return error;
|
||||
}
|
||||
/*! \brief check if anything bad happens */
|
||||
inline bool BadSocket(void) const {
|
||||
if (IsClosed()) return true;
|
||||
int err = GetSockError();
|
||||
if (err == EBADF || err == EINTR) return true;
|
||||
return false;
|
||||
}
|
||||
/*! \brief check if socket is already closed */
|
||||
inline bool IsClosed(void) const {
|
||||
return sockfd == INVALID_SOCKET;
|
||||
}
|
||||
/*! \brief close the socket */
|
||||
inline void Close(void) {
|
||||
if (sockfd != INVALID_SOCKET) {
|
||||
#ifdef _WIN32
|
||||
closesocket(sockfd);
|
||||
#else
|
||||
close(sockfd);
|
||||
#endif
|
||||
sockfd = INVALID_SOCKET;
|
||||
} else {
|
||||
Error("Socket::Close double close the socket or close without create");
|
||||
}
|
||||
}
|
||||
// report an socket error
|
||||
inline static void Error(const char *msg) {
|
||||
int errsv = errno;
|
||||
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief a wrapper of TCP socket that hopefully be cross platform
|
||||
*/
|
||||
class TCPSocket : public Socket{
|
||||
public:
|
||||
// constructor
|
||||
TCPSocket(void) : Socket(INVALID_SOCKET) {
|
||||
}
|
||||
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
||||
}
|
||||
/*!
|
||||
* \brief enable/disable TCP keepalive
|
||||
* \param keepalive whether to set the keep alive option on
|
||||
*/
|
||||
inline void SetKeepAlive(bool keepalive) {
|
||||
int opt = static_cast<int>(keepalive);
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
|
||||
Socket::Error("SetKeepAlive");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief create the socket, call this before using socket
|
||||
* \param af domain
|
||||
*/
|
||||
inline void Create(int af = PF_INET) {
|
||||
sockfd = socket(PF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd == INVALID_SOCKET) {
|
||||
Socket::Error("Create");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief perform listen of the socket
|
||||
* \param backlog backlog parameter
|
||||
*/
|
||||
inline void Listen(int backlog = 16) {
|
||||
listen(sockfd, backlog);
|
||||
}
|
||||
/*! \brief get a new connection */
|
||||
TCPSocket Accept(void) {
|
||||
SOCKET newfd = accept(sockfd, NULL, NULL);
|
||||
if (newfd == INVALID_SOCKET) {
|
||||
Socket::Error("Accept");
|
||||
}
|
||||
return TCPSocket(newfd);
|
||||
}
|
||||
/*!
|
||||
* \brief decide whether the socket is at OOB mark
|
||||
* \return 1 if at mark, 0 if not, -1 if an error occured
|
||||
*/
|
||||
inline int AtMark(void) const {
|
||||
#ifdef _WIN32
|
||||
unsigned long atmark;
|
||||
if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
|
||||
#else
|
||||
int atmark;
|
||||
if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
|
||||
#endif
|
||||
return static_cast<int>(atmark);
|
||||
}
|
||||
/*!
|
||||
* \brief connect to an address
|
||||
* \param addr the address to connect to
|
||||
* \return whether connect is successful
|
||||
*/
|
||||
inline bool Connect(const SockAddr &addr) {
|
||||
return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
|
||||
sizeof(addr.addr)) == 0;
|
||||
}
|
||||
/*!
|
||||
* \brief send data using the socket
|
||||
* \param buf the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \param flags extra flags
|
||||
* \return size of data actually sent
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
|
||||
}
|
||||
/*!
|
||||
* \brief receive data using the socket
|
||||
* \param buf_ the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \param flags extra flags
|
||||
* \return size of data actually received
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
|
||||
}
|
||||
/*!
|
||||
* \brief peform block write that will attempt to send all data out
|
||||
* can still return smaller than request when error occurs
|
||||
* \param buf the pointer to the buffer
|
||||
* \param len the size of the buffer
|
||||
* \return size of data actually sent
|
||||
*/
|
||||
inline size_t SendAll(const void *buf_, size_t len) {
|
||||
const char *buf = reinterpret_cast<const char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
|
||||
if (ret == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
|
||||
Socket::Error("SendAll");
|
||||
}
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
* \brief peforma block read that will attempt to read all data
|
||||
* can still return smaller than request when error occurs
|
||||
* \param buf_ the buffer pointer
|
||||
* \param len length of data to recv
|
||||
* \return size of data actually sent
|
||||
*/
|
||||
inline size_t RecvAll(void *buf_, size_t len) {
|
||||
char *buf = reinterpret_cast<char*>(buf_);
|
||||
size_t ndone = 0;
|
||||
while (ndone < len) {
|
||||
ssize_t ret = recv(sockfd, buf,
|
||||
static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
|
||||
if (ret == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) return ndone;
|
||||
Socket::Error("RecvAll");
|
||||
}
|
||||
if (ret == 0) return ndone;
|
||||
buf += ret;
|
||||
ndone += ret;
|
||||
}
|
||||
return ndone;
|
||||
}
|
||||
/*!
|
||||
* \brief send a string over network
|
||||
* \param str the string to be sent
|
||||
*/
|
||||
inline void SendStr(const std::string &str) {
|
||||
int len = static_cast<int>(str.length());
|
||||
utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
|
||||
"error during send SendStr");
|
||||
if (len != 0) {
|
||||
utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
|
||||
"error during send SendStr");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief recv a string from network
|
||||
* \param out_str the string to receive
|
||||
*/
|
||||
inline void RecvStr(std::string *out_str) {
|
||||
int len;
|
||||
utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
|
||||
"error during send RecvStr");
|
||||
out_str->resize(len);
|
||||
if (len != 0) {
|
||||
utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
|
||||
"error during send SendStr");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief helper data structure to perform select */
|
||||
struct SelectHelper {
|
||||
public:
|
||||
SelectHelper(void) {
|
||||
FD_ZERO(&read_set);
|
||||
FD_ZERO(&write_set);
|
||||
FD_ZERO(&except_set);
|
||||
maxfd = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief add file descriptor to watch for read
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchRead(SOCKET fd) {
|
||||
FD_SET(fd, &read_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
/*!
|
||||
* \brief add file descriptor to watch for write
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchWrite(SOCKET fd) {
|
||||
FD_SET(fd, &write_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
/*!
|
||||
* \brief add file descriptor to watch for exception
|
||||
* \param fd file descriptor to be watched
|
||||
*/
|
||||
inline void WatchException(SOCKET fd) {
|
||||
FD_SET(fd, &except_set);
|
||||
if (fd > maxfd) maxfd = fd;
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor is ready for read
|
||||
* \param fd file descriptor to check status
|
||||
*/
|
||||
inline bool CheckRead(SOCKET fd) const {
|
||||
return FD_ISSET(fd, &read_set) != 0;
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor is ready for write
|
||||
* \param fd file descriptor to check status
|
||||
*/
|
||||
inline bool CheckWrite(SOCKET fd) const {
|
||||
return FD_ISSET(fd, &write_set) != 0;
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor has any exception
|
||||
* \param fd file descriptor to check status
|
||||
*/
|
||||
inline bool CheckExcept(SOCKET fd) const {
|
||||
return FD_ISSET(fd, &except_set) != 0;
|
||||
}
|
||||
/*!
|
||||
* \brief wait for exception event on a single descriptor
|
||||
* \param fd the file descriptor to wait the event for
|
||||
* \param timeout the timeout counter, can be 0, which means wait until the event happen
|
||||
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
||||
*/
|
||||
inline static int WaitExcept(SOCKET fd, long timeout = 0) {
|
||||
fd_set wait_set;
|
||||
FD_ZERO(&wait_set);
|
||||
FD_SET(fd, &wait_set);
|
||||
return Select_(static_cast<int>(fd + 1),
|
||||
NULL, NULL, &wait_set, timeout);
|
||||
}
|
||||
/*!
|
||||
* \brief peform select on the set defined
|
||||
* \param select_read whether to watch for read event
|
||||
* \param select_write whether to watch for write event
|
||||
* \param select_except whether to watch for exception event
|
||||
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
|
||||
* \return number of active descriptors selected,
|
||||
* return -1 if error occurs
|
||||
*/
|
||||
inline int Select(long timeout = 0) {
|
||||
int ret = Select_(static_cast<int>(maxfd + 1),
|
||||
&read_set, &write_set, &except_set, timeout);
|
||||
if (ret == -1) {
|
||||
Socket::Error("Select");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
inline static int Select_(int maxfd, fd_set *rfds,
|
||||
fd_set *wfds, fd_set *efds, long timeout) {
|
||||
#if !defined(_WIN32)
|
||||
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
|
||||
#endif
|
||||
if (timeout == 0) {
|
||||
return select(maxfd, rfds, wfds, efds, NULL);
|
||||
} else {
|
||||
timeval tm;
|
||||
tm.tv_usec = (timeout % 1000) * 1000;
|
||||
tm.tv_sec = timeout / 1000;
|
||||
return select(maxfd, rfds, wfds, efds, &tm);
|
||||
}
|
||||
}
|
||||
|
||||
SOCKET maxfd;
|
||||
fd_set read_set, write_set, except_set;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace rabit
|
||||
#endif // RABIT_SOCKET_H_
|
||||
Reference in New Issue
Block a user