exit when allreduce/broadcast error cause timeout (#112)
* keep async timeout task * add missing pthread to cmake * add tests * Add a sleep period to avoid flushing the tracker.
This commit is contained in:
parent
af7281afe3
commit
5d1b613910
3
.gitignore
vendored
3
.gitignore
vendored
@ -47,3 +47,6 @@ mpich-3.2/
|
|||||||
cmake-build-debug/
|
cmake-build-debug/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
|
# cmake
|
||||||
|
build/
|
||||||
|
|
||||||
|
|||||||
15
.travis.yml
15
.travis.yml
@ -8,16 +8,29 @@ osx_image: xcode10.2
|
|||||||
|
|
||||||
dist: xenial
|
dist: xenial
|
||||||
|
|
||||||
|
language: cpp
|
||||||
|
|
||||||
# Use Build Matrix to do lint and build seperately
|
# Use Build Matrix to do lint and build seperately
|
||||||
env:
|
env:
|
||||||
matrix:
|
matrix:
|
||||||
- TASK=lint LINT_LANG=cpp
|
- TASK=lint LINT_LANG=cpp
|
||||||
- TASK=lint LINT_LANG=python
|
- TASK=lint LINT_LANG=python
|
||||||
- TASK=doc
|
- TASK=doc
|
||||||
- TASK=build
|
# - TASK=build
|
||||||
- TASK=mpi-build
|
- TASK=mpi-build
|
||||||
- TASK=cmake-test
|
- TASK=cmake-test
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
exclude:
|
||||||
|
- os: osx
|
||||||
|
env: TASK=lint LINT_LANG=cpp
|
||||||
|
- os: osx
|
||||||
|
env: TASK=lint LINT_LANG=python
|
||||||
|
- os: osx
|
||||||
|
env: TASK=doc
|
||||||
|
- os: osx
|
||||||
|
env: TASK=build
|
||||||
|
|
||||||
# dependent apt packages
|
# dependent apt packages
|
||||||
addons:
|
addons:
|
||||||
apt:
|
apt:
|
||||||
|
|||||||
@ -20,11 +20,16 @@ if(R_LIB OR MINGW OR WIN32)
|
|||||||
CXX_STANDARD_REQUIRED ON
|
CXX_STANDARD_REQUIRED ON
|
||||||
POSITION_INDEPENDENT_CODE ON)
|
POSITION_INDEPENDENT_CODE ON)
|
||||||
else()
|
else()
|
||||||
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
|
find_package(Threads REQUIRED)
|
||||||
add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc)
|
|
||||||
add_library(rabit_empty src/engine_empty.cc src/c_api.cc)
|
add_library(rabit_empty src/engine_empty.cc src/c_api.cc)
|
||||||
|
add_library(rabit_base src/allreduce_base.cc src/engine_base.cc src/c_api.cc)
|
||||||
|
|
||||||
|
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
|
||||||
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||||
add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
add_library(rabit_mock SHARED src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||||
|
target_link_libraries(rabit Threads::Threads)
|
||||||
|
target_link_libraries(rabit_mock_static Threads::Threads)
|
||||||
|
target_link_libraries(rabit_mock Threads::Threads)
|
||||||
|
|
||||||
set(rabit_libs rabit rabit_base rabit_empty rabit_mock rabit_mock_static)
|
set(rabit_libs rabit rabit_base rabit_empty rabit_mock rabit_mock_static)
|
||||||
set_target_properties(rabit rabit_base rabit_empty rabit_mock rabit_mock_static
|
set_target_properties(rabit rabit_base rabit_empty rabit_mock rabit_mock_static
|
||||||
|
|||||||
11
doc/guide.md
11
doc/guide.md
@ -154,6 +154,8 @@ you can also refer to [wormhole](https://github.com/dmlc/wormhole/blob/master/le
|
|||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
...
|
...
|
||||||
rabit::Init(argc, argv);
|
rabit::Init(argc, argv);
|
||||||
|
// sync on expected model size before load checkpoint, if we pass rabit_bootstrap_cache=true
|
||||||
|
rabit::Allreduce<rabit::op::Max>(&model.size(), 1);
|
||||||
// load the latest checked model
|
// load the latest checked model
|
||||||
int version = rabit::LoadCheckPoint(&model);
|
int version = rabit::LoadCheckPoint(&model);
|
||||||
// initialize the model if it is the first version
|
// initialize the model if it is the first version
|
||||||
@ -370,3 +372,12 @@ Allreduce/Broadcast calls after the checkpoint from some alive nodes.
|
|||||||
|
|
||||||
This is just a conceptual introduction to rabit's fault tolerance model. The actual implementation is more sophisticated,
|
This is just a conceptual introduction to rabit's fault tolerance model. The actual implementation is more sophisticated,
|
||||||
and can deal with more complicated cases such as multiple nodes failure and node failure during recovery phase.
|
and can deal with more complicated cases such as multiple nodes failure and node failure during recovery phase.
|
||||||
|
|
||||||
|
Rabit Timeout
|
||||||
|
---------------
|
||||||
|
|
||||||
|
In certain cases, rabit cluster may suffer lack of resources to retry failed workers.
|
||||||
|
Thanks to fault tolerant assumption with infinite retry, it might cause entire cluster hang infinitely.
|
||||||
|
We introduce sidecar thread which runs when rabit fault tolerant runtime observed allreduce/broadcast errors.
|
||||||
|
By default, it will wait for 30 mins before all workers program exit.
|
||||||
|
User can opt-in this feature and change treshold by passing rabit_timeout=true and rabit_timeout_sec=x (in seconds).
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
#ifndef RABIT_INTERNAL_UTILS_H_
|
#ifndef RABIT_INTERNAL_UTILS_H_
|
||||||
#define RABIT_INTERNAL_UTILS_H_
|
#define RABIT_INTERNAL_UTILS_H_
|
||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
|
#include <string.h>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
@ -66,6 +67,11 @@ const int kPrintBuffer = 1 << 12;
|
|||||||
* co-locate in the same process */
|
* co-locate in the same process */
|
||||||
extern bool STOP_PROCESS_ON_ERROR;
|
extern bool STOP_PROCESS_ON_ERROR;
|
||||||
|
|
||||||
|
/* \brief parse config string too bool*/
|
||||||
|
inline bool StringToBool(const char* s) {
|
||||||
|
return strcasecmp(s, "true") == 0 || atoi(s) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef RABIT_CUSTOMIZE_MSG_
|
#ifndef RABIT_CUSTOMIZE_MSG_
|
||||||
/*!
|
/*!
|
||||||
* \brief handling of Assert error, caused by inappropriate input
|
* \brief handling of Assert error, caused by inappropriate input
|
||||||
@ -86,7 +92,7 @@ inline void HandleAssertError(const char *msg) {
|
|||||||
*/
|
*/
|
||||||
inline void HandleCheckError(const char *msg) {
|
inline void HandleCheckError(const char *msg) {
|
||||||
if (STOP_PROCESS_ON_ERROR) {
|
if (STOP_PROCESS_ON_ERROR) {
|
||||||
fprintf(stderr, "%s, shutting down process", msg);
|
fprintf(stderr, "%s, shutting down process\n", msg);
|
||||||
exit(-1);
|
exit(-1);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
||||||
|
|||||||
@ -25,8 +25,9 @@ if [ ${TASK} == "cmake-test" ]; then
|
|||||||
mkdir build
|
mkdir build
|
||||||
cd build
|
cd build
|
||||||
cmake -DRABIT_BUILD_TESTS=ON -DRABIT_BUILD_DMLC=ON -DGTEST_ROOT=${HOME}/.local ..
|
cmake -DRABIT_BUILD_TESTS=ON -DRABIT_BUILD_DMLC=ON -DGTEST_ROOT=${HOME}/.local ..
|
||||||
#unit tests
|
# known osx gtest 1.8 issue
|
||||||
make
|
cp ${HOME}/.local/lib/*.dylib .
|
||||||
|
make -j$(nproc)
|
||||||
make test
|
make test
|
||||||
make install || exit -1
|
make install || exit -1
|
||||||
cd ../test
|
cd ../test
|
||||||
|
|||||||
@ -73,7 +73,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
if (task_id == NULL) {
|
if (task_id == NULL) {
|
||||||
task_id = getenv("mapreduce_task_id");
|
task_id = getenv("mapreduce_task_id");
|
||||||
}
|
}
|
||||||
if (hadoop_mode != 0) {
|
if (hadoop_mode) {
|
||||||
utils::Check(task_id != NULL,
|
utils::Check(task_id != NULL,
|
||||||
"hadoop_mode is set but cannot find mapred_task_id");
|
"hadoop_mode is set but cannot find mapred_task_id");
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
if (num_task == NULL) {
|
if (num_task == NULL) {
|
||||||
num_task = getenv("mapreduce_job_maps");
|
num_task = getenv("mapreduce_job_maps");
|
||||||
}
|
}
|
||||||
if (hadoop_mode != 0) {
|
if (hadoop_mode) {
|
||||||
utils::Check(num_task != NULL,
|
utils::Check(num_task != NULL,
|
||||||
"hadoop_mode is set but cannot find mapred_map_tasks");
|
"hadoop_mode is set but cannot find mapred_map_tasks");
|
||||||
}
|
}
|
||||||
@ -188,7 +188,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
|
if (!strcmp(name, "DMLC_TASK_ID")) task_id = val;
|
||||||
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
|
if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val;
|
||||||
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
if (!strcmp(name, "rabit_world_size")) world_size = atoi(val);
|
||||||
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = atoi(val);
|
if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val);
|
||||||
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
if (!strcmp(name, "rabit_reduce_ring_mincount")) {
|
||||||
reduce_ring_mincount = atoi(val);
|
reduce_ring_mincount = atoi(val);
|
||||||
utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
|
utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0");
|
||||||
@ -209,10 +209,17 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "rabit_bootstrap_cache")) {
|
if (!strcmp(name, "rabit_bootstrap_cache")) {
|
||||||
rabit_bootstrap_cache = atoi(val);
|
rabit_bootstrap_cache = utils::StringToBool(val);
|
||||||
}
|
}
|
||||||
if (!strcmp(name, "rabit_debug")) {
|
if (!strcmp(name, "rabit_debug")) {
|
||||||
rabit_debug = atoi(val);
|
rabit_debug = utils::StringToBool(val);
|
||||||
|
}
|
||||||
|
if (!strcmp(name, "rabit_timeout")) {
|
||||||
|
rabit_timeout = utils::StringToBool(val);
|
||||||
|
}
|
||||||
|
if (!strcmp(name, "rabit_timeout_sec")) {
|
||||||
|
timeout_sec = atoi(val);
|
||||||
|
utils::Assert(rabit_timeout > 0, "rabit_timeout_sec should be greater than 0 second");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -496,7 +496,7 @@ class AllreduceBase : public IEngine {
|
|||||||
// version number of model
|
// version number of model
|
||||||
int version_number;
|
int version_number;
|
||||||
// whether the job is running in hadoop
|
// whether the job is running in hadoop
|
||||||
int hadoop_mode;
|
bool hadoop_mode;
|
||||||
//---- local data related to link ----
|
//---- local data related to link ----
|
||||||
// index of parent link, can be -1, meaning this is root of the tree
|
// index of parent link, can be -1, meaning this is root of the tree
|
||||||
int parent_index;
|
int parent_index;
|
||||||
@ -543,9 +543,13 @@ class AllreduceBase : public IEngine {
|
|||||||
// backdoor port
|
// backdoor port
|
||||||
int port = 0;
|
int port = 0;
|
||||||
// enable bootstrap cache 0 false 1 true
|
// enable bootstrap cache 0 false 1 true
|
||||||
int rabit_bootstrap_cache = 0;
|
bool rabit_bootstrap_cache = false;
|
||||||
// enable detailed logging
|
// enable detailed logging
|
||||||
int rabit_debug = 0;
|
bool rabit_debug = false;
|
||||||
|
// by default, if rabit worker not recover in half an hour exit
|
||||||
|
int timeout_sec = 1800;
|
||||||
|
// flag to enable rabit_timeout
|
||||||
|
bool rabit_timeout = false;
|
||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -176,7 +176,7 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
if (mock_map.count(key) != 0) {
|
if (mock_map.count(key) != 0) {
|
||||||
num_trial += 1;
|
num_trial += 1;
|
||||||
// data processing frameworks runs on shared process
|
// data processing frameworks runs on shared process
|
||||||
utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name);
|
_error("[%d]@@@Hit Mock Error:%s ", rank, name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -8,6 +8,8 @@
|
|||||||
#define _CRT_SECURE_NO_WARNINGS
|
#define _CRT_SECURE_NO_WARNINGS
|
||||||
#define _CRT_SECURE_NO_DEPRECATE
|
#define _CRT_SECURE_NO_DEPRECATE
|
||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
|
#include <chrono>
|
||||||
|
#include <thread>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "rabit/internal/io.h"
|
#include "rabit/internal/io.h"
|
||||||
@ -19,6 +21,7 @@
|
|||||||
|
|
||||||
namespace rabit {
|
namespace rabit {
|
||||||
namespace engine {
|
namespace engine {
|
||||||
|
|
||||||
AllreduceRobust::AllreduceRobust(void) {
|
AllreduceRobust::AllreduceRobust(void) {
|
||||||
num_local_replica = 0;
|
num_local_replica = 0;
|
||||||
num_global_replica = 5;
|
num_global_replica = 5;
|
||||||
@ -38,7 +41,7 @@ bool AllreduceRobust::Init(int argc, char* argv[]) {
|
|||||||
if (AllreduceBase::Init(argc, argv)) {
|
if (AllreduceBase::Init(argc, argv)) {
|
||||||
// chenqin: alert user opted in experimental feature.
|
// chenqin: alert user opted in experimental feature.
|
||||||
if (rabit_bootstrap_cache) utils::HandleLogInfo(
|
if (rabit_bootstrap_cache) utils::HandleLogInfo(
|
||||||
"[EXPERIMENTAL] rabit bootstrap cache has been enabled\n");
|
"[EXPERIMENTAL] bootstrap cache has been enabled\n");
|
||||||
checkpoint_loaded = false;
|
checkpoint_loaded = false;
|
||||||
if (num_global_replica == 0) {
|
if (num_global_replica == 0) {
|
||||||
result_buffer_round = -1;
|
result_buffer_round = -1;
|
||||||
@ -55,24 +58,31 @@ bool AllreduceRobust::Shutdown(void) {
|
|||||||
try {
|
try {
|
||||||
// need to sync the exec before we shutdown, do a pesudo check point
|
// need to sync the exec before we shutdown, do a pesudo check point
|
||||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp,
|
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp,
|
||||||
cur_cache_seq), "Shutdown: check point must return true");
|
cur_cache_seq), "Shutdown: check point must return true");
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
cachebuf.Clear(); cur_cache_seq = 0;
|
cachebuf.Clear(); cur_cache_seq = 0;
|
||||||
lookupbuf.Clear();
|
lookupbuf.Clear();
|
||||||
// execute check ack step, load happens here
|
// execute check ack step, load happens here
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
||||||
ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check ack must return true");
|
ActionSummary::kSpecialOp, cur_cache_seq), "Shutdown: check ack must return true");
|
||||||
|
// travis ci only osx test hang
|
||||||
#if defined (__APPLE__)
|
#if defined (__APPLE__)
|
||||||
sleep(1);
|
sleep(1);
|
||||||
#endif
|
#endif
|
||||||
|
shutdown_timeout = true;
|
||||||
|
if (rabit_timeout_task.valid()) {
|
||||||
|
rabit_timeout_task.wait();
|
||||||
|
_assert(rabit_timeout_task.get(), "expect timeout task return\n");
|
||||||
|
}
|
||||||
return AllreduceBase::Shutdown();
|
return AllreduceBase::Shutdown();
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
fprintf(stderr, "%s\n", e.what());
|
fprintf(stderr, "%s\n", e.what());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
@ -98,8 +108,8 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::Assert(index == -1, "immutable cache key already exists");
|
_assert(index == -1, "immutable cache key already exists");
|
||||||
utils::Assert(type_nbytes*count > 0, "can't set empty cache");
|
_assert(type_nbytes*count > 0, "can't set empty cache");
|
||||||
void* temp = cachebuf.AllocTemp(type_nbytes, count);
|
void* temp = cachebuf.AllocTemp(type_nbytes, count);
|
||||||
cachebuf.PushTemp(cur_cache_seq, type_nbytes, count);
|
cachebuf.PushTemp(cur_cache_seq, type_nbytes, count);
|
||||||
std::memcpy(temp, buf, type_nbytes*count);
|
std::memcpy(temp, buf, type_nbytes*count);
|
||||||
@ -133,9 +143,9 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
|||||||
|
|
||||||
size_t siz = 0;
|
size_t siz = 0;
|
||||||
void* temp = cachebuf.Query(index, &siz);
|
void* temp = cachebuf.Query(index, &siz);
|
||||||
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
_assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
|
||||||
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
_assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
|
||||||
utils::Assert(siz > 0, "cache size should be greater than 0");
|
_assert(siz > 0, "cache size should be greater than 0");
|
||||||
std::memcpy(buf, temp, type_nbytes*count);
|
std::memcpy(buf, temp, type_nbytes*count);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -317,7 +327,7 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
|
|||||||
local_rptr[local_chkpt_version][1]);
|
local_rptr[local_chkpt_version][1]);
|
||||||
local_model->Load(&fs);
|
local_model->Load(&fs);
|
||||||
} else {
|
} else {
|
||||||
utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
|
_assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
@ -327,14 +337,14 @@ int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
|
|||||||
if (global_checkpoint.length() == 0) {
|
if (global_checkpoint.length() == 0) {
|
||||||
version_number = 0;
|
version_number = 0;
|
||||||
} else {
|
} else {
|
||||||
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
|
_assert(fs.Read(&version_number, sizeof(version_number)) != 0,
|
||||||
"read in version number");
|
"read in version number");
|
||||||
global_model->Load(&fs);
|
global_model->Load(&fs);
|
||||||
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
|
_assert(local_model == NULL || nlocal == num_local_replica + 1,
|
||||||
"local model inconsistent, nlocal=%d", nlocal);
|
"local model inconsistent, nlocal=%d", nlocal);
|
||||||
}
|
}
|
||||||
// run another phase of check ack, if recovered from data
|
// run another phase of check ack, if recovered from data
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
||||||
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
|
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
|
||||||
|
|
||||||
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) {
|
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq)) {
|
||||||
@ -433,7 +443,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
|||||||
local_chkpt_version = !local_chkpt_version;
|
local_chkpt_version = !local_chkpt_version;
|
||||||
}
|
}
|
||||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
|
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
|
||||||
ActionSummary::kSpecialOp, cur_cache_seq),
|
ActionSummary::kSpecialOp, cur_cache_seq),
|
||||||
"check point must return true");
|
"check point must return true");
|
||||||
// this is the critical region where we will change all the stored models
|
// this is the critical region where we will change all the stored models
|
||||||
@ -460,7 +470,7 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
|||||||
// reset result buffer, mark boostrap phase complete
|
// reset result buffer, mark boostrap phase complete
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
// execute check ack step, load happens here
|
// execute check ack step, load happens here
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck,
|
||||||
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
|
ActionSummary::kSpecialOp, cur_cache_seq), "check ack must return true");
|
||||||
|
|
||||||
delta = utils::GetTime() - start;
|
delta = utils::GetTime() - start;
|
||||||
@ -533,7 +543,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
if (all_links[i].size_read == 0) {
|
if (all_links[i].size_read == 0) {
|
||||||
int atmark = all_links[i].sock.AtMark();
|
int atmark = all_links[i].sock.AtMark();
|
||||||
if (atmark < 0) {
|
if (atmark < 0) {
|
||||||
utils::Assert(all_links[i].sock.BadSocket(), "must already gone bad");
|
_assert(all_links[i].sock.BadSocket(), "must already gone bad");
|
||||||
} else if (atmark > 0) {
|
} else if (atmark > 0) {
|
||||||
all_links[i].size_read = 1;
|
all_links[i].size_read = 1;
|
||||||
} else {
|
} else {
|
||||||
@ -555,10 +565,10 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
all_links[i].sock.Close(); continue;
|
all_links[i].sock.Close(); continue;
|
||||||
} else if (len > 0) {
|
} else if (len > 0) {
|
||||||
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
|
_assert(oob_mark == kResetMark, "wrong oob msg");
|
||||||
utils::Assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
|
_assert(all_links[i].sock.AtMark() != 1, "should already read past mark");
|
||||||
} else {
|
} else {
|
||||||
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
_assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||||
}
|
}
|
||||||
// send out ack
|
// send out ack
|
||||||
char ack = kResetAck;
|
char ack = kResetAck;
|
||||||
@ -579,9 +589,9 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
all_links[i].sock.Close(); continue;
|
all_links[i].sock.Close(); continue;
|
||||||
} else if (len > 0) {
|
} else if (len > 0) {
|
||||||
utils::Assert(ack == kResetAck, "wrong Ack MSG");
|
_assert(ack == kResetAck, "wrong Ack MSG");
|
||||||
} else {
|
} else {
|
||||||
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
_assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||||
}
|
}
|
||||||
// set back to nonblock mode
|
// set back to nonblock mode
|
||||||
all_links[i].sock.SetNonBlock(true);
|
all_links[i].sock.SetNonBlock(true);
|
||||||
@ -600,14 +610,44 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|||||||
* \return true if err_type is kSuccess, false otherwise
|
* \return true if err_type is kSuccess, false otherwise
|
||||||
*/
|
*/
|
||||||
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
||||||
|
shutdown_timeout = err_type == kSuccess;
|
||||||
if (err_type == kSuccess) return true;
|
if (err_type == kSuccess) return true;
|
||||||
utils::Assert(err_link != NULL, "must know the error source");
|
|
||||||
recover_counter += 1;
|
|
||||||
|
|
||||||
|
_assert(err_link != NULL, "must know the error link");
|
||||||
|
recover_counter += 1;
|
||||||
|
// async launch timeout task if enable_rabit_timeout is set
|
||||||
|
if (rabit_timeout && !rabit_timeout_task.valid()) {
|
||||||
|
utils::Printf("[EXPERIMENTAL] timeout thread expires in %d second(s)\n", timeout_sec);
|
||||||
|
rabit_timeout_task = std::async(std::launch::async, [=]() {
|
||||||
|
if (rabit_debug) {
|
||||||
|
utils::Printf("[%d] timeout thread %ld starts\n", rank,
|
||||||
|
std::this_thread::get_id());
|
||||||
|
}
|
||||||
|
int time = 0;
|
||||||
|
// check if rabit recovered every 100ms
|
||||||
|
while (time++ < 10 * timeout_sec) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
if (shutdown_timeout.load()) {
|
||||||
|
if (rabit_debug) {
|
||||||
|
utils::Printf("[%d] timeout task thread %ld exits\n",
|
||||||
|
rank, std::this_thread::get_id());
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// print on tracker to help debuging
|
||||||
|
TrackerPrint("[ERROR] rank " + std::to_string(rank) + "@"+
|
||||||
|
host_uri + ":" +std::to_string(port) + " timeout\n");
|
||||||
|
_error("[%d] exit due to time out %d s\n", rank, timeout_sec);
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
}
|
||||||
// simple way, shutdown all links
|
// simple way, shutdown all links
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close();
|
||||||
}
|
}
|
||||||
|
// smooth out traffic to tracker
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(10*rank));
|
||||||
ReConnectLinks("recover");
|
ReConnectLinks("recover");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -724,8 +764,8 @@ AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
|||||||
// set p_req_in
|
// set p_req_in
|
||||||
(*p_req_in)[i] = (req_in[i] != 0);
|
(*p_req_in)[i] = (req_in[i] != 0);
|
||||||
if (req_out[i] != 0) {
|
if (req_out[i] != 0) {
|
||||||
utils::Assert(req_in[i] == 0, "cannot get and receive request");
|
_assert(req_in[i] == 0, "cannot get and receive request");
|
||||||
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
|
_assert(static_cast<int>(i) == best_link, "request result inconsistent");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*p_recvlink = best_link;
|
*p_recvlink = best_link;
|
||||||
@ -755,20 +795,20 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
RefLinkVector &links = tree_links;
|
RefLinkVector &links = tree_links;
|
||||||
// no need to run recovery for zero size messages
|
// no need to run recovery for zero size messages
|
||||||
if (links.size() == 0 || size == 0) return kSuccess;
|
if (links.size() == 0 || size == 0) return kSuccess;
|
||||||
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
_assert(req_in.size() == links.size(), "TryRecoverData");
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
{
|
{
|
||||||
bool req_data = role == kRequestData;
|
bool req_data = role == kRequestData;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (req_in[i]) {
|
if (req_in[i]) {
|
||||||
utils::Assert(i != recv_link, "TryDecideRouting");
|
_assert(i != recv_link, "TryDecideRouting");
|
||||||
req_data = true;
|
req_data = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// do not need to provide data or receive data, directly exit
|
// do not need to provide data or receive data, directly exit
|
||||||
if (!req_data) return kSuccess;
|
if (!req_data) return kSuccess;
|
||||||
}
|
}
|
||||||
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
_assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
||||||
if (role == kPassData) {
|
if (role == kPassData) {
|
||||||
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
|
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
|
||||||
}
|
}
|
||||||
@ -835,7 +875,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
||||||
}
|
}
|
||||||
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
_assert(min_write <= links[pid].size_read, "boundary check");
|
||||||
ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
|
ReturnType ret = links[pid].ReadToRingBuffer(min_write, size);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ReportError(&links[pid], ret);
|
return ReportError(&links[pid], ret);
|
||||||
@ -869,7 +909,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester,
|
|||||||
const int min_seq, const int max_seq) {
|
const int min_seq, const int max_seq) {
|
||||||
// clear requester and rebuild from those with most cache entries
|
// clear requester and rebuild from those with most cache entries
|
||||||
if (requester) {
|
if (requester) {
|
||||||
utils::Assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries");
|
_assert(cur_cache_seq <= max_seq, "requester is expected to have fewer cache entries");
|
||||||
cachebuf.Clear();
|
cachebuf.Clear();
|
||||||
lookupbuf.Clear();
|
lookupbuf.Clear();
|
||||||
cur_cache_seq = 0;
|
cur_cache_seq = 0;
|
||||||
@ -998,7 +1038,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
|||||||
int new_version = !local_chkpt_version;
|
int new_version = !local_chkpt_version;
|
||||||
int nlocal = std::max(static_cast<int>(local_rptr[new_version].size()) - 1, 0);
|
int nlocal = std::max(static_cast<int>(local_rptr[new_version].size()) - 1, 0);
|
||||||
// if we goes to this place, use must have already setup the state once
|
// if we goes to this place, use must have already setup the state once
|
||||||
utils::Assert(nlocal == 1 || nlocal == num_local_replica + 1,
|
_assert(nlocal == 1 || nlocal == num_local_replica + 1,
|
||||||
"TryGetResult::Checkpoint");
|
"TryGetResult::Checkpoint");
|
||||||
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
|
return TryRecoverLocalState(&local_rptr[new_version], &local_chkpt[new_version]);
|
||||||
}
|
}
|
||||||
@ -1048,13 +1088,13 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
// kLoadBootstrapCache should be treated similar as allreduce
|
// kLoadBootstrapCache should be treated similar as allreduce
|
||||||
// when loadcheck/check/checkack runs in other nodes
|
// when loadcheck/check/checkack runs in other nodes
|
||||||
if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) {
|
if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) {
|
||||||
utils::Assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
|
_assert(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string msg = std::string(caller) + " pass negative seqno "
|
std::string msg = std::string(caller) + " pass negative seqno "
|
||||||
+ std::to_string(seqno) + " flag " + std::to_string(flag)
|
+ std::to_string(seqno) + " flag " + std::to_string(flag)
|
||||||
+ " version " + std::to_string(version_number);
|
+ " version " + std::to_string(version_number);
|
||||||
utils::Assert(seqno >=0, msg.c_str());
|
_assert(seqno >=0, msg.c_str());
|
||||||
|
|
||||||
ActionSummary req(flag, flag, seqno, cache_seqno);
|
ActionSummary req(flag, flag, seqno, cache_seqno);
|
||||||
|
|
||||||
@ -1068,7 +1108,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
if (act.check_ack()) {
|
if (act.check_ack()) {
|
||||||
if (act.check_point()) {
|
if (act.check_point()) {
|
||||||
// if we also have check_point, do check point first
|
// if we also have check_point, do check point first
|
||||||
utils::Assert(!act.diff_seq(),
|
_assert(!act.diff_seq(),
|
||||||
"check ack & check pt cannot occur together with normal ops");
|
"check ack & check pt cannot occur together with normal ops");
|
||||||
// if we requested checkpoint, we are free to go
|
// if we requested checkpoint, we are free to go
|
||||||
if (req.check_point()) return true;
|
if (req.check_point()) return true;
|
||||||
@ -1087,7 +1127,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
} else {
|
} else {
|
||||||
if (act.check_point()) {
|
if (act.check_point()) {
|
||||||
if (act.diff_seq()) {
|
if (act.diff_seq()) {
|
||||||
utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
_assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
||||||
// print checkpoint consensus flag if user turn on debug
|
// print checkpoint consensus flag if user turn on debug
|
||||||
if (rabit_debug) {
|
if (rabit_debug) {
|
||||||
req.print_flags(rank, "checkpoint req");
|
req.print_flags(rank, "checkpoint req");
|
||||||
@ -1112,16 +1152,16 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
if (!act.load_cache()) {
|
if (!act.load_cache()) {
|
||||||
if (act.seqno() > 0) {
|
if (act.seqno() > 0) {
|
||||||
if (!requester) {
|
if (!requester) {
|
||||||
utils::Assert(req.check_point(), "checkpoint node should be KHaveData role");
|
_assert(req.check_point(), "checkpoint node should be KHaveData role");
|
||||||
buf = resbuf.Query(act.seqno(), &size);
|
buf = resbuf.Query(act.seqno(), &size);
|
||||||
utils::Assert(buf != NULL, "buf should have data from resbuf");
|
_assert(buf != NULL, "buf should have data from resbuf");
|
||||||
utils::Assert(size > 0, "buf size should be greater than 0");
|
_assert(size > 0, "buf size should be greater than 0");
|
||||||
}
|
}
|
||||||
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
|
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// cache seq no should be smaller than kSpecialOp
|
// cache seq no should be smaller than kSpecialOp
|
||||||
utils::Assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp,
|
_assert(act.seqno(SeqType::kCache) != ActionSummary::kSpecialOp,
|
||||||
"checkpoint with kSpecialOp");
|
"checkpoint with kSpecialOp");
|
||||||
int max_cache_seq = cur_cache_seq;
|
int max_cache_seq = cur_cache_seq;
|
||||||
if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1,
|
if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1,
|
||||||
@ -1153,11 +1193,11 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
act.print_flags(rank, "loadcache act");
|
act.print_flags(rank, "loadcache act");
|
||||||
}
|
}
|
||||||
// load cache should not running in parralel with other states
|
// load cache should not running in parralel with other states
|
||||||
utils::Assert(!act.load_check(),
|
_assert(!act.load_check(),
|
||||||
"load cache state expect no nodes doing load checkpoint");
|
"load cache state expect no nodes doing load checkpoint");
|
||||||
utils::Assert(!act.check_point() ,
|
_assert(!act.check_point() ,
|
||||||
"load cache state expect no nodes doing checkpoint");
|
"load cache state expect no nodes doing checkpoint");
|
||||||
utils::Assert(!act.check_ack(),
|
_assert(!act.check_ack(),
|
||||||
"load cache state expect no nodes doing checkpoint ack");
|
"load cache state expect no nodes doing checkpoint ack");
|
||||||
|
|
||||||
// if all nodes are requester in load cache, skip
|
// if all nodes are requester in load cache, skip
|
||||||
@ -1176,10 +1216,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// assert no req with load cache set goes into seq catch up
|
// assert no req with load cache set goes into seq catch up
|
||||||
utils::Assert(!req.load_cache(), "load cache not interacte with rest states");
|
_assert(!req.load_cache(), "load cache not interacte with rest states");
|
||||||
|
|
||||||
// no special flags, no checkpoint, check ack, load_check
|
// no special flags, no checkpoint, check ack, load_check
|
||||||
utils::Assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
_assert(act.seqno() != ActionSummary::kSpecialOp, "min seq bug");
|
||||||
if (act.diff_seq()) {
|
if (act.diff_seq()) {
|
||||||
bool requester = req.seqno() == act.seqno();
|
bool requester = req.seqno() == act.seqno();
|
||||||
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
|
if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue;
|
||||||
@ -1194,7 +1234,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
|
|||||||
// something is still incomplete try next round
|
// something is still incomplete try next round
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::Assert(false, "RecoverExec: should not reach here");
|
_assert(false, "RecoverExec: should not reach here");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -1222,13 +1262,13 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
std::string &chkpt = *p_local_chkpt;
|
std::string &chkpt = *p_local_chkpt;
|
||||||
if (rptr.size() == 0) {
|
if (rptr.size() == 0) {
|
||||||
rptr.push_back(0);
|
rptr.push_back(0);
|
||||||
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
|
_assert(chkpt.length() == 0, "local chkpt space inconsistent");
|
||||||
}
|
}
|
||||||
const int n = num_local_replica;
|
const int n = num_local_replica;
|
||||||
{
|
{
|
||||||
// backward passing, passing state in backward direction of the ring
|
// backward passing, passing state in backward direction of the ring
|
||||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
_assert(nlocal <= n + 1, "invalid local replica");
|
||||||
std::vector<int> msg_back(n + 1);
|
std::vector<int> msg_back(n + 1);
|
||||||
msg_back[0] = nlocal;
|
msg_back[0] = nlocal;
|
||||||
// backward passing one hop the request
|
// backward passing one hop the request
|
||||||
@ -1282,7 +1322,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
{
|
{
|
||||||
// forward passing, passing state in forward direction of the ring
|
// forward passing, passing state in forward direction of the ring
|
||||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
_assert(nlocal <= n + 1, "invalid local replica");
|
||||||
std::vector<int> msg_forward(n + 1);
|
std::vector<int> msg_forward(n + 1);
|
||||||
msg_forward[0] = nlocal;
|
msg_forward[0] = nlocal;
|
||||||
// backward passing one hop the request
|
// backward passing one hop the request
|
||||||
@ -1367,7 +1407,7 @@ AllreduceRobust::TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
if (num_local_replica == 0) return kSuccess;
|
if (num_local_replica == 0) return kSuccess;
|
||||||
std::vector<size_t> &rptr = *p_local_rptr;
|
std::vector<size_t> &rptr = *p_local_rptr;
|
||||||
std::string &chkpt = *p_local_chkpt;
|
std::string &chkpt = *p_local_chkpt;
|
||||||
utils::Assert(rptr.size() == 2,
|
_assert(rptr.size() == 2,
|
||||||
"TryCheckinLocalState must have exactly 1 state");
|
"TryCheckinLocalState must have exactly 1 state");
|
||||||
const int n = num_local_replica;
|
const int n = num_local_replica;
|
||||||
std::vector<size_t> sizes(n + 1);
|
std::vector<size_t> sizes(n + 1);
|
||||||
@ -1423,10 +1463,10 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
|||||||
LinkRecord *read_link,
|
LinkRecord *read_link,
|
||||||
LinkRecord *write_link) {
|
LinkRecord *write_link) {
|
||||||
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
|
if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess;
|
||||||
utils::Assert(write_end <= read_end,
|
_assert(write_end <= read_end,
|
||||||
"RingPassing: boundary check1");
|
"RingPassing: boundary check1");
|
||||||
utils::Assert(read_ptr <= read_end, "RingPassing: boundary check2");
|
_assert(read_ptr <= read_end, "RingPassing: boundary check2");
|
||||||
utils::Assert(write_ptr <= write_end, "RingPassing: boundary check3");
|
_assert(write_ptr <= write_end, "RingPassing: boundary check3");
|
||||||
// take reference
|
// take reference
|
||||||
LinkRecord &prev = *read_link, &next = *write_link;
|
LinkRecord &prev = *read_link, &next = *write_link;
|
||||||
// send recv buffer
|
// send recv buffer
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
*/
|
*/
|
||||||
#ifndef RABIT_ALLREDUCE_ROBUST_H_
|
#ifndef RABIT_ALLREDUCE_ROBUST_H_
|
||||||
#define RABIT_ALLREDUCE_ROBUST_H_
|
#define RABIT_ALLREDUCE_ROBUST_H_
|
||||||
|
#include <future>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@ -632,6 +633,14 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
int local_chkpt_version;
|
int local_chkpt_version;
|
||||||
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
|
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
|
||||||
bool checkpoint_loaded;
|
bool checkpoint_loaded;
|
||||||
|
// sidecar executing timeout task
|
||||||
|
std::future<bool> rabit_timeout_task;
|
||||||
|
// flag to shutdown rabit_timeout_task before timeout
|
||||||
|
std::atomic<bool> shutdown_timeout{false};
|
||||||
|
// error handler
|
||||||
|
void (* _error)(const char *fmt, ...) = utils::Error;
|
||||||
|
// assert handler
|
||||||
|
void (* _assert)(bool exp, const char *fmt, ...) = utils::Assert;
|
||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
@ -2,14 +2,13 @@ find_package(GTest REQUIRED)
|
|||||||
|
|
||||||
add_executable(
|
add_executable(
|
||||||
unit_tests
|
unit_tests
|
||||||
allreduce_base_test.cpp
|
allreduce_base_test.cc allreduce_robust_test.cc allreduce_mock_test.cc
|
||||||
allreduce_mock_test.cpp
|
|
||||||
test_main.cpp)
|
test_main.cpp)
|
||||||
|
|
||||||
target_link_libraries(
|
target_link_libraries(
|
||||||
unit_tests
|
unit_tests PRIVATE
|
||||||
GTest::GTest GTest::Main
|
GTest::GTest GTest::Main
|
||||||
rabit_base rabit_mock)
|
rabit_base rabit_mock rabit)
|
||||||
|
|
||||||
target_include_directories(unit_tests PUBLIC
|
target_include_directories(unit_tests PUBLIC
|
||||||
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
|
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
|
||||||
|
|||||||
233
test/cpp/allreduce_robust_test.cc
Normal file
233
test/cpp/allreduce_robust_test.cc
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
#define RABIT_CXXTESTDEFS_H
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include "../../src/allreduce_robust.h"
|
||||||
|
|
||||||
|
inline void mockerr(const char *fmt, ...) {EXPECT_STRCASEEQ(fmt, "[%d] exit due to time out %d s\n");}
|
||||||
|
inline void mockassert(bool val, const char *fmt, ...) {}
|
||||||
|
rabit::engine::AllreduceRobust::ReturnType err_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSockError);
|
||||||
|
rabit::engine::AllreduceRobust::ReturnType succ_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSuccess);
|
||||||
|
|
||||||
|
TEST(allreduce_robust, sync_error_timeout)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceRobust m;
|
||||||
|
|
||||||
|
std::string rabit_timeout = "rabit_timeout=1";
|
||||||
|
char cmd[rabit_timeout.size()+1];
|
||||||
|
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||||
|
cmd[rabit_timeout.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||||
|
char cmd1[rabit_timeout_sec.size()+1];
|
||||||
|
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||||
|
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd,cmd1};
|
||||||
|
m.Init(2, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
m.rabit_bootstrap_cache = 1;
|
||||||
|
m._error = mockerr;
|
||||||
|
m._assert = mockassert;
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1500));
|
||||||
|
EXPECT_EQ(m.rabit_timeout_task.get(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(allreduce_robust, sync_error_reset)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceRobust m;
|
||||||
|
|
||||||
|
std::string rabit_timeout = "rabit_timeout=1";
|
||||||
|
char cmd[rabit_timeout.size()+1];
|
||||||
|
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||||
|
cmd[rabit_timeout.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||||
|
char cmd1[rabit_timeout_sec.size()+1];
|
||||||
|
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||||
|
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_debug = "rabit_debug=1";
|
||||||
|
char cmd2[rabit_debug.size()+1];
|
||||||
|
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||||
|
cmd2[rabit_debug.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd, cmd1,cmd2};
|
||||||
|
m.Init(3, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
m._assert = mockassert;
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||||
|
EXPECT_EQ(m.rabit_timeout_task.get(), true);
|
||||||
|
m.Shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(allreduce_robust, sync_success_error_timeout)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceRobust m;
|
||||||
|
|
||||||
|
std::string rabit_timeout = "rabit_timeout=1";
|
||||||
|
char cmd[rabit_timeout.size()+1];
|
||||||
|
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||||
|
cmd[rabit_timeout.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||||
|
char cmd1[rabit_timeout_sec.size()+1];
|
||||||
|
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||||
|
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_debug = "rabit_debug=1";
|
||||||
|
char cmd2[rabit_debug.size()+1];
|
||||||
|
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||||
|
cmd2[rabit_debug.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd, cmd1,cmd2};
|
||||||
|
m.Init(3, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
m.rabit_bootstrap_cache = 1;
|
||||||
|
m._assert = mockassert;
|
||||||
|
m._error = mockerr;
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1500));
|
||||||
|
EXPECT_EQ(m.rabit_timeout_task.get(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(allreduce_robust, sync_success_error_success)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceRobust m;
|
||||||
|
|
||||||
|
std::string rabit_timeout = "rabit_timeout=1";
|
||||||
|
char cmd[rabit_timeout.size()+1];
|
||||||
|
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||||
|
cmd[rabit_timeout.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||||
|
char cmd1[rabit_timeout_sec.size()+1];
|
||||||
|
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||||
|
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_debug = "rabit_debug=1";
|
||||||
|
char cmd2[rabit_debug.size()+1];
|
||||||
|
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||||
|
cmd2[rabit_debug.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd, cmd1,cmd2};
|
||||||
|
m.Init(3, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
m.rabit_bootstrap_cache = 1;
|
||||||
|
m._assert = mockassert;
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1100));
|
||||||
|
EXPECT_EQ(m.rabit_timeout_task.get(), true);
|
||||||
|
m.Shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(allreduce_robust, sync_error_no_reset_timeout)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceRobust m;
|
||||||
|
|
||||||
|
std::string rabit_timeout = "rabit_timeout=1";
|
||||||
|
char cmd[rabit_timeout.size()+1];
|
||||||
|
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||||
|
cmd[rabit_timeout.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||||
|
char cmd1[rabit_timeout_sec.size()+1];
|
||||||
|
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||||
|
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_debug = "rabit_debug=1";
|
||||||
|
char cmd2[rabit_debug.size()+1];
|
||||||
|
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||||
|
cmd2[rabit_debug.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd, cmd1,cmd2};
|
||||||
|
m.Init(3, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
m.rabit_bootstrap_cache = 1;
|
||||||
|
m._assert = mockassert;
|
||||||
|
m._error = mockerr;
|
||||||
|
auto start = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1100));
|
||||||
|
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||||
|
|
||||||
|
m.rabit_timeout_task.wait();
|
||||||
|
auto end = std::chrono::system_clock::now();
|
||||||
|
std::chrono::duration<double> diff = end-start;
|
||||||
|
|
||||||
|
EXPECT_EQ(m.rabit_timeout_task.get(), false);
|
||||||
|
// expect second error don't overwrite/reset timeout task
|
||||||
|
EXPECT_LT(diff.count(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(allreduce_robust, no_timeout_shut_down)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceRobust m;
|
||||||
|
|
||||||
|
std::string rabit_timeout = "rabit_timeout=1";
|
||||||
|
char cmd[rabit_timeout.size()+1];
|
||||||
|
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||||
|
cmd[rabit_timeout.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||||
|
char cmd1[rabit_timeout_sec.size()+1];
|
||||||
|
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||||
|
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_debug = "rabit_debug=1";
|
||||||
|
char cmd2[rabit_debug.size()+1];
|
||||||
|
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||||
|
cmd2[rabit_debug.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd, cmd1,cmd2};
|
||||||
|
m.Init(3, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
m.Shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(allreduce_robust, shut_down_before_timeout)
|
||||||
|
{
|
||||||
|
rabit::engine::AllreduceRobust m;
|
||||||
|
|
||||||
|
std::string rabit_timeout = "rabit_timeout=1";
|
||||||
|
char cmd[rabit_timeout.size()+1];
|
||||||
|
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||||
|
cmd[rabit_timeout.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||||
|
char cmd1[rabit_timeout_sec.size()+1];
|
||||||
|
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||||
|
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||||
|
|
||||||
|
std::string rabit_debug = "rabit_debug=1";
|
||||||
|
char cmd2[rabit_debug.size()+1];
|
||||||
|
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||||
|
cmd2[rabit_debug.size()] = '\0';
|
||||||
|
|
||||||
|
char* argv[] = {cmd, cmd1,cmd2};
|
||||||
|
m.Init(3, argv);
|
||||||
|
m.rank = 0;
|
||||||
|
rabit::engine::AllreduceRobust::LinkRecord a;
|
||||||
|
m.err_link = &a;
|
||||||
|
|
||||||
|
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
m.Shutdown();
|
||||||
|
}
|
||||||
@ -3,5 +3,6 @@
|
|||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||||
return RUN_ALL_TESTS();
|
return RUN_ALL_TESTS();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,7 +13,7 @@ all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_di
|
|||||||
|
|
||||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||||
model_recover_10_10k:
|
model_recover_10_10k:
|
||||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=-1 rabit_debug=1 rabit_reduce_ring_mincount=1
|
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 rabit_bootstrap_cache=true rabit_debug=true rabit_reduce_ring_mincount=1 rabit_timeout=true rabit_timeout_sec=5
|
||||||
|
|
||||||
model_recover_10_10k_die_same:
|
model_recover_10_10k_die_same:
|
||||||
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 rabit_bootstrap_cache=1
|
$(DMLC)/tracker/dmlc-submit --cluster local --num-workers=10 --local-num-attempt=20 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 rabit_bootstrap_cache=1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user