exit when allreduce/broadcast error cause timeout (#112)

* keep async timeout task

* add missing pthread to cmake

* add tests

* Add a sleep period to avoid flushing the tracker.
This commit is contained in:
Chen Qin 2019-10-11 00:39:39 -07:00 committed by Jiaming Yuan
parent af7281afe3
commit 5d1b613910
17 changed files with 403 additions and 71 deletions

3
.gitignore vendored
View File

@ -47,3 +47,6 @@ mpich-3.2/
cmake-build-debug/
.vscode/
# cmake
build/

View File

@ -8,16 +8,29 @@ osx_image: xcode10.2
dist: xenial
language: cpp
# Use Build Matrix to do lint and build seperately
env:
matrix:
- TASK=lint LINT_LANG=cpp
- TASK=lint LINT_LANG=python
- TASK=doc
- TASK=build
# - TASK=build
- TASK=mpi-build
- TASK=cmake-test
matrix:
exclude:
- os: osx
env: TASK=lint LINT_LANG=cpp
- os: osx
env: TASK=lint LINT_LANG=python
- os: osx
env: TASK=doc
- os: osx
env: TASK=build
# dependent apt packages
addons:
apt:

View File

@ -20,11 +20,16 @@ if(R_LIB OR MINGW OR WIN32)
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)
else()
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc)
find_package(Threads REQUIRED)
add_library(rabit_empty src/engine_empty.cc src/c_api.cc)
add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc)
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
target_link_libraries(rabit Threads::Threads)
target_link_libraries(rabit_mock_static Threads::Threads)
target_link_libraries(rabit_mock Threads::Threads)
set(rabit_libs rabit rabit_base rabit_empty rabit_mock rabit_mock_static)
set_target_properties(rabit rabit_base rabit_empty rabit_mock rabit_mock_static

View File

@ -154,6 +154,8 @@ you can also refer to [wormhole](https://github.com/dmlc/wormhole/blob/master/le
int main(int argc, char *argv[]) {
...
rabit::Init(argc, argv);
// sync on expected model size before load checkpoint, if we pass rabit_bootstrap_cache=true
rabit::Allreduce<rabit::op::Max>(&model.size(), 1);
// load the latest checked model
int version = rabit::LoadCheckPoint(&model);
// initialize the model if it is the first version
@ -370,3 +372,12 @@ Allreduce/Broadcast calls after the checkpoint from some alive nodes.
This is just a conceptual introduction to rabit's fault tolerance model. The actual implementation is more sophisticated,
and can deal with more complicated cases such as multiple nodes failure and node failure during recovery phase.
Rabit Timeout
---------------
In certain cases, rabit cluster may suffer lack of resources to retry failed workers.
Thanks to fault tolerant assumption with infinite retry, it might cause entire cluster hang infinitely.
We introduce sidecar thread which runs when rabit fault tolerant runtime observed allreduce/broadcast errors.
By default, it will wait for 30 mins before all workers program exit.
User can opt-in this feature and change treshold by passing rabit_timeout=true and rabit_timeout_sec=x (in seconds).

View File

@ -7,6 +7,7 @@
#ifndef RABIT_INTERNAL_UTILS_H_
#define RABIT_INTERNAL_UTILS_H_
#define _CRT_SECURE_NO_WARNINGS
#include <string.h>
#include <cstdio>
#include <string>
#include <cstdlib>
@ -66,6 +67,11 @@ const int kPrintBuffer = 1 << 12;
* co-locate in the same process */
extern bool STOP_PROCESS_ON_ERROR;
/* \brief parse config string too bool*/
inline bool StringToBool(const char* s) {
return strcasecmp(s, "true") == 0 || atoi(s) != 0;
}
#ifndef RABIT_CUSTOMIZE_MSG_
/*!
* \brief handling of Assert error, caused by inappropriate input
@ -86,7 +92,7 @@ inline void HandleAssertError(const char *msg) {
*/
inline void HandleCheckError(const char *msg) {
if (STOP_PROCESS_ON_ERROR) {
fprintf(stderr, "%s, shutting down process", msg);
fprintf(stderr, "%s, shutting down process\n", msg);
exit(-1);
} else {
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);

View File

@ -25,8 +25,9 @@ if [ ${TASK} == "cmake-test" ]; then
mkdir build
cd build
cmake -DRABIT_BUILD_TESTS=ON -DRABIT_BUILD_DMLC=ON -DGTEST_ROOT=${HOME}/.local ..
#unit tests
make
# known osx gtest 1.8 issue
cp ${HOME}/.local/lib/*.dylib .
make -j$(nproc)
make test
make install || exit -1
cd ../test

View File

@ -73,7 +73,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
if (task_id == NULL) {
task_id = getenv("mapreduce_task_id");
}
if (hadoop_mode != 0) {
if (hadoop_mode) {
utils::Check(task_id != NULL,
"hadoop_mode is set but cannot find mapred_task_id");
}
@ -94,7 +94,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
if (num_task == NULL) {
num_task = getenv("mapreduce_job_maps");
}
if (hadoop_mode != 0) {
if (hadoop_mode) {
utils::Check(num_task != NULL,
"hadoop_mode is set but cannot find mapred_map_tasks");
}
@ -188,7 +188,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val);
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
reduce_ring_mincount = atoi(val);
utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
@ -209,10 +209,17 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
}
if (!strcmp(name, "rabit_bootstrap_cache")) {
rabit_bootstrap_cache = atoi(val);
rabit_bootstrap_cache = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_debug")) {
rabit_debug = atoi(val);
rabit_debug = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_timeout")) {
rabit_timeout = utils::StringToBool(val);
}
if (!strcmp(name, "rabit_timeout_sec")) {
timeout_sec = atoi(val);
utils::Assert(rabit_timeout > 0, "rabit_timeout_sec should be greater than 0 second");
}
}
/*!

View File

@ -496,7 +496,7 @@ class AllreduceBase : public IEngine {
// version number of model
int version_number;
// whether the job is running in hadoop
int hadoop_mode;
bool hadoop_mode;
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index;
@ -543,9 +543,13 @@ class AllreduceBase : public IEngine {
// backdoor port
int port = 0;
// enable bootstrap cache 0 false 1 true
int rabit_bootstrap_cache = 0;
bool rabit_bootstrap_cache = false;
// enable detailed logging
int rabit_debug = 0;
bool rabit_debug = false;
// by default, if rabit worker not recover in half an hour exit
int timeout_sec = 1800;
// flag to enable rabit_timeout
bool rabit_timeout = false;
};
} // namespace engine
} // namespace rabit

View File

@ -176,7 +176,7 @@ class AllreduceMock : public AllreduceRobust {
if (mock_map.count(key) != 0) {
num_trial += 1;
// data processing frameworks runs on shared process
utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name);
_error("[%d]@@@Hit Mock Error:%s ", rank, name);
}
}
};

View File

@ -8,6 +8,8 @@
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <chrono>
#include <thread>
#include <limits>
#include <utility>
#include "rabit/internal/io.h"
@ -19,6 +21,7 @@
namespace rabit {
namespace engine {
AllreduceRobust::AllreduceRobust(void) {
num_local_replica = 0;
num_global_replica = 5;
@ -38,7 +41,7 @@ bool AllreduceRobust::Init(int argc, char* argv[]) {
if (AllreduceBase::Init(argc, argv)) {
// chenqin: alert user opted in experimental feature.
if (rabit_bootstrap_cache) utils::HandleLogInfo(
"[EXPERIMENTAL] rabit bootstrap cache has been enabled\n");
"[EXPERIMENTAL] bootstrap cache has been enabled\n");
checkpoint_loaded = false;
if (num_global_replica == 0) {
result_buffer_round = -1;
@ -55,24 +58,31 @@ bool AllreduceRobust::Shutdown(void) {
try {
// need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp,
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp,
cur_cache_seq), "Shutdown: check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
cachebuf.Clear(); cur_cache_seq = 0;
lookupbuf.Clear();
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check ack must return true");
// travis ci only osx test hang
#if defined (__APPLE__)
sleep(1);
#endif
shutdown_timeout = true;
if (rabit_timeout_task.valid()) {
rabit_timeout_task.wait();
_assert(rabit_timeout_task.get(), "expect timeout task return\n");
}
return AllreduceBase::Shutdown();
} catch (const std::exception& e) {
fprintf(stderr, "%s\n", e.what());
return false;
}
}
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -98,8 +108,8 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
break;
}
}
utils::Assert(index == -1, "immutable cache key already exists");
utils::Assert(type_nbytes*count > 0, "can't set empty cache");
_assert(index == -1, "immutable cache key already exists");
_assert(type_nbytes*count > 0, "can't set empty cache");
void* temp = cachebuf.AllocTemp(type_nbytes, count);
cachebuf.PushTemp(cur_cache_seq, type_nbytes, count);
std::memcpy(temp, buf, type_nbytes*count);
@ -133,9 +143,9 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
size_t siz = 0;
void* temp = cachebuf.Query(index, &siz);
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
utils::Assert(siz > 0, "cache size should be greater than 0");
_assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
_assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
_assert(siz > 0, "cache size should be greater than 0");
std::memcpy(buf, temp, type_nbytes*count);
return 0;
}
@ -317,7 +327,7 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
local_rptr[local_chkpt_version][1]);
local_model->Load(&fs);
} else {
utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
_assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
}
}
// reset result buffer
@ -327,14 +337,14 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
if (global_checkpoint.length() == 0) {
version_number = 0;
} else {
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
_assert(fs.Read(&version_number, sizeof(version_number)) != 0,
"read in version number");
global_model->Load(&fs);
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
_assert(local_model == NULL || nlocal == num_local_replica + 1,
"local model inconsistent, nlocal=%d", nlocal);
}
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) {
@ -433,7 +443,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
local_chkpt_version = !local_chkpt_version;
}
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
ActionSummary::kSpecialOp, cur_cache_seq),
"check point must return true");
// this is the critical region where we will change all the stored models
@ -460,7 +470,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
// reset result buffer, mark boostrap phase complete
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
delta = utils::GetTime() - start;
@ -533,7 +543,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (all_links[i].size_read == 0) {
int atmark = all_links[i].sock.AtMark();
if (atmark < 0) {
utils::Assert(all_links[i].sock.BadSocket(), "must already gone bad");
_assert(all_links[i].sock.BadSocket(), "must already gone bad");
} else if (atmark > 0) {
all_links[i].size_read = 1;
} else {
@ -555,10 +565,10 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (len == 0) {
all_links[i].sock.Close(); continue;
} else if (len > 0) {
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
utils::Assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
_assert(oob_mark == kResetMark, "wrong oob msg");
_assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
} else {
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
_assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
}
// send out ack
char ack = kResetAck;
@ -579,9 +589,9 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (len == 0) {
all_links[i].sock.Close(); continue;
} else if (len > 0) {
utils::Assert(ack == kResetAck, "wrong Ack MSG");
_assert(ack == kResetAck, "wrong Ack MSG");
} else {
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
_assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
}
// set back to nonblock mode
all_links[i].sock.SetNonBlock(true);
@ -600,14 +610,44 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
* \return true if err_type is kSuccess, false otherwise
*/
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
shutdown_timeout = err_type == kSuccess;
if (err_type == kSuccess) return true;
utils::Assert(err_link != NULL, "must know the error source");
recover_counter += 1;
_assert(err_link != NULL, "must know the error link");
recover_counter += 1;
// async launch timeout task if enable_rabit_timeout is set
if (rabit_timeout && !rabit_timeout_task.valid()) {
utils::Printf("[EXPERIMENTAL] timeout thread expires in %d second(s)\n", timeout_sec);
rabit_timeout_task = std::async(std::launch::async, [=]() {
if (rabit_debug) {
utils::Printf("[%d] timeout thread %ld starts\n", rank,
std::this_thread::get_id());
}
int time = 0;
// check if rabit recovered every 100ms
while (time++ < 10 * timeout_sec) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (shutdown_timeout.load()) {
if (rabit_debug) {
utils::Printf("[%d] timeout task thread %ld exits\n",
rank, std::this_thread::get_id());
}
return true;
}
}
// print on tracker to help debuging
TrackerPrint("[ERROR] rank " + std::to_string(rank) + "@"+
host_uri + ":" +std::to_string(port) + " timeout\n");
_error("[%d] exit due to time out %d s\n", rank, timeout_sec);
return false;
});
}
// simple way, shutdown all links
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
}
// smooth out traffic to tracker
std::this_thread::sleep_for(std::chrono::milliseconds(10*rank));
ReConnectLinks("recover");
return false;
}
@ -724,8 +764,8 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
// set p_req_in
(*p_req_in)[i] = (req_in[i] != 0);
if (req_out[i] != 0) {
utils::Assert(req_in[i] == 0, "cannot get and receive request");
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
_assert(req_in[i] == 0, "cannot get and receive request");
_assert(static_cast<int>(i) == best_link, "request result inconsistent");
}
}
*p_recvlink = best_link;
@ -755,20 +795,20 @@ AllreduceRobust::TryRecoverData(RecoverType role,
RefLinkVector &links = tree_links;
// no need to run recovery for zero size messages
if (links.size() == 0 || size == 0) return kSuccess;
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
_assert(req_in.size() == links.size(), "TryRecoverData");
const int nlink = static_cast<int>(links.size());
{
bool req_data = role == kRequestData;
for (int i = 0; i < nlink; ++i) {
if (req_in[i]) {
utils::Assert(i != recv_link, "TryDecideRouting");
_assert(i != recv_link, "TryDecideRouting");
req_data = true;
}
}
// do not need to provide data or receive data, directly exit
if (!req_data) return kSuccess;
}
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
_assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
if (role == kPassData) {
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
}
@ -835,7 +875,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
for (int i = 0; i < nlink; ++i) {
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
}
utils::Assert(min_write <= links[pid].size_read, "boundary check");
_assert(min_write <= links[pid].size_read, "boundary check");
ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
if (ret != kSuccess) {
return ReportError(&links[pid], ret);
@ -869,7 +909,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester,
const int min_seq, const int max_seq) {
// clear requester and rebuild from those with most cache entries
if (requester) {
utils::Assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries");
_assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries");
cachebuf.Clear();
lookupbuf.Clear();
cur_cache_seq = 0;
@ -998,7 +1038,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
int new_version = !local_chkpt_version;
int nlocal = std::max(static_cast<int>(local_rptr[new_version].size()) - 1, 0);
// if we goes to this place, use must have already setup the state once
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
_assert(nlocal == 1 || nlocal == num_local_replica + 1,
"TryGetResult::Checkpoint");
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
}
@ -1048,13 +1088,13 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
// kLoadBootstrapCache should be treated similar as allreduce
// when loadcheck/check/checkack runs in other nodes
if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) {
utils::Assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
_assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
}
std::string msg = std::string(caller) + " pass negative seqno "
+ std::to_string(seqno) + " flag " + std::to_string(flag)
+ " version " + std::to_string(version_number);
utils::Assert(seqno >=0, msg.c_str());
_assert(seqno >=0, msg.c_str());
ActionSummary req(flag, flag, seqno, cache_seqno);
@ -1068,7 +1108,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
if (act.check_ack()) {
if (act.check_point()) {
// if we also have check_point, do check point first
utils::Assert(!act.diff_seq(),
_assert(!act.diff_seq(),
"check ack & check pt cannot occur together with normal ops");
// if we requested checkpoint, we are free to go
if (req.check_point()) return true;
@ -1087,7 +1127,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
} else {
if (act.check_point()) {
if (act.diff_seq()) {
utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
_assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
// print checkpoint consensus flag if user turn on debug
if (rabit_debug) {
req.print_flags(rank, "checkpoint req");
@ -1112,16 +1152,16 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
if (!act.load_cache()) {
if (act.seqno() > 0) {
if (!requester) {
utils::Assert(req.check_point(), "checkpoint node should be KHaveData role");
_assert(req.check_point(), "checkpoint node should be KHaveData role");
buf = resbuf.Query(act.seqno(), &size);
utils::Assert(buf != NULL, "buf should have data from resbuf");
utils::Assert(size > 0, "buf size should be greater than 0");
_assert(buf != NULL, "buf should have data from resbuf");
_assert(size > 0, "buf size should be greater than 0");
}
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
}
} else {
// cache seq no should be smaller than kSpecialOp
utils::Assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp,
_assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp,
"checkpoint with kSpecialOp");
int max_cache_seq = cur_cache_seq;
if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1,
@ -1153,11 +1193,11 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
act.print_flags(rank, "loadcache act");
}
// load cache should not running in parralel with other states
utils::Assert(!act.load_check(),
_assert(!act.load_check(),
"load cache state expect no nodes doing load checkpoint");
utils::Assert(!act.check_point() ,
_assert(!act.check_point() ,
"load cache state expect no nodes doing checkpoint");
utils::Assert(!act.check_ack(),
_assert(!act.check_ack(),
"load cache state expect no nodes doing checkpoint ack");
// if all nodes are requester in load cache, skip
@ -1176,10 +1216,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
}
// assert no req with load cache set goes into seq catch up
utils::Assert(!req.load_cache(), "load cache not interacte with rest states");
_assert(!req.load_cache(), "load cache not interacte with rest states");
// no special flags, no checkpoint, check ack, load_check
utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
_assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
if (act.diff_seq()) {
bool requester = req.seqno() == act.seqno();
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
@ -1194,7 +1234,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
// something is still incomplete try next round
}
}
utils::Assert(false, "RecoverExec: should not reach here");
_assert(false, "RecoverExec: should not reach here");
return true;
}
/*!
@ -1222,13 +1262,13 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
std::string &chkpt = *p_local_chkpt;
if (rptr.size() == 0) {
rptr.push_back(0);
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
_assert(chkpt.length() == 0, "local chkpt space inconsistent");
}
const int n = num_local_replica;
{
// backward passing, passing state in backward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
_assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_back(n + 1);
msg_back[0] = nlocal;
// backward passing one hop the request
@ -1282,7 +1322,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
{
// forward passing, passing state in forward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
_assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_forward(n + 1);
msg_forward[0] = nlocal;
// backward passing one hop the request
@ -1367,7 +1407,7 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
if (num_local_replica == 0) return kSuccess;
std::vector<size_t> &rptr = *p_local_rptr;
std::string &chkpt = *p_local_chkpt;
utils::Assert(rptr.size() == 2,
_assert(rptr.size() == 2,
"TryCheckinLocalState must have exactly 1 state");
const int n = num_local_replica;
std::vector<size_t> sizes(n + 1);
@ -1423,10 +1463,10 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
LinkRecord *read_link,
LinkRecord *write_link) {
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
utils::Assert(write_end <= read_end,
_assert(write_end <= read_end,
"RingPassing: boundary check1");
utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2");
utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3");
_assert(read_ptr <= read_end, "RingPassing: boundary check2");
_assert(write_ptr <= write_end, "RingPassing: boundary check3");
// take reference
LinkRecord &prev = *read_link, &next = *write_link;
// send recv buffer

View File

@ -10,6 +10,7 @@
*/
#ifndef RABIT_ALLREDUCE_ROBUST_H_
#define RABIT_ALLREDUCE_ROBUST_H_
#include <future>
#include <vector>
#include <string>
#include <algorithm>
@ -632,6 +633,14 @@ o * the input state must exactly one saved state(local state of current node)
int local_chkpt_version;
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
bool checkpoint_loaded;
// sidecar executing timeout task
std::future<bool> rabit_timeout_task;
// flag to shutdown rabit_timeout_task before timeout
std::atomic<bool> shutdown_timeout{false};
// error handler
void (* _error)(const char *fmt, ...) = utils::Error;
// assert handler
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
};
} // namespace engine
} // namespace rabit

View File

@ -2,14 +2,13 @@ find_package(GTest REQUIRED)
add_executable(
unit_tests
allreduce_base_test.cpp
allreduce_mock_test.cpp
allreduce_base_test.cc allreduce_robust_test.cc allreduce_mock_test.cc
test_main.cpp)
target_link_libraries(
unit_tests
unit_tests PRIVATE
GTest::GTest GTest::Main
rabit_base rabit_mock)
rabit_base rabit_mock rabit)
target_include_directories(unit_tests PUBLIC
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"

View File

@ -0,0 +1,233 @@
#define RABIT_CXXTESTDEFS_H
#include <gtest/gtest.h>
#include <chrono>
#include <string>
#include <iostream>
#include "../../src/allreduce_robust.h"
inline void mockerr(const char *fmt, ...) {EXPECT_STRCASEEQ(fmt, "[%d] exit due to time out %d s\n");}
inline void mockassert(bool val, const char *fmt, ...) {}
rabit::engine::AllreduceRobust::ReturnType err_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSockError);
rabit::engine::AllreduceRobust::ReturnType succ_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSuccess);
TEST(allreduce_robust, sync_error_timeout)
{
rabit::engine::AllreduceRobust m;
std::string rabit_timeout = "rabit_timeout=1";
char cmd[rabit_timeout.size()+1];
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
cmd[rabit_timeout.size()] = '\0';
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
char cmd1[rabit_timeout_sec.size()+1];
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
cmd1[rabit_timeout_sec.size()] = '\0';
char* argv[] = {cmd,cmd1};
m.Init(2, argv);
m.rank = 0;
m.rabit_bootstrap_cache = 1;
m._error = mockerr;
m._assert = mockassert;
EXPECT_EQ(m.CheckAndRecover(err_type), false);
std::this_thread::sleep_for(std::chrono::milliseconds(1500));
EXPECT_EQ(m.rabit_timeout_task.get(), false);
}
TEST(allreduce_robust, sync_error_reset)
{
rabit::engine::AllreduceRobust m;
std::string rabit_timeout = "rabit_timeout=1";
char cmd[rabit_timeout.size()+1];
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
cmd[rabit_timeout.size()] = '\0';
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
char cmd1[rabit_timeout_sec.size()+1];
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
cmd1[rabit_timeout_sec.size()] = '\0';
std::string rabit_debug = "rabit_debug=1";
char cmd2[rabit_debug.size()+1];
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
cmd2[rabit_debug.size()] = '\0';
char* argv[] = {cmd, cmd1,cmd2};
m.Init(3, argv);
m.rank = 0;
m._assert = mockassert;
EXPECT_EQ(m.CheckAndRecover(err_type), false);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
EXPECT_EQ(m.rabit_timeout_task.get(), true);
m.Shutdown();
}
TEST(allreduce_robust, sync_success_error_timeout)
{
rabit::engine::AllreduceRobust m;
std::string rabit_timeout = "rabit_timeout=1";
char cmd[rabit_timeout.size()+1];
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
cmd[rabit_timeout.size()] = '\0';
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
char cmd1[rabit_timeout_sec.size()+1];
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
cmd1[rabit_timeout_sec.size()] = '\0';
std::string rabit_debug = "rabit_debug=1";
char cmd2[rabit_debug.size()+1];
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
cmd2[rabit_debug.size()] = '\0';
char* argv[] = {cmd, cmd1,cmd2};
m.Init(3, argv);
m.rank = 0;
m.rabit_bootstrap_cache = 1;
m._assert = mockassert;
m._error = mockerr;
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
EXPECT_EQ(m.CheckAndRecover(err_type), false);
std::this_thread::sleep_for(std::chrono::milliseconds(1500));
EXPECT_EQ(m.rabit_timeout_task.get(), false);
}
TEST(allreduce_robust, sync_success_error_success)
{
rabit::engine::AllreduceRobust m;
std::string rabit_timeout = "rabit_timeout=1";
char cmd[rabit_timeout.size()+1];
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
cmd[rabit_timeout.size()] = '\0';
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
char cmd1[rabit_timeout_sec.size()+1];
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
cmd1[rabit_timeout_sec.size()] = '\0';
std::string rabit_debug = "rabit_debug=1";
char cmd2[rabit_debug.size()+1];
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
cmd2[rabit_debug.size()] = '\0';
char* argv[] = {cmd, cmd1,cmd2};
m.Init(3, argv);
m.rank = 0;
m.rabit_bootstrap_cache = 1;
m._assert = mockassert;
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
EXPECT_EQ(m.CheckAndRecover(err_type), false);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
std::this_thread::sleep_for(std::chrono::milliseconds(1100));
EXPECT_EQ(m.rabit_timeout_task.get(), true);
m.Shutdown();
}
TEST(allreduce_robust, sync_error_no_reset_timeout)
{
rabit::engine::AllreduceRobust m;
std::string rabit_timeout = "rabit_timeout=1";
char cmd[rabit_timeout.size()+1];
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
cmd[rabit_timeout.size()] = '\0';
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
char cmd1[rabit_timeout_sec.size()+1];
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
cmd1[rabit_timeout_sec.size()] = '\0';
std::string rabit_debug = "rabit_debug=1";
char cmd2[rabit_debug.size()+1];
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
cmd2[rabit_debug.size()] = '\0';
char* argv[] = {cmd, cmd1,cmd2};
m.Init(3, argv);
m.rank = 0;
m.rabit_bootstrap_cache = 1;
m._assert = mockassert;
m._error = mockerr;
auto start = std::chrono::system_clock::now();
EXPECT_EQ(m.CheckAndRecover(err_type), false);
std::this_thread::sleep_for(std::chrono::milliseconds(1100));
EXPECT_EQ(m.CheckAndRecover(err_type), false);
m.rabit_timeout_task.wait();
auto end = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end-start;
EXPECT_EQ(m.rabit_timeout_task.get(), false);
// expect second error don't overwrite/reset timeout task
EXPECT_LT(diff.count(), 2);
}
TEST(allreduce_robust, no_timeout_shut_down)
{
rabit::engine::AllreduceRobust m;
std::string rabit_timeout = "rabit_timeout=1";
char cmd[rabit_timeout.size()+1];
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
cmd[rabit_timeout.size()] = '\0';
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
char cmd1[rabit_timeout_sec.size()+1];
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
cmd1[rabit_timeout_sec.size()] = '\0';
std::string rabit_debug = "rabit_debug=1";
char cmd2[rabit_debug.size()+1];
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
cmd2[rabit_debug.size()] = '\0';
char* argv[] = {cmd, cmd1,cmd2};
m.Init(3, argv);
m.rank = 0;
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
m.Shutdown();
}
TEST(allreduce_robust, shut_down_before_timeout)
{
rabit::engine::AllreduceRobust m;
std::string rabit_timeout = "rabit_timeout=1";
char cmd[rabit_timeout.size()+1];
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
cmd[rabit_timeout.size()] = '\0';
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
char cmd1[rabit_timeout_sec.size()+1];
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
cmd1[rabit_timeout_sec.size()] = '\0';
std::string rabit_debug = "rabit_debug=1";
char cmd2[rabit_debug.size()+1];
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
cmd2[rabit_debug.size()] = '\0';
char* argv[] = {cmd, cmd1,cmd2};
m.Init(3, argv);
m.rank = 0;
rabit::engine::AllreduceRobust::LinkRecord a;
m.err_link = &a;
EXPECT_EQ(m.CheckAndRecover(err_type), false);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
m.Shutdown();
}

View File

@ -3,5 +3,6 @@
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
::testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}

View File

@ -13,7 +13,7 @@ all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_di
# this experiment test recovery with actually process exit, use keepalive to keep program alive
model_recover_10_10k:
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=-1 rabit_debug=1 rabit_reduce_ring_mincount=1
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=true rabit_debug=true rabit_reduce_ring_mincount=1 rabit_timeout=true rabit_timeout_sec=5
model_recover_10_10k_die_same:
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 rabit_bootstrap_cache=1