Remove stop process. (#143)
This commit is contained in:
parent
e6cd74ead3
commit
4acdd7c6f6
@ -65,10 +65,6 @@ namespace utils {
|
||||
/*! \brief error message buffer length */
|
||||
const int kPrintBuffer = 1 << 12;
|
||||
|
||||
/*! \brief we may want to keep the process alive when there are multiple workers
|
||||
* co-locate in the same process */
|
||||
extern bool STOP_PROCESS_ON_ERROR;
|
||||
|
||||
/* \brief Case-insensitive string comparison */
|
||||
inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
|
||||
#ifdef _MSC_VER
|
||||
@ -89,26 +85,17 @@ inline bool StringToBool(const char* s) {
|
||||
* \param msg error message
|
||||
*/
|
||||
inline void HandleAssertError(const char *msg) {
|
||||
if (STOP_PROCESS_ON_ERROR) {
|
||||
fprintf(stderr, "AssertError:%s, shutting down process\n", msg);
|
||||
exit(-1);
|
||||
} else {
|
||||
fprintf(stderr, "AssertError:%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
fprintf(stderr,
|
||||
"AssertError:%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
/*!
|
||||
* \brief handling of Check error, caused by inappropriate input
|
||||
* \param msg error message
|
||||
*/
|
||||
inline void HandleCheckError(const char *msg) {
|
||||
if (STOP_PROCESS_ON_ERROR) {
|
||||
fprintf(stderr, "%s, shutting down process\n", msg);
|
||||
exit(-1);
|
||||
} else {
|
||||
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
|
||||
throw dmlc::Error(msg);
|
||||
}
|
||||
inline void HandlePrint(const char *msg) {
|
||||
printf("%s", msg);
|
||||
|
||||
@ -13,11 +13,6 @@
|
||||
#include <map>
|
||||
|
||||
namespace rabit {
|
||||
|
||||
namespace utils {
|
||||
bool STOP_PROCESS_ON_ERROR = true;
|
||||
}
|
||||
|
||||
namespace engine {
|
||||
// constructor
|
||||
AllreduceBase::AllreduceBase(void) {
|
||||
@ -48,7 +43,6 @@ AllreduceBase::AllreduceBase(void) {
|
||||
env_vars.push_back("DMLC_TRACKER_URI");
|
||||
env_vars.push_back("DMLC_TRACKER_PORT");
|
||||
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
|
||||
env_vars.push_back("DMLC_WORKER_STOP_PROCESS_ON_ERROR");
|
||||
}
|
||||
|
||||
// initialization function
|
||||
@ -200,15 +194,6 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
|
||||
connect_retry = atoi(val);
|
||||
}
|
||||
if (!strcmp(name, "DMLC_WORKER_STOP_PROCESS_ON_ERROR")) {
|
||||
if (!strcmp(val, "true")) {
|
||||
rabit::utils::STOP_PROCESS_ON_ERROR = true;
|
||||
} else if (!strcmp(val, "false")) {
|
||||
rabit::utils::STOP_PROCESS_ON_ERROR = false;
|
||||
} else {
|
||||
throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR");
|
||||
}
|
||||
}
|
||||
if (!strcmp(name, "rabit_bootstrap_cache")) {
|
||||
rabit_bootstrap_cache = utils::StringToBool(val);
|
||||
}
|
||||
|
||||
@ -167,8 +167,8 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void AllreduceRobust::Allgather(void *sendrecvbuf,
|
||||
size_t total_size,
|
||||
size_t slice_begin,
|
||||
@ -518,8 +518,8 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
||||
}
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
|
||||
ActionSummary::kSpecialOp, cur_cache_seq),
|
||||
"check point must return true");
|
||||
ActionSummary::kSpecialOp, cur_cache_seq),
|
||||
"check point must return true");
|
||||
// this is the critical region where we will change all the stored models
|
||||
// increase version number
|
||||
version_number += 1;
|
||||
@ -550,8 +550,9 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
||||
delta = utils::GetTime() - start;
|
||||
// log checkpoint ack latency
|
||||
if (rabit_debug) {
|
||||
utils::HandleLogInfo("[%d] checkpoint ack finished version %d, take %f seconds\n",
|
||||
rank, version_number, delta);
|
||||
utils::HandleLogInfo(
|
||||
"[%d] checkpoint ack finished version %d, take %f seconds\n", rank,
|
||||
version_number, delta);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
|
||||
@ -12,11 +12,6 @@
|
||||
#include "rabit/internal/engine.h"
|
||||
|
||||
namespace rabit {
|
||||
|
||||
namespace utils {
|
||||
bool STOP_PROCESS_ON_ERROR = true;
|
||||
}
|
||||
|
||||
namespace engine {
|
||||
/*! \brief EmptyEngine */
|
||||
class EmptyEngine : public IEngine {
|
||||
|
||||
@ -15,11 +15,6 @@
|
||||
#include "rabit/internal/utils.h"
|
||||
|
||||
namespace rabit {
|
||||
|
||||
namespace utils {
|
||||
bool STOP_PROCESS_ON_ERROR = true;
|
||||
}
|
||||
|
||||
namespace engine {
|
||||
/*! \brief implementation of engine using MPI */
|
||||
class MPIEngine : public IEngine {
|
||||
|
||||
@ -3,6 +3,7 @@ find_package(GTest REQUIRED)
|
||||
add_executable(
|
||||
unit_tests
|
||||
test_io.cc
|
||||
test_utils.cc
|
||||
allreduce_robust_test.cc
|
||||
allreduce_base_test.cc
|
||||
allreduce_mock_test.cc
|
||||
|
||||
@ -17,7 +17,7 @@ TEST(allreduce_mock, mock_allreduce)
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
|
||||
EXPECT_THROW(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(allreduce_mock, mock_broadcast)
|
||||
@ -32,5 +32,5 @@ TEST(allreduce_mock, mock_broadcast)
|
||||
m.rank = 0;
|
||||
m.version_number=1;
|
||||
m.seq_counter=2;
|
||||
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
|
||||
EXPECT_THROW(m.Broadcast(nullptr,0,0), dmlc::Error);
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <dmlc/logging.h>
|
||||
#include "../../src/allreduce_mock.h"
|
||||
|
||||
TEST(allreduce_mock, mock_allreduce)
|
||||
@ -17,7 +18,7 @@ TEST(allreduce_mock, mock_allreduce)
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
|
||||
EXPECT_THROW({m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr);}, dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(allreduce_mock, mock_broadcast)
|
||||
@ -32,7 +33,7 @@ TEST(allreduce_mock, mock_broadcast)
|
||||
m.rank = 0;
|
||||
m.version_number=1;
|
||||
m.seq_counter=2;
|
||||
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
|
||||
EXPECT_THROW({m.Broadcast(nullptr,0,0);}, dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(allreduce_mock, mock_gather)
|
||||
@ -47,5 +48,5 @@ TEST(allreduce_mock, mock_gather)
|
||||
m.rank = 3;
|
||||
m.version_number=13;
|
||||
m.seq_counter=22;
|
||||
EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), "");
|
||||
EXPECT_THROW({m.Allgather(nullptr,0,0,0,0);}, dmlc::Error);
|
||||
}
|
||||
|
||||
6
test/cpp/test_utils.cc
Normal file
6
test/cpp/test_utils.cc
Normal file
@ -0,0 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <rabit/internal/utils.h>
|
||||
|
||||
TEST(Utils, Assert) {
|
||||
EXPECT_THROW({rabit::utils::Assert(false, "foo");}, dmlc::Error);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user