Merge rabit

This commit is contained in:
fis
2020-08-18 03:52:33 +08:00
81 changed files with 11230 additions and 0 deletions

31
rabit/src/CMakeLists.txt Normal file
View File

@@ -0,0 +1,31 @@
option(DMLC_ROOT "Specify root of external dmlc core.")
add_library(allreduce_base "")
add_library(allreduce_mock "")
target_sources(
allreduce_base
PRIVATE
allreduce_base.cc
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
)
target_sources(
allreduce_mock
PRIVATE
allreduce_robust.cc
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h
)
target_include_directories(
allreduce_base
PUBLIC
${DMLC_ROOT}/include
${CMAKE_CURRENT_LIST_DIR}/../../include)
target_include_directories(
allreduce_mock
PUBLIC
${DMLC_ROOT}/include
${CMAKE_CURRENT_LIST_DIR}/../../include)

6
rabit/src/README.md Normal file
View File

@@ -0,0 +1,6 @@
Source Files of Rabit
====
* This folder contains the source files of rabit library
* The library headers are in folder [include](../include)
* The .h files in this folder are internal header files that are only used by rabit and will not be seen by users

979
rabit/src/allreduce_base.cc Normal file
View File

@@ -0,0 +1,979 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.cc
* \brief Basic implementation of AllReduce
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#define NOMINMAX
#include "allreduce_base.h"
#include <rabit/base.h>
#include <netinet/tcp.h>
#include <cstring>
#include <map>
namespace rabit {
namespace engine {
// constructor
AllreduceBase::AllreduceBase(void) {
tracker_uri = "NULL";
tracker_port = 9000;
host_uri = "";
slave_port = 9010;
nport_trial = 1000;
rank = 0;
world_size = -1;
connect_retry = 5;
hadoop_mode = 0;
version_number = 0;
// 32 K items
reduce_ring_mincount = 32 << 10;
// 1M reducer size each time
tree_reduce_minsize = 1 << 20;
// tracker URL
task_id = "NULL";
err_link = NULL;
dmlc_role = "worker";
this->SetParam("rabit_reduce_buffer", "256MB");
// setup possible enviroment variable of interest
// include dmlc support direct variables
env_vars.push_back("DMLC_TASK_ID");
env_vars.push_back("DMLC_ROLE");
env_vars.push_back("DMLC_NUM_ATTEMPT");
env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
}
// initialization function
bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
const char *value = getenv(env_vars[i].c_str());
if (value != NULL) {
this->SetParam(env_vars[i].c_str(), value);
}
}
// pass in arguments override env variable.
for (int i = 0; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) {
this->SetParam(name, val);
}
}
{
// handling for hadoop
const char *task_id = getenv("mapred_tip_id");
if (task_id == NULL) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode) {
utils::Check(task_id != NULL,
"hadoop_mode is set but cannot find mapred_task_id");
}
if (task_id != NULL) {
this->SetParam("rabit_task_id", task_id);
this->SetParam("rabit_hadoop_mode", "1");
}
const char *attempt_id = getenv("mapred_task_id");
if (attempt_id != 0) {
const char *att = strrchr(attempt_id, '_');
int num_trial;
if (att != NULL && sscanf(att + 1, "%d", &num_trial) == 1) {
this->SetParam("rabit_num_trial", att + 1);
}
}
// handling for hadoop
const char *num_task = getenv("mapred_map_tasks");
if (num_task == NULL) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode) {
utils::Check(num_task != NULL,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
if (num_task != NULL) {
this->SetParam("rabit_world_size", num_task);
}
}
if (dmlc_role != "worker") {
fprintf(stderr, "Rabit Module currently only work with dmlc worker"\
", quit this program by exit 0\n");
exit(0);
}
// clear the setting before start reconnection
this->rank = -1;
//---------------------
// start socket
utils::Socket::Startup();
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
// get information from tracker
return this->ReConnectLinks();
}
bool AllreduceBase::Shutdown(void) {
try {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.Close();
utils::TCPSocket::Finalize();
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed to shutdown due to %s\n", e.what());
return false;
}
}
void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("print"));
tracker.SendStr(msg);
tracker.Close();
}
// util to parse data with unit suffix
inline size_t ParseUnit(const char *name, const char *val) {
char unit;
unsigned long amt; // NOLINT(*)
int n = sscanf(val, "%lu%c", &amt, &unit);
size_t amount = amt;
if (n == 2) {
switch (unit) {
case 'B': return amount;
case 'K': return amount << 10UL;
case 'M': return amount << 20UL;
case 'G': return amount << 30UL;
default: utils::Error("invalid format for %s", name); return 0;
}
} else if (n == 1) {
return amount;
} else {
utils::Error("invalid format for %s," \
"shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name);
return 0;
}
}
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val;
if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val);
if (!strcmp(name, "rabit_task_id")) task_id = val;
if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val;
if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val);
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val);
if (!strcmp(name, "rabit_tree_reduce_minsize")) tree_reduce_minsize = atoi(val);
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
reduce_ring_mincount = atoi(val);
utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
}
if (!strcmp(name, "rabit_reduce_buffer")) {
reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3;
}
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
connect_retry = atoi(val);
}
if (!strcmp(name, "rabit_bootstrap_cache")) {
rabit_bootstrap_cache = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_debug")) {
rabit_debug = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_timeout")) {
rabit_timeout = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_timeout_sec")) {
timeout_sec = atoi(val);
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
}
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
if (!strcmp(val, "true"))
rabit_enable_tcp_no_delay = true;
else
rabit_enable_tcp_no_delay = false;
}
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
tracker.Create();
int retry = 0;
do {
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
if (++retry >= connect_retry) {
fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str());
utils::Socket::Error("Connect");
} else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
#if defined(_MSC_VER) || defined (__MINGW32__)
Sleep(retry << 1);
#else
sleep(retry << 1);
#endif
continue;
}
}
break;
} while (1);
using utils::Assert;
Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 1");
Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic),
"ReConnectLink failure 2");
utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure");
Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 3");
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id);
return tracker;
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
bool AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return true;
}
try {
utils::TCPSocket tracker = this->ConnectTracker();
fprintf(stdout, "task %s connected to the tracker\n", task_id.c_str());
tracker.SendStr(std::string(cmd));
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
using utils::Assert;
// get new ranks
int newrank, num_neighbors;
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \
sizeof(parent_rank), "ReConnectLink failure 4");
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
Assert(rank == -1 || newrank == rank,
"must keep rank to same if the node already have one");
rank = newrank;
// tracker got overwhelemed and not able to assign correct rank
if (rank == -1) exit(-1);
fprintf(stdout, "task %s got new rank %d\n", task_id.c_str(), rank);
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
utils::TCPSocket sock_listen;
if (!sock_listen.IsClosed()) {
sock_listen.Close();
}
// create listening socket
sock_listen.Create();
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
} else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
sizeof(good_link[i]), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1;
r.sock.Close();
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
match = true;
break;
}
}
if (!match) all_links.push_back(r);
}
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
match = true;
break;
}
}
if (!match) all_links.push_back(r);
}
sock_listen.Close();
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
int tcpNoDelay = 1;
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) {
setsockopt(all_links[i].sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
}
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_links[i]);
}
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
}
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed in ReconnectLink %s\n", e.what());
return false;
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*
* NOTE on Allreduce:
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
if (count > reduce_ring_mincount) {
return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer);
} else {
return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer);
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || count == 0) return kSuccess;
// total size of message
const size_t total_size = type_nbytes * count;
// number of links
const int nlink = static_cast<int>(links.size());
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// size of space that we already performs reduce in up pass
size_t size_up_reduce = 0;
// size of space that we have already passed to parent
size_t size_up_out = 0;
// size of message we received, and send in the down pass
size_t size_down_in = 0;
// minimal size of each reducer
const size_t eachreduce = (tree_reduce_minsize / type_nbytes * type_nbytes);
// initialize the link ring-buffer and pointer
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
links[i].InitBuffer(type_nbytes, count, reduce_buffer_size);
}
links[i].ResetSize();
}
// if no childs, no need to reduce
if (nlink == static_cast<int>(parent_index != -1)) {
size_up_reduce = total_size;
}
// while we have not passed the messages out
while (true) {
// select helper
bool finished = true;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
watcher.WatchRead(links[i].sock);
// only watch for exception in live channels
watcher.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
watcher.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
watcher.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) {
watcher.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
watcher.WatchException(links[i].sock);
finished = false;
}
}
}
// finish runing allreduce
if (finished) break;
// select must return
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
// make sure to receive minimal reducer size
// since each child reduce and sends the minimal reducer size
while (links[i].size_read < total_size
&& links[i].size_read - size_up_reduce < eachreduce) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
// this node have childs, peform reduce
if (nlink > static_cast<int>(parent_index != -1)) {
size_t buffer_size = 0;
// do upstream reduce
size_t max_reduce = total_size;
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
max_reduce = std::min(max_reduce, links[i].size_read);
utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size,
"buffer size inconsistent");
buffer_size = links[i].buffer_size;
}
}
utils::Assert(buffer_size != 0, "must assign buffer_size");
// round to type_n4bytes
max_reduce = (max_reduce / type_nbytes * type_nbytes);
// if max reduce is less than total size, we reduce multiple times of
// eachreduce size
if (max_reduce < total_size)
max_reduce = max_reduce - max_reduce % eachreduce;
// peform reduce, can be at most two rounds
while (size_up_reduce < max_reduce) {
// start position
size_t start = size_up_reduce % buffer_size;
// peform read till end of buffer
size_t nread = std::min(buffer_size - start,
max_reduce - size_up_reduce);
utils::Assert(nread % type_nbytes == 0, "Allreduce: size check");
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
reducer(links[i].buffer_head + start,
sendrecvbuf + size_up_reduce,
static_cast<int>(nread / type_nbytes),
MPI::Datatype(type_nbytes));
}
}
size_up_reduce += nread;
}
}
if (parent_index != -1) {
// pass message up to parent, can pass data that are already been reduced
if (size_up_out < size_up_reduce) {
ssize_t len = links[parent_index].sock.
Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out);
if (len != -1) {
size_up_out += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
return ReportError(&links[parent_index], ret);
}
}
}
// read data from parent
if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
size_t left_size = total_size-size_down_in;
size_t reduce_size_min = std::min(left_size, eachreduce);
size_t recved = 0;
while (recved < reduce_size_min) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
links[parent_index].sock.Close();
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {
size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out,
"Allreduce: boundary error");
recved+=len;
// if it receives more data than each reduce, it means the next block is sent.
// we double the reduce_size_min or add to left_size
while (recved > reduce_size_min) {
reduce_size_min += std::min(left_size-reduce_size_min, eachreduce);
}
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) {
return ReportError(&links[parent_index], ret);
}
}
}
}
} else {
// this is root, can use reduce as most recent point
size_down_in = size_up_out = size_up_reduce;
}
// can pass message down to childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write < size_down_in) {
ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and recving data
* \param total_size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
RefLinkVector &links = tree_links;
if (links.size() == 0 || total_size == 0) return kSuccess;
utils::Check(root < world_size,
"Broadcast: root should be smaller than world size");
// number of links
const int nlink = static_cast<int>(links.size());
// size of space already read from data
size_t size_in = 0;
// input link, -2 means unknown yet, -1 means this is root
int in_link = -2;
// initialize the link statistics
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
// root have all the data
if (this->rank == root) {
size_in = total_size;
in_link = -1;
}
// while we have not passed the messages out
while (true) {
bool finished = true;
// select helper
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
watcher.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
watcher.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
watcher.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
size_in = links[i].size_read;
if (size_in != 0) {
in_link = i; break;
}
}
}
} else {
// read from in link
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
}
size_in = links[in_link].size_read;
}
}
// send data to all out-link
for (int i = 0; i < nlink; ++i) {
if (i != in_link && links[i].size_write < size_in) {
ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
const size_t stop_read = total_size + slice_begin;
const size_t stop_write = total_size + slice_begin - size_prev_slice;
size_t write_ptr = slice_begin;
size_t read_ptr = slice_end;
while (true) {
// select helper
bool finished = true;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = next.sock.Recv(sendrecvbuf + start, size);
if (len != -1) {
read_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) return ReportError(&next, ret);
}
}
if (write_ptr < read_ptr && write_ptr != stop_write) {
size_t size = std::min(read_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) return ReportError(&prev, ret);
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail,
* and will return the cause of failure
*
* Ring-based algorithm
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
AllreduceBase::ReturnType
AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
// read from next link and send to prev one
LinkRecord &prev = *ring_prev, &next = *ring_next;
// need to reply on special rank structure
utils::Assert(next.rank == (rank + 1) % world_size &&
rank == (prev.rank + 1) % world_size,
"need to assume rank structure");
// total size of message
const size_t total_size = type_nbytes * count;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t r = static_cast<size_t>(next.rank);
size_t write_ptr = std::min(r * step, count) * type_nbytes;
size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes;
size_t reduce_ptr = read_ptr;
// send recv buffer
char *sendrecvbuf = reinterpret_cast<char*>(sendrecvbuf_);
// position to stop reading
const size_t stop_read = total_size + write_ptr;
// position to stop writing
size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes;
if (stop_write > stop_read) {
stop_write -= total_size;
utils::Assert(write_ptr <= stop_write, "write ptr boundary check");
}
// use ring buffer in next position
next.InitBuffer(type_nbytes, step, reduce_buffer_size);
// set size_read to read pointer for ring buffer to work properly
next.size_read = read_ptr;
while (true) {
// select helper
bool finished = true;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
}
// sync the rate
read_ptr = next.size_read;
utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank);
const size_t buffer_size = next.buffer_size;
size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes;
while (reduce_ptr < max_reduce) {
size_t bstart = reduce_ptr % buffer_size;
size_t nread = std::min(buffer_size - bstart,
max_reduce - reduce_ptr);
size_t rstart = reduce_ptr % total_size;
nread = std::min(nread, total_size - rstart);
reducer(next.buffer_head + bstart,
sendrecvbuf + rstart,
static_cast<int>(nread / type_nbytes),
MPI::Datatype(type_nbytes));
reduce_ptr += nread;
}
}
if (write_ptr < reduce_ptr && write_ptr != stop_write) {
size_t size = std::min(reduce_ptr, stop_write) - write_ptr;
size_t start = write_ptr % total_size;
if (start + size > total_size) {
size = total_size - start;
}
ssize_t len = prev.sock.Send(sendrecvbuf + start, size);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
ReturnType ret = Errno2Return();
if (ret != kSuccess) return ReportError(&prev, ret);
}
}
}
return kSuccess;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllreduceBase::ReturnType
AllreduceBase::TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer);
if (ret != kSuccess) return ret;
size_t n = static_cast<size_t>(world_size);
size_t step = (count + n - 1) / n;
size_t begin = std::min(rank * step, count) * type_nbytes;
size_t end = std::min((rank + 1) * step, count) * type_nbytes;
// previous rank
int prank = ring_prev->rank;
// get rank of previous
return TryAllgatherRing
(sendrecvbuf_, type_nbytes * count,
begin, end,
(std::min((prank + 1) * step, count) -
std::min(prank * step, count)) * type_nbytes);
}
} // namespace engine
} // namespace rabit

589
rabit/src/allreduce_base.h Normal file
View File

@@ -0,0 +1,589 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_base.h
* \brief Basic implementation of AllReduce
* using TCP non-block socket and tree-shape reduction.
*
* This implementation provides basic utility of AllReduce and Broadcast
* without considering node failure
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_BASE_H_
#define RABIT_ALLREDUCE_BASE_H_
#include <vector>
#include <string>
#include <algorithm>
#include "rabit/internal/utils.h"
#include "rabit/internal/engine.h"
#include "rabit/internal/socket.h"
#ifdef RABIT_CXXTESTDEFS_H
#define private public
#define protected public
#endif // RABIT_CXXTESTDEFS_H
namespace MPI {
// MPI data type to be compatible with existing MPI interface
class Datatype {
public:
size_t type_size;
explicit Datatype(size_t type_size) : type_size(type_size) {}
};
}
namespace rabit {
namespace engine {
/*! \brief implementation of basic Allreduce engine */
class AllreduceBase : public IEngine {
public:
// magic number to verify server
static const int kMagic = 0xff99;
// constant one byte out of band message to indicate error happening
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
virtual bool Init(int argc, char* argv[]);
// shutdown the engine
virtual bool Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief print the msg in the tracker,
* this function can be used to communicate the information of the progress to
* the user who monitors the tracker
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
/*! \brief get rank of previous node in ring topology*/
virtual int GetRingPrevRank(void) const {
return ring_prev->rank;
}
/*! \brief get rank */
virtual int GetRank(void) const {
return rank;
}
/*! \brief get rank */
virtual int GetWorldSize(void) const {
if (world_size == -1) return 1;
return world_size;
}
/*! \brief whether is distributed or not */
virtual bool IsDistributed(void) const {
return tracker_uri != "NULL";
}
/*! \brief get rank */
virtual std::string GetHost(void) const {
return host_uri;
}
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
slice_begin, slice_end, size_prev_slice) == kSuccess,
"AllgatherRing failed");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryAllreduce(sendrecvbuf_,
type_nbytes, count, reducer) == kSuccess,
"Allreduce failed");
}
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE, const int _line = _LINE, const char* _caller = _CALLER) {
if (world_size == 1 || world_size == -1) return;
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"Broadcast failed");
}
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
return 0;
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const {
return version_number;
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
utils::Error("InitAfterException: not implemented");
}
/*!
* \brief report current status to the job tracker
* depending on the job tracker we are in
*/
inline void ReportStatus(void) const {
if (hadoop_mode != 0) {
fprintf(stderr, "reporter:status:Rabit Phase[%03d] Operation %03d\n",
version_number, seq_counter);
}
}
protected:
/*! \brief enumeration of possible returning results from Try functions */
enum ReturnTypeEnum {
/*! \brief execution is successful */
kSuccess,
/*! \brief a link was reset by peer */
kConnReset,
/*! \brief received a zero length message */
kRecvZeroLen,
/*! \brief a neighbor node go down, the connection is dropped */
kSockError,
/*!
* \brief another node which is not my neighbor go down,
* get Out-of-Band exception notification from my neighbor
*/
kGetExcept
};
/*! \brief struct return type to avoid implicit conversion to int/bool */
struct ReturnType {
/*! \brief internal return type */
ReturnTypeEnum value;
// constructor
ReturnType() {}
ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*)
inline bool operator==(const ReturnTypeEnum &v) const {
return value == v;
}
inline bool operator!=(const ReturnTypeEnum &v) const {
return value != v;
}
};
/*! \brief translate errno to return type */
inline static ReturnType Errno2Return() {
int errsv = utils::Socket::GetLastError();
if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess;
#ifdef _WIN32
if (errsv == WSAEWOULDBLOCK) return kSuccess;
if (errsv == WSAECONNRESET) return kConnReset;
#endif // _WIN32
if (errsv == ECONNRESET) return kConnReset;
return kSockError;
}
// link record to a neighbor
struct LinkRecord {
public:
// socket to get data from/to link
utils::TCPSocket sock;
// rank of the node in this link
int rank;
// size of data readed from link
size_t size_read;
// size of data sent to the link
size_t size_write;
// pointer to buffer head
char *buffer_head;
// buffer size, in bytes
size_t buffer_size;
// constructor
LinkRecord(void)
: buffer_head(NULL), buffer_size(0) {
}
// initialize buffer
inline void InitBuffer(size_t type_nbytes, size_t count,
size_t reduce_buffer_size) {
size_t n = (type_nbytes * count + 7)/ 8;
buffer_.resize(std::min(reduce_buffer_size, n));
// make sure align to type_nbytes
buffer_size =
buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes;
utils::Assert(type_nbytes <= buffer_size,
"too large type_nbytes=%lu, buffer_size=%lu",
type_nbytes, buffer_size);
// set buffer head
buffer_head = reinterpret_cast<char*>(BeginPtr(buffer_));
}
// reset the recv and sent size
inline void ResetSize(void) {
size_write = size_read = 0;
}
/*!
* \brief read data into ring-buffer, with care not to existing useful override data
* position after protect_start
* \param protect_start all data start from protect_start is still needed in buffer
* read shall not override this
* \param max_size_read maximum logical amount we can read, size_read cannot exceed this value
* \return the type of reading
*/
inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) {
utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated");
utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check");
size_t ngap = size_read - protect_start;
utils::Assert(ngap <= buffer_size, "Allreduce: boundary check");
size_t offset = size_read % buffer_size;
size_t nmax = max_size_read - size_read;
nmax = std::min(nmax, buffer_size - ngap);
nmax = std::min(nmax, buffer_size - offset);
if (nmax == 0) return kSuccess;
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
return kSuccess;
}
/*!
* \brief read data into array,
* this function can not be used together with ReadToRingBuffer
* a link can either read into the ring buffer, or existing array
* \param max_size maximum size of array
* \return true if it is an successful read, false if there is some error happens, check errno
*/
inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) {
if (max_size == size_read) return kSuccess;
char *p = static_cast<char*>(recvbuf_);
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
return kSuccess;
}
/*!
* \brief write data in array to sock
* \param sendbuf_ head of array
* \param max_size maximum size of array
* \return true if it is an successful write, false if there is some error happens, check errno
*/
inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) {
const char *p = static_cast<const char*>(sendbuf_);
ssize_t len = sock.Send(p + size_write, max_size - size_write);
if (len == -1) return Errno2Return();
size_write += static_cast<size_t>(len);
return kSuccess;
}
private:
// recv buffer to get data from child
// aligned with 64 bits, will be able to perform 64 bits operations freely
std::vector<uint64_t> buffer_;
};
/*!
* \brief simple data structure that works like a vector
* but takes reference instead of space
*/
struct RefLinkVector {
std::vector<LinkRecord*> plinks;
inline LinkRecord &operator[](size_t i) {
return *plinks[i];
}
inline size_t size(void) const {
return plinks.size();
}
};
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker(void) const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
bool ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*
* NOTE on Allreduce:
* The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce.
* It only means the current node get the correct result of Allreduce.
* However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure
* \param sendrecvbuf_ buffer for both sending and receiving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
/*!
* \brief perform in-place allreduce, on sendrecvbuf,
* this function implements tree-shape reduction
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceTree(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
size_t slice_begin, size_t slice_end,
size_t size_prev_slice);
/*!
* \brief perform in-place allreduce, reduce on the sendrecvbuf,
*
* after the function, node k get k-th segment of the reduction result
* the k-th segment is defined by [k * step, min((k + 1) * step,count) )
* where step = ceil(count / world_size)
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType, TryAllreduce
*/
ReturnType TryReduceScatterRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* use a ring based algorithm, reduce-scatter + allgather
*
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryAllreduceRing(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer);
/*!
* \brief function used to report error when a link goes wrong
* \param link the pointer to the link who causes the error
* \param err the error type
*/
inline ReturnType ReportError(LinkRecord *link, ReturnType err) {
err_link = link; return err;
}
//---- data structure related to model ----
// call sequence counter, records how many calls we made so far
// from last call to CheckPoint, LoadCheckPoint
int seq_counter;
// version number of model
int version_number;
// whether the job is running in hadoop
bool hadoop_mode;
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index;
// rank of parent node, can be -1
int parent_rank;
// sockets of all links this connects to
std::vector<LinkRecord> all_links;
// used to record the link where things goes wrong
LinkRecord *err_link;
// all the links in the reduction tree connection
RefLinkVector tree_links;
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
//----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars;
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL
std::string task_id;
// uri of current host, to be set by Init
std::string host_uri;
// uri of tracker
std::string tracker_uri;
// role in dmlc jobs
std::string dmlc_role;
// port of tracker address
int tracker_port;
// port of slave process
int slave_port, nport_trial;
// reduce buffer size
size_t reduce_buffer_size;
// reduction method
int reduce_method;
// mininum count of cells to use ring based method
size_t reduce_ring_mincount;
// minimul block size per tree reduce
size_t tree_reduce_minsize;
// current rank
int rank;
// world size
int world_size;
// connect retry time
int connect_retry;
// enable bootstrap cache 0 false 1 true
bool rabit_bootstrap_cache = false;
// enable detailed logging
bool rabit_debug = false;
// by default, if rabit worker not recover in half an hour exit
int timeout_sec = 1800;
// flag to enable rabit_timeout
bool rabit_timeout = false;
// Enable TCP node delay
bool rabit_enable_tcp_no_delay = false;
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_BASE_H_

206
rabit/src/allreduce_mock.h Normal file
View File

@@ -0,0 +1,206 @@
/*!
* Copyright by Contributors
* \file allreduce_mock.h
* \brief Mock test module of AllReduce engine,
* insert failures in certain call point, to test if the engine is robust to failure
*
* \author Ignacio Cano, Tianqi Chen
*/
#ifndef RABIT_ALLREDUCE_MOCK_H_
#define RABIT_ALLREDUCE_MOCK_H_
#include <vector>
#include <map>
#include <sstream>
#include "rabit/internal/engine.h"
#include "rabit/internal/timer.h"
#include "allreduce_robust.h"
namespace rabit {
namespace engine {
class AllreduceMock : public AllreduceRobust {
public:
// constructor
AllreduceMock(void) {
num_trial = 0;
force_local = 0;
report_stats = 0;
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
}
// destructor
virtual ~AllreduceMock(void) {}
virtual void SetParam(const char *name, const char *val) {
AllreduceRobust::SetParam(name, val);
// additional parameters
if (!strcmp(name, "rabit_num_trial")) num_trial = atoi(val);
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial = atoi(val);
if (!strcmp(name, "report_stats")) report_stats = atoi(val);
if (!strcmp(name, "force_local")) force_local = atoi(val);
if (!strcmp(name, "mock")) {
MockKey k;
utils::Check(sscanf(val, "%d,%d,%d,%d",
&k.rank, &k.version, &k.seqno, &k.ntrial) == 4,
"invalid mock parameter");
mock_map[k] = 1;
}
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg,
_file, _line, _caller);
tsum_allreduce += utils::GetTime() - tstart;
}
virtual void Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Allgather");
double tstart = utils::GetTime();
AllreduceRobust::Allgather(sendrecvbuf, total_size,
slice_begin, slice_end,
size_prev_slice, _file, _line, _caller);
tsum_allgather += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
return AllreduceRobust::LoadCheckPoint(&dum, &com);
}
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint;
if (force_local == 0) {
AllreduceRobust::CheckPoint(global_model, local_model);
} else {
DummySerializer dum;
ComboSerializer com(global_model, local_model);
AllreduceRobust::CheckPoint(&dum, &com);
}
time_checkpoint = utils::GetTime();
double tcost = utils::GetTime() - tstart;
if (report_stats != 0 && rank == 0) {
std::stringstream ss;
ss << "[v" << version_number << "] global_size=" << global_checkpoint.length()
<< ",local_size=" << (local_chkpt[0].length() + local_chkpt[1].length())
<< ",check_tcost="<< tcost <<" sec"
<< ",allreduce_tcost=" << tsum_allreduce << " sec"
<< ",allgather_tcost=" << tsum_allgather << " sec"
<< ",between_chpt=" << tbet_chkpt << "sec\n";
this->TrackerPrint(ss.str());
}
tsum_allreduce = 0.0;
tsum_allgather = 0.0;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model);
}
protected:
// force checkpoint to local
int force_local;
// whether report statistics
int report_stats;
// sum of allreduce
double tsum_allreduce;
// sum of allgather
double tsum_allgather;
double time_checkpoint;
private:
struct DummySerializer : public Serializable {
virtual void Load(Stream *fi) {
}
virtual void Save(Stream *fo) const {
}
};
struct ComboSerializer : public Serializable {
Serializable *lhs;
Serializable *rhs;
const Serializable *c_lhs;
const Serializable *c_rhs;
ComboSerializer(Serializable *lhs, Serializable *rhs)
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
}
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
}
virtual void Load(Stream *fi) {
if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi);
}
virtual void Save(Stream *fo) const {
if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo);
}
};
// key to identify the mock stage
struct MockKey {
int rank;
int version;
int seqno;
int ntrial;
MockKey(void) {}
MockKey(int rank, int version, int seqno, int ntrial)
: rank(rank), version(version), seqno(seqno), ntrial(ntrial) {}
inline bool operator==(const MockKey &b) const {
return rank == b.rank &&
version == b.version &&
seqno == b.seqno &&
ntrial == b.ntrial;
}
inline bool operator<(const MockKey &b) const {
if (rank != b.rank) return rank < b.rank;
if (version != b.version) return version < b.version;
if (seqno != b.seqno) return seqno < b.seqno;
return ntrial < b.ntrial;
}
};
// number of failure trials
int num_trial;
// record all mock actions
std::map<MockKey, int> mock_map;
// used to generate all kinds of exceptions
inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) {
num_trial += 1;
// data processing frameworks runs on shared process
_error("[%d]@@@Hit Mock Error:%s ", rank, name);
}
}
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_MOCK_H_

View File

@@ -0,0 +1,169 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust-inl.h
* \brief implementation of inline template function in AllreduceRobust
*
* \author Tianqi Chen
*/
#ifndef RABIT_ALLREDUCE_ROBUST_INL_H_
#define RABIT_ALLREDUCE_ROBUST_INL_H_
#include <vector>
namespace rabit {
namespace engine {
/*!
* \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge
* \param p_edge_out used to store output message from each of the edge
* \param func a function that defines the message passing rule
* Parameters of func:
* - node_value same as node_value in the main function
* - edge_in the array of input messages from each edge,
* this includes the output edge, which should be excluded
* - out_index array the index of output edge, the function should
* exclude the output edge when compute the message passing value
* Return of func:
* the function returns the output message based on the input message and node_value
*
* \tparam EdgeType type of edge message, must be simple struct
* \tparam NodeType type of node value
*/
template<typename NodeType, typename EdgeType>
inline AllreduceRobust::ReturnType
AllreduceRobust::MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType(*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index)) {
RefLinkVector &links = tree_links;
if (links.size() == 0) return kSuccess;
// number of links
const int nlink = static_cast<int>(links.size());
// initialize the pointers
for (int i = 0; i < nlink; ++i) {
links[i].ResetSize();
}
std::vector<EdgeType> &edge_in = *p_edge_in;
std::vector<EdgeType> &edge_out = *p_edge_out;
edge_in.resize(nlink);
edge_out.resize(nlink);
// stages in the process
// 0: recv messages from childs
// 1: send message to parent
// 2: recv message from parent
// 3: send message to childs
int stage = 0;
// if no childs, no need to, directly start passing message
if (nlink == static_cast<int>(parent_index != -1)) {
utils::Assert(parent_index == 0, "parent must be 0");
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
}
// while we have not passed the messages out
while (true) {
// for node with no parent, directly do stage 3
if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// poll helper
utils::PollHelper watcher;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
watcher.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
watcher.WatchRead(links[i].sock);
}
break;
case 1:
if (i == parent_index) {
watcher.WatchWrite(links[i].sock);
}
break;
case 2:
if (i == parent_index) {
watcher.WatchRead(links[i].sock);
}
break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
watcher.WatchWrite(links[i].sock);
done = false;
}
break;
default: utils::Error("invalid stage");
}
}
// finish all the stages, and write out message
if (done) break;
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (stage == 0) {
bool finished = true;
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
if (links[i].size_read != sizeof(EdgeType)) finished = false;
}
}
// if no parent, jump to stage 3, otherwise do stage 1
if (finished) {
if (parent_index != -1) {
edge_out[parent_index] = func(node_value, edge_in, parent_index);
stage = 1;
} else {
for (int i = 0; i < nlink; ++i) {
edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
}
if (stage == 1) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[pid], ret);
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
}
if (stage == 2) {
const int pid = this->parent_index;
utils::Assert(pid != -1, "MsgPassing invalid stage");
ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[pid], ret);
if (links[pid].size_read == sizeof(EdgeType)) {
for (int i = 0; i < nlink; ++i) {
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
}
stage = 3;
}
}
if (stage == 3) {
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
}
}
}
return kSuccess;
}
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_ROBUST_INL_H_

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,672 @@
/*!
* Copyright (c) 2014 by Contributors
* \file allreduce_robust.h
* \brief Robust implementation of Allreduce
* using TCP non-block socket and tree-shape reduction.
*
* This implementation considers the failure of nodes
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#ifndef RABIT_ALLREDUCE_ROBUST_H_
#define RABIT_ALLREDUCE_ROBUST_H_
#include <future>
#include <vector>
#include <string>
#include <algorithm>
#include "rabit/internal/engine.h"
#include "allreduce_base.h"
namespace rabit {
namespace engine {
/*! \brief implementation of fault tolerant all reduce engine */
class AllreduceRobust : public AllreduceBase {
public:
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
// initialize the manager
virtual bool Init(int argc, char* argv[]);
/*! \brief shutdown the engine */
virtual bool Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*!
* \brief perform immutable local bootstrap cache insertion
* \param key unique cache key
* \param buf buffer of allreduce/robust payload to copy
* \param buflen total number of bytes
* \return -1 if no recovery cache fetched otherwise 0
*/
int SetBootstrapCache(const std::string &key, const void *buf,
const size_t type_nbytes, const size_t count);
/*!
* \brief perform bootstrap cache lookup if nodes in fault recovery
* \param key unique cache key
* \param buf buffer for recv allreduce/robust payload
* \param buflen total number of bytes
*/
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
const size_t count);
/*!
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
* the data provided by current node k is [slice_begin, slice_end),
* the next node's segment must start with slice_end
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
* use a ring based algorithm
*
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
* \param total_size total size of data to be gathered
* \param slice_begin beginning of the current slice
* \param slice_end end of the current slice
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
* \param prepare_arg argument used to passed into the lazy preprocessing function
* \param prepare_arg argument used to passed into the lazy preprocessing function
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun = NULL,
void *prepare_arg = NULL,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER);
/*!
* \brief load latest check point
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false);
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const Serializable *global_model) {
this->CheckPoint_(global_model, NULL, true);
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
}
ReConnectLinks("recover");
}
protected:
// constant one byte out of band message to indicate error happening
// and mark for channel cleanup
static const char kOOBReset = 95;
// and mark for channel cleanup, after OOB signal
static const char kResetMark = 97;
// and mark for channel cleanup
static const char kResetAck = 97;
/*! \brief type of roles each node can play during recovery */
enum RecoverType {
/*! \brief current node have data */
kHaveData = 0,
/*! \brief current node request data */
kRequestData = 1,
/*! \brief current node only helps to pass data around */
kPassData = 2
};
enum SeqType {
/*! \brief apply to rabit seq code */
kSeq = 0,
/*! \brief apply to rabit cache seq code */
kCache = 1
};
/*!
* \brief summary of actions proposed in all nodes
* this data structure is used to make consensus decision
* about next action to take in the recovery mode
*/
struct ActionSummary {
// maximumly allowed sequence id
static const u_int32_t kSpecialOp = (1 << 26);
// special sequence number for local state checkpoint
static const u_int32_t kLocalCheckPoint = (1 << 26) - 2;
// special sequnce number for local state checkpoint ack signal
static const u_int32_t kLocalCheckAck = (1 << 26) - 1;
//---------------------------------------------
// The following are bit mask of flag used in
//----------------------------------------------
// some node want to load check point
static const int kLoadCheck = 1;
// some node want to do check point
static const int kCheckPoint = 2;
// check point Ack, we use a two phase message in check point,
// this is the second phase of check pointing
static const int kCheckAck = 4;
// there are difference sequence number the nodes proposed
// this means we want to do recover execution of the lower sequence
// action instead of normal execution
static const int kDiffSeq = 8;
// there are nodes request load cache
static const int kLoadBootstrapCache = 16;
// constructor
ActionSummary(void) {}
// constructor of action
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
u_int32_t minseqno = kSpecialOp, u_int32_t maxseqno = kSpecialOp) {
seqcode = (minseqno << 5) | seqno_flag;
maxseqcode = (maxseqno << 5) | cache_flag;
}
// minimum number of all operations by default
// maximum number of all cache operations otherwise
inline u_int32_t seqno(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
return code >> 5;
}
// whether the operation set contains a load_check
inline bool load_check(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
return (code & kLoadCheck) != 0;
}
// whether the operation set contains a load_cache
inline bool load_cache(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
return (code & kLoadBootstrapCache) != 0;
}
// whether the operation set contains a check point
inline bool check_point(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
return (code & kCheckPoint) != 0;
}
// whether the operation set contains a check ack
inline bool check_ack(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
return (code & kCheckAck) != 0;
}
// whether the operation set contains different sequence number
inline bool diff_seq() const {
return (seqcode & kDiffSeq) != 0;
}
// returns the operation flag of the result
inline int flag(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
return code & 31;
}
// print flags in user friendly way
inline void print_flags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
rank, prefix.c_str(),
seqno(), check_point(), check_ack(), load_cache(),
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
}
// reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_,
int len, const MPI::Datatype &dtype) {
const ActionSummary *src = (const ActionSummary*)src_;
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
for (int i = 0; i < len; ++i) {
u_int32_t min_seqno = std::min(src[i].seqno(), dst[i].seqno());
u_int32_t max_seqno = std::max(src[i].seqno(SeqType::kCache),
dst[i].seqno(SeqType::kCache));
int action_flag = src[i].flag() | dst[i].flag();
// if any node is not requester set to 0 otherwise 1
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
// if seqno is different in src and destination
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
// apply or to both seq diff flag as well as cache seq diff flag
dst[i] = ActionSummary(action_flag | seq_diff_flag,
role_flag, min_seqno, max_seqno);
}
}
private:
// internel sequence code min of rabit seqno
u_int32_t seqcode;
// internal sequence code max of cache seqno
u_int32_t maxseqcode;
};
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
class ResultBuffer{
public:
// constructor
ResultBuffer(void) {
this->Clear();
}
// clear the existing record
inline void Clear(void) {
seqno_.clear(); size_.clear();
rptr_.clear(); rptr_.push_back(0);
data_.clear();
}
// allocate temporal space
inline void *AllocTemp(size_t type_nbytes, size_t count) {
size_t size = type_nbytes * count;
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
utils::Assert(nhop != 0, "cannot allocate 0 size memory");
// allocate addational nhop buffer size
data_.resize(rptr_.back() + nhop);
return BeginPtr(data_) + rptr_.back();
}
// push the result in temp to the
inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
size_t size = type_nbytes * count;
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
if (seqno_.size() != 0) {
utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent");
}
seqno_.push_back(seqid);
rptr_.push_back(rptr_.back() + nhop);
size_.push_back(size);
utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
}
// return the stored result of seqid, if any
inline void* Query(int seqid, size_t *p_size) {
size_t idx = std::lower_bound(seqno_.begin(),
seqno_.end(), seqid) - seqno_.begin();
if (idx == seqno_.size() || seqno_[idx] != seqid) return NULL;
*p_size = size_[idx];
return BeginPtr(data_) + rptr_[idx];
}
// drop last stored result
inline void DropLast(void) {
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
seqno_.pop_back();
rptr_.pop_back();
size_.pop_back();
data_.resize(rptr_.back());
}
// the sequence number of last stored result
inline int LastSeqNo(void) const {
if (seqno_.size() == 0) return -1;
return seqno_.back();
}
private:
// sequence number of each
std::vector<int> seqno_;
// pointer to the positions
std::vector<size_t> rptr_;
// actual size of each buffer
std::vector<size_t> size_;
// content of the buffer
std::vector<uint64_t> data_;
};
/*!
* \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint
*
* \param with_local whether the user calls CheckPoint with local model
*/
void LocalModelCheck(bool with_local);
/*!
* \brief internal implementation of checkpoint, support both lazy and normal way
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
* \param lazy_checkpt whether the action is lazy checkpoint
*
* \sa CheckPoint, LazyCheckPoint
*/
void CheckPoint_(const Serializable *global_model,
const Serializable *local_model,
bool lazy_checkpt);
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent
* before in all live links are discarded,
* This allows us to get a fresh start after error has happened
*
* TODO(tqchen): this function is not yet functioning was not used by engine,
* simple resetlink and reconnect strategy is used
*
* \return this function can return kSuccess or kSockError
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
*/
ReturnType TryResetLinks(void);
/*!
* \brief if err_type indicates an error
* recover links according to the error type reported
* if there is no error, return true
* \param err_type the type of error happening in the system
* \return true if err_type is kSuccess, false otherwise
*/
bool CheckAndRecover(ReturnType err_type);
/*!
* \brief try to run recover execution for a request action described by flag and seqno,
* the function will keep blocking to run possible recovery operations before the specified action,
* until the requested result is received by a recovering procedure,
* or the function discovers that the requested action is not yet executed, and return false
*
* \param buf the buffer to store the result
* \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kSpecialOp
*
* \return if this function can return true or false
* - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action
*/
bool RecoverExec(void *buf, size_t size, int flag,
int seqno = ActionSummary::kSpecialOp,
int cacheseqno = ActionSummary::kSpecialOp,
const char* caller = _CALLER);
/*!
* \brief try to load check point
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to load the check point
* other nodes acts as collaborative roles to complete this request
*
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryLoadCheckPoint(bool requester);
/*!
* \brief try to load cache
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to load the check point
* other nodes acts as collaborative roles to complete this request
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryRestoreCache(bool requester, const int min_seq = ActionSummary::kSpecialOp,
const int max_seq = ActionSummary::kSpecialOp);
/*!
* \brief try to get the result of operation specified by seqno
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to get the result
* other nodes acts as collaborative roles to complete this request
*
* \param buf the buffer to store the result, this parameter is only used when current node is requester
* \param size the total size of the buffer, this parameter is only used when current node is requester
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
/*!
* \brief try to decide the routing strategy for recovery
* \param role the current role of the node
* \param p_size used to store the size of the message, for node in state kHaveData,
* this size must be set correctly before calling the function
* for others, this surves as output parameter
* \param p_recvlink used to store the link current node should recv data from, if necessary
* this can be -1, which means current node have the data
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
*
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryRecoverData
*/
ReturnType TryDecideRouting(RecoverType role,
size_t *p_size,
int *p_recvlink,
std::vector<bool> *p_req_in);
/*!
* \brief try to finish the data recovery request,
* this function is used together with TryDecideRouting
* \param role the current role of the node
* \param sendrecvbuf_ the buffer to store the data to be sent/recived
* - if the role is kHaveData, this stores the data to be sent
* - if the role is kRequestData, this is the buffer to store the result
* - if the role is kPassData, this will not be used, and can be NULL
* \param size the size of the data, obtained from TryDecideRouting
* \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting
* \param req_in the request of each link to send data, obtained from TryDecideRouting
*
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryDecideRouting
*/
ReturnType TryRecoverData(RecoverType role,
void *sendrecvbuf_,
size_t size,
int recv_link,
const std::vector<bool> &req_in);
/*!
* \brief try to recover the local state, making each local state to be the result of itself
* plus replication of states in previous num_local_replica hops in the ring
*
* The input parameters must contain the valid local states available in current nodes,
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
* If there is sufficient information in the ring, when the function returns, local_chkpt will
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
* If there is no sufficient information in the ring, this function the number of checkpoints
* will be less than the specified value
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt);
/*!
* \brief try to checkpoint local state, this function is called in normal executation phase
* of checkpoint that contains local state
o * the input state must exactly one saved state(local state of current node),
* after complete, this function will get local state from previous num_local_replica nodes and put them
* into local_chkpt and local_rptr
*
* It is also OK to call TryRecoverLocalState instead,
* TryRecoverLocalState makes less assumption about the input, and requires more communications
*
* \param p_local_rptr the pointer to the segment pointers in the states array
* \param p_local_chkpt the pointer to the storage of local check points
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType, TryRecoverLocalState
*/
ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
std::string *p_local_chkpt);
/*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure
* sendrecvbuf[0:read_ptr] are already provided by current node
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
* current node will send sendrecvbuf[write_ptr:write_end] to next link
* write_ptr will wait till the data is readed before sending the data
* this function requires read_end >= write_end
*
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
* \param read_ptr the initial read pointer
* \param read_end the ending position to read
* \param write_ptr the initial write pointer
* \param write_end the ending position to write
* \param read_link pointer to link to previous position in ring
* \param write_link pointer to link of next position in ring
*/
ReturnType RingPassing(void *senrecvbuf_,
size_t read_ptr,
size_t read_end,
size_t write_ptr,
size_t write_end,
LinkRecord *read_link,
LinkRecord *write_link);
/*!
* \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out
* \param node_value the value associated with current node
* \param p_edge_in used to store input message from each of the edge
* \param p_edge_out used to store output message from each of the edge
* \param func a function that defines the message passing rule
* Parameters of func:
* - node_value same as node_value in the main function
* - edge_in the array of input messages from each edge,
* this includes the output edge, which should be excluded
* - out_index array the index of output edge, the function should
* exclude the output edge when compute the message passing value
* Return of func:
* the function returns the output message based on the input message and node_value
*
* \tparam EdgeType type of edge message, must be simple struct
* \tparam NodeType type of node value
*/
template<typename NodeType, typename EdgeType>
inline ReturnType MsgPassing(const NodeType &node_value,
std::vector<EdgeType> *p_edge_in,
std::vector<EdgeType> *p_edge_out,
EdgeType(*func)
(const NodeType &node_value,
const std::vector<EdgeType> &edge_in,
size_t out_index));
//---- recovery data structure ----
// the round of result buffer, used to mode the result
int result_buffer_round;
// result buffer of all reduce
ResultBuffer resbuf;
// current cached allreduce/braodcast sequence number
int cur_cache_seq;
// result buffer of cached all reduce
ResultBuffer cachebuf;
// key of each cache entry
ResultBuffer lookupbuf;
// last check point global model
std::string global_checkpoint;
// lazy checkpoint of global model
const Serializable *global_lazycheck;
// number of replica for local state/model
int num_local_replica;
// number of default local replica
int default_local_replica;
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
int use_local_model;
// number of replica for global state/model
int num_global_replica;
// number of times recovery happens
int recover_counter;
// --- recovery data structure for local checkpoint
// there is two version of the data structure,
// at one time one version is valid and another is used as temp memory
// pointer to memory position in the local model
// local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
std::vector<size_t> local_rptr[2];
// storage for local model replicas
std::string local_chkpt[2];
// version of local checkpoint can be 1 or 0
int local_chkpt_version;
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
bool checkpoint_loaded;
// sidecar executing timeout task
std::future<bool> rabit_timeout_task;
// flag to shutdown rabit_timeout_task before timeout
std::atomic<bool> shutdown_timeout{false};
// error handler
void (* _error)(const char *fmt, ...) = utils::Error;
// assert handler
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
};
} // namespace engine
} // namespace rabit
// implementation of inline template function
#include "./allreduce_robust-inl.h"
#endif // RABIT_ALLREDUCE_ROBUST_H_

333
rabit/src/c_api.cc Normal file
View File

@@ -0,0 +1,333 @@
// Copyright by Contributors
// implementations in ctypes
#include <rabit/base.h>
#include <cstring>
#include <string>
#include "rabit/rabit.h"
#include "rabit/c_api.h"
namespace rabit {
namespace c_api {
// helper use to avoid BitOR operator
template<typename OP, typename DType>
struct FHelper {
static void
Allreduce(DType *senrecvbuf_,
size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
rabit::Allreduce<OP>(senrecvbuf_, count,
prepare_fun, prepare_arg);
}
};
template<typename DType>
struct FHelper<op::BitOR, DType> {
static void
Allreduce(DType *senrecvbuf_,
size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
utils::Error("DataType does not support bitwise or operation");
}
};
template<typename OP>
void Allreduce_(void *sendrecvbuf_,
size_t count,
engine::mpi::DataType enum_dtype,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
switch (enum_dtype) {
case kChar:
rabit::Allreduce<OP>
(static_cast<char*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kUChar:
rabit::Allreduce<OP>
(static_cast<unsigned char*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kInt:
rabit::Allreduce<OP>
(static_cast<int*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kUInt:
rabit::Allreduce<OP>
(static_cast<unsigned*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kLong:
rabit::Allreduce<OP>
(static_cast<long*>(sendrecvbuf_), // NOLINT(*)
count, prepare_fun, prepare_arg);
return;
case kULong:
rabit::Allreduce<OP>
(static_cast<unsigned long*>(sendrecvbuf_), // NOLINT(*)
count, prepare_fun, prepare_arg);
return;
case kFloat:
FHelper<OP, float>::Allreduce
(static_cast<float*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
case kDouble:
FHelper<OP, double>::Allreduce
(static_cast<double*>(sendrecvbuf_),
count, prepare_fun, prepare_arg);
return;
default: utils::Error("unknown data_type");
}
}
void Allreduce(void *sendrecvbuf,
size_t count,
engine::mpi::DataType enum_dtype,
engine::mpi::OpType enum_op,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
using namespace engine::mpi;
switch (enum_op) {
case kMax:
Allreduce_<op::Max>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kMin:
Allreduce_<op::Min>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kSum:
Allreduce_<op::Sum>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseOR:
Allreduce_<op::BitOR>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
default: utils::Error("unknown enum_op");
}
}
void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t beginIndex,
size_t size_node_slice,
size_t size_prev_slice,
int enum_dtype) {
using namespace engine::mpi;
size_t type_size = 0;
switch (enum_dtype) {
case kChar:
type_size = sizeof(char);
rabit::Allgather(static_cast<char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kUChar:
type_size = sizeof(unsigned char);
rabit::Allgather(static_cast<unsigned char*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kInt:
type_size = sizeof(int);
rabit::Allgather(static_cast<int*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kUInt:
type_size = sizeof(unsigned);
rabit::Allgather(static_cast<unsigned*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kLong:
type_size = sizeof(int64_t);
rabit::Allgather(static_cast<int64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kULong:
type_size = sizeof(uint64_t);
rabit::Allgather(static_cast<uint64_t*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kFloat:
type_size = sizeof(float);
rabit::Allgather(static_cast<float*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
case kDouble:
type_size = sizeof(double);
rabit::Allgather(static_cast<double*>(sendrecvbuf_), total_size * type_size,
beginIndex * type_size, (beginIndex + size_node_slice) * type_size,
size_prev_slice * type_size);
break;
default: utils::Error("unknown data_type");
}
}
// wrapper for serialization
struct ReadWrapper : public Serializable {
std::string *p_str;
explicit ReadWrapper(std::string *p_str)
: p_str(p_str) {}
virtual void Load(Stream *fi) {
uint64_t sz;
utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
"Read pickle string");
p_str->resize(sz);
if (sz != 0) {
utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0,
"Read pickle string");
}
}
virtual void Save(Stream *fo) const {
utils::Error("not implemented");
}
};
struct WriteWrapper : public Serializable {
const char *data;
size_t length;
explicit WriteWrapper(const char *data,
size_t length)
: data(data), length(length) {
}
virtual void Load(Stream *fi) {
utils::Error("not implemented");
}
virtual void Save(Stream *fo) const {
uint64_t sz = static_cast<uint16_t>(length);
fo->Write(&sz, sizeof(sz));
fo->Write(data, length * sizeof(char));
}
};
} // namespace c_api
} // namespace rabit
RABIT_DLL bool RabitInit(int argc, char *argv[]) {
return rabit::Init(argc, argv);
}
RABIT_DLL bool RabitFinalize() {
return rabit::Finalize();
}
RABIT_DLL int RabitGetRingPrevRank() {
return rabit::GetRingPrevRank();
}
RABIT_DLL int RabitGetRank() {
return rabit::GetRank();
}
RABIT_DLL int RabitGetWorldSize() {
return rabit::GetWorldSize();
}
RABIT_DLL int RabitIsDistributed() {
return rabit::IsDistributed();
}
RABIT_DLL void RabitTrackerPrint(const char *msg) {
std::string m(msg);
rabit::TrackerPrint(m);
}
RABIT_DLL void RabitGetProcessorName(char *out_name,
rbt_ulong *out_len,
rbt_ulong max_len) {
std::string s = rabit::GetProcessorName();
if (s.length() > max_len) {
s.resize(max_len - 1);
}
strcpy(out_name, s.c_str()); // NOLINT(*)
*out_len = static_cast<rbt_ulong>(s.length());
}
RABIT_DLL void RabitBroadcast(void *sendrecv_data,
rbt_ulong size, int root) {
rabit::Broadcast(sendrecv_data, size, root);
}
RABIT_DLL void RabitAllgather(void *sendrecvbuf_, size_t total_size,
size_t beginIndex, size_t size_node_slice,
size_t size_prev_slice, int enum_dtype) {
rabit::c_api::Allgather(sendrecvbuf_,
total_size,
beginIndex,
size_node_slice,
size_prev_slice,
static_cast<rabit::engine::mpi::DataType>(enum_dtype));
}
RABIT_DLL void RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype,
int enum_op, void (*prepare_fun)(void *arg),
void *prepare_arg) {
rabit::c_api::Allreduce
(sendrecvbuf, count,
static_cast<rabit::engine::mpi::DataType>(enum_dtype),
static_cast<rabit::engine::mpi::OpType>(enum_op),
prepare_fun, prepare_arg);
}
RABIT_DLL int RabitLoadCheckPoint(char **out_global_model,
rbt_ulong *out_global_len,
char **out_local_model,
rbt_ulong *out_local_len) {
// NOTE: this function is not thread-safe
using rabit::BeginPtr;
using namespace rabit::c_api; // NOLINT(*)
static std::string global_buffer;
static std::string local_buffer;
ReadWrapper sg(&global_buffer);
ReadWrapper sl(&local_buffer);
int version;
if (out_local_model == NULL) {
version = rabit::LoadCheckPoint(&sg, NULL);
*out_global_model = BeginPtr(global_buffer);
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
} else {
version = rabit::LoadCheckPoint(&sg, &sl);
*out_global_model = BeginPtr(global_buffer);
*out_global_len = static_cast<rbt_ulong>(global_buffer.length());
*out_local_model = BeginPtr(local_buffer);
*out_local_len = static_cast<rbt_ulong>(local_buffer.length());
}
return version;
}
RABIT_DLL void RabitCheckPoint(const char *global_model, rbt_ulong global_len,
const char *local_model, rbt_ulong local_len) {
using namespace rabit::c_api; // NOLINT(*)
WriteWrapper sg(global_model, global_len);
WriteWrapper sl(local_model, local_len);
if (local_model == NULL) {
rabit::CheckPoint(&sg, NULL);
} else {
rabit::CheckPoint(&sg, &sl);
}
}
RABIT_DLL int RabitVersionNumber() {
return rabit::VersionNumber();
}
RABIT_DLL int RabitLinkTag() {
return 0;
}

142
rabit/src/engine.cc Normal file
View File

@@ -0,0 +1,142 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine.cc
* \brief this file governs which implementation of engine we are actually using
* provides an singleton of engine interface
*
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
*/
#include <rabit/base.h>
#include <memory>
#include "rabit/internal/engine.h"
#include "allreduce_base.h"
#include "allreduce_robust.h"
#include "rabit/internal/thread_local.h"
namespace rabit {
namespace engine {
// singleton sync manager
#ifndef RABIT_USE_BASE
#ifndef RABIT_USE_MOCK
typedef AllreduceRobust Manager;
#else
typedef AllreduceMock Manager;
#endif // RABIT_USE_MOCK
#else
typedef AllreduceBase Manager;
#endif // RABIT_USE_BASE
/*! \brief entry to to easily hold returning information */
struct ThreadLocalEntry {
/*! \brief stores the current engine */
std::unique_ptr<Manager> engine;
/*! \brief whether init has been called */
bool initialized;
/*! \brief constructor */
ThreadLocalEntry() : initialized(false) {}
};
// define the threadlocal store.
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
/*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) {
ThreadLocalEntry* e = EngineThreadLocal::Get();
if (e->engine.get() == nullptr) {
e->initialized = true;
e->engine.reset(new Manager());
return e->engine->Init(argc, argv);
} else {
return true;
}
}
/*! \brief finalize syncrhonization module */
bool Finalize() {
ThreadLocalEntry* e = EngineThreadLocal::Get();
if (e->engine.get() != nullptr) {
if (e->engine->Shutdown()) {
e->engine.reset(nullptr);
e->initialized = false;
return true;
} else {
return false;
}
} else {
return true;
}
}
/*! \brief singleton method to get engine */
IEngine *GetEngine() {
// un-initialized default manager.
static AllreduceBase default_manager;
ThreadLocalEntry* e = EngineThreadLocal::Get();
IEngine* ptr = e->engine.get();
if (ptr == nullptr) {
utils::Check(!e->initialized, "the rabit has not been initialized");
return &default_manager;
} else {
return ptr;
}
}
// perform in-place allgather, on sendrecvbuf
void Allgather(void *sendrecvbuf_, size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin,
slice_end, size_prev_slice, _file, _line, _caller);
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun,
prepare_arg, _file, _line, _caller);
}
// code for reduce handle
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return static_cast<int>(dtype.type_size);
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(redfunc_ == NULL, "cannot initialize reduce handle twice");
redfunc_ = redfunc;
}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
utils::Assert(redfunc_ != NULL, "must intialize handle to call AllReduce");
GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count,
redfunc_, prepare_fun, prepare_arg,
_file, _line, _caller);
}
} // namespace engine
} // namespace rabit

14
rabit/src/engine_base.cc Normal file
View File

@@ -0,0 +1,14 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure
* \author Tianqi Chen
*/
// define use MOCK, os we will use mock Manager
#define NOMINMAX
#include <rabit/base.h>
// switch engine to AllreduceMock
#define RABIT_USE_BASE
#include "engine.cc"

143
rabit/src/engine_empty.cc Normal file
View File

@@ -0,0 +1,143 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_empty.cc
* \brief this file provides a dummy implementation of engine that does nothing
* this file provides a way to fall back to single node program without causing too many dependencies
* This is usually NOT needed, use engine_mpi or engine for real distributed version
* \author Tianqi Chen
*/
#define NOMINMAX
#include <rabit/base.h>
#include "rabit/internal/engine.h"
namespace rabit {
namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
public:
EmptyEngine(void) {
version_number = 0;
}
virtual void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("EmptyEngine:: Allgather is not supported");
}
virtual int GetRingPrevRank(void) const {
utils::Error("EmptyEngine:: GetRingPrevRank is not supported");
return -1;
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("EmptyEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line, const char* _caller) {
}
virtual void InitAfterException(void) {
utils::Error("EmptyEngine is not fault tolerant");
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return 0;
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return 1;
}
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
return false;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
return std::string("");
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
utils::Printf("%s", msg.c_str());
}
private:
int version_number;
};
// singleton sync manager
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
bool Finalize(void) {
return true;
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
}
// code for reduce handle
ReduceHandle::ReduceHandle(void) : handle_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return 0;
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
}
} // namespace engine
} // namespace rabit

15
rabit/src/engine_mock.cc Normal file
View File

@@ -0,0 +1,15 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mock.cc
* \brief this is an engine implementation that will
* insert failures in certain call point, to test if the engine is robust to failure
* \author Tianqi Chen
*/
// define use MOCK, os we will use mock Manager
#define NOMINMAX
// switch engine to AllreduceMock
#define RABIT_USE_MOCK
#include <rabit/base.h>
#include "allreduce_mock.h"
#include "engine.cc"

247
rabit/src/engine_mpi.cc Normal file
View File

@@ -0,0 +1,247 @@
/*!
* Copyright (c) 2014 by Contributors
* \file engine_mpi.cc
* \brief this file gives an implementation of engine interface using MPI,
* this will allow rabit program to run with MPI, but do not comes with fault tolerant
*
* \author Tianqi Chen
*/
#define NOMINMAX
#include <mpi.h>
#include <rabit/base.h>
#include <cstdio>
#include <string>
#include "rabit/internal/engine.h"
#include "rabit/internal/utils.h"
namespace rabit {
namespace engine {
/*! \brief implementation of engine using MPI */
class MPIEngine : public IEngine {
public:
MPIEngine(void) {
version_number = 0;
}
virtual void Allgather(void *sendrecvbuf_,
size_t total_size,
size_t slice_begin,
size_t slice_end,
size_t size_prev_slice,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("MPIEngine:: Allgather is not supported");
}
virtual void Allreduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
utils::Error("MPIEngine:: Allreduce is not supported,"\
"use Allreduce_ instead");
}
virtual int GetRingPrevRank(void) const {
utils::Error("MPIEngine:: GetRingPrevRank is not supported");
}
virtual void Broadcast(void *sendrecvbuf_, size_t size, int root,
const char* _file, const int _line,
const char* _caller) {
MPI::COMM_WORLD.Bcast(sendrecvbuf_, size, MPI::CHAR, root);
}
virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant");
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}
/*! \brief get rank of current node */
virtual int GetRank(void) const {
return MPI::COMM_WORLD.Get_rank();
}
/*! \brief get total number of */
virtual int GetWorldSize(void) const {
return MPI::COMM_WORLD.Get_size();
}
/*! \brief whether it is distributed */
virtual bool IsDistributed(void) const {
return true;
}
/*! \brief get the host name of current node */
virtual std::string GetHost(void) const {
int len;
char name[MPI_MAX_PROCESSOR_NAME];
MPI::Get_processor_name(name, len);
name[len] = '\0';
return std::string(name);
}
virtual void TrackerPrint(const std::string &msg) {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
}
private:
int version_number;
};
// singleton sync manager
MPIEngine manager;
/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
MPI::Init(argc, argv);
return true;
} catch (const std::exception& e) {
fprintf(stderr, " failed in MPI Init %s\n", e.what());
return false;
}
}
/*! \brief finalize syncrhonization module */
bool Finalize(void) {
try {
MPI::Finalize();
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
return false;
}
}
/*! \brief singleton method to get engine */
IEngine *GetEngine(void) {
return &manager;
}
// transform enum to MPI data type
inline MPI::Datatype GetType(mpi::DataType dtype) {
using namespace mpi;
switch (dtype) {
case kChar: return MPI::CHAR;
case kUChar: return MPI::BYTE;
case kInt: return MPI::INT;
case kUInt: return MPI::UNSIGNED;
case kLong: return MPI::LONG;
case kULong: return MPI::UNSIGNED_LONG;
case kFloat: return MPI::FLOAT;
case kDouble: return MPI::DOUBLE;
case kLongLong: return MPI::LONG_LONG;
case kULongLong: return MPI::UNSIGNED_LONG_LONG;
}
utils::Error("unknown mpi::DataType");
return MPI::CHAR;
}
// transform enum to MPI OP
inline MPI::Op GetOp(mpi::OpType otype) {
using namespace mpi;
switch (otype) {
case kMax: return MPI::MAX;
case kMin: return MPI::MIN;
case kSum: return MPI::SUM;
case kBitwiseOR: return MPI::BOR;
}
utils::Error("unknown mpi::OpType");
return MPI::MAX;
}
// perform in-place allreduce, on sendrecvbuf
void Allreduce_(void *sendrecvbuf,
size_t type_nbytes,
size_t count,
IEngine::ReduceFunction red,
mpi::DataType dtype,
mpi::OpType op,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf,
count, GetType(dtype), GetOp(op));
}
// code for reduce handle
ReduceHandle::ReduceHandle(void)
: handle_(NULL), redfunc_(NULL), htype_(NULL) {
}
ReduceHandle::~ReduceHandle(void) {
if (handle_ != NULL) {
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
op->Free();
delete op;
}
if (htype_ != NULL) {
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
dtype->Free();
delete dtype;
}
}
int ReduceHandle::TypeSize(const MPI::Datatype &dtype) {
return dtype.Get_size();
}
void ReduceHandle::Init(IEngine::ReduceFunction redfunc, size_t type_nbytes) {
utils::Assert(handle_ == NULL, "cannot initialize reduce handle twice");
if (type_nbytes != 0) {
MPI::Datatype *dtype = new MPI::Datatype();
if (type_nbytes % 8 == 0) {
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
} else if (type_nbytes % 4 == 0) {
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
} else {
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
}
dtype->Commit();
created_type_nbytes_ = type_nbytes;
htype_ = dtype;
}
MPI::Op *op = new MPI::Op();
MPI::User_function *pf = redfunc;
op->Init(pf, true);
handle_ = op;
}
void ReduceHandle::Allreduce(void *sendrecvbuf,
size_t type_nbytes, size_t count,
IEngine::PreprocFunction prepare_fun,
void *prepare_arg,
const char* _file,
const int _line,
const char* _caller) {
utils::Assert(handle_ != NULL, "must intialize handle to call AllReduce");
MPI::Op *op = reinterpret_cast<MPI::Op*>(handle_);
MPI::Datatype *dtype = reinterpret_cast<MPI::Datatype*>(htype_);
if (created_type_nbytes_ != type_nbytes || dtype == NULL) {
if (dtype == NULL) {
dtype = new MPI::Datatype();
} else {
dtype->Free();
}
if (type_nbytes % 8 == 0) {
*dtype = MPI::LONG.Create_contiguous(type_nbytes / sizeof(long)); // NOLINT(*)
} else if (type_nbytes % 4 == 0) {
*dtype = MPI::INT.Create_contiguous(type_nbytes / sizeof(int));
} else {
*dtype = MPI::CHAR.Create_contiguous(type_nbytes);
}
dtype->Commit();
created_type_nbytes_ = type_nbytes;
}
if (prepare_fun != NULL) prepare_fun(prepare_arg);
MPI::COMM_WORLD.Allreduce(MPI_IN_PLACE, sendrecvbuf, count, *dtype, *op);
}
} // namespace engine
} // namespace rabit