From 4acdd7c6f68debe1c39ae07ca75466d74d194dd1 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Aug 2020 01:12:00 +0800 Subject: [PATCH] Remove stop process. (#143) --- include/rabit/internal/utils.h | 23 +++++------------------ src/allreduce_base.cc | 15 --------------- src/allreduce_robust.cc | 13 +++++++------ src/engine_empty.cc | 5 ----- src/engine_mpi.cc | 5 ----- test/cpp/CMakeLists.txt | 1 + test/cpp/allreduce_mock_test.cc | 4 ++-- test/cpp/allreduce_mock_test.cpp | 7 ++++--- test/cpp/test_utils.cc | 6 ++++++ 9 files changed, 25 insertions(+), 54 deletions(-) create mode 100644 test/cpp/test_utils.cc diff --git a/include/rabit/internal/utils.h b/include/rabit/internal/utils.h index 918a913ca..5a6b43b6c 100644 --- a/include/rabit/internal/utils.h +++ b/include/rabit/internal/utils.h @@ -65,10 +65,6 @@ namespace utils { /*! \brief error message buffer length */ const int kPrintBuffer = 1 << 12; -/*! \brief we may want to keep the process alive when there are multiple workers - * co-locate in the same process */ -extern bool STOP_PROCESS_ON_ERROR; - /* \brief Case-insensitive string comparison */ inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) { #ifdef _MSC_VER @@ -89,26 +85,17 @@ inline bool StringToBool(const char* s) { * \param msg error message */ inline void HandleAssertError(const char *msg) { - if (STOP_PROCESS_ON_ERROR) { - fprintf(stderr, "AssertError:%s, shutting down process\n", msg); - exit(-1); - } else { - fprintf(stderr, "AssertError:%s, rabit is configured to keep process running\n", msg); - throw dmlc::Error(msg); - } + fprintf(stderr, + "AssertError:%s, rabit is configured to keep process running\n", msg); + throw dmlc::Error(msg); } /*! * \brief handling of Check error, caused by inappropriate input * \param msg error message */ inline void HandleCheckError(const char *msg) { - if (STOP_PROCESS_ON_ERROR) { - fprintf(stderr, "%s, shutting down process\n", msg); - exit(-1); - } else { - fprintf(stderr, "%s, rabit is configured to keep process running\n", msg); - throw dmlc::Error(msg); - } + fprintf(stderr, "%s, rabit is configured to keep process running\n", msg); + throw dmlc::Error(msg); } inline void HandlePrint(const char *msg) { printf("%s", msg); diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index ca02771a4..a5b199df8 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -13,11 +13,6 @@ #include namespace rabit { - -namespace utils { - bool STOP_PROCESS_ON_ERROR = true; -} - namespace engine { // constructor AllreduceBase::AllreduceBase(void) { @@ -48,7 +43,6 @@ AllreduceBase::AllreduceBase(void) { env_vars.push_back("DMLC_TRACKER_URI"); env_vars.push_back("DMLC_TRACKER_PORT"); env_vars.push_back("DMLC_WORKER_CONNECT_RETRY"); - env_vars.push_back("DMLC_WORKER_STOP_PROCESS_ON_ERROR"); } // initialization function @@ -200,15 +194,6 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) { connect_retry = atoi(val); } - if (!strcmp(name, "DMLC_WORKER_STOP_PROCESS_ON_ERROR")) { - if (!strcmp(val, "true")) { - rabit::utils::STOP_PROCESS_ON_ERROR = true; - } else if (!strcmp(val, "false")) { - rabit::utils::STOP_PROCESS_ON_ERROR = false; - } else { - throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR"); - } - } if (!strcmp(name, "rabit_bootstrap_cache")) { rabit_bootstrap_cache = utils::StringToBool(val); } diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 1ce407d69..de962055f 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -167,8 +167,8 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size * \param _file caller file name used to generate unique cache key * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ + * \param _caller caller function name used to generate unique cache key + */ void AllreduceRobust::Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, @@ -518,8 +518,8 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model, } // execute checkpoint, note: when checkpoint existing, load will not happen _assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, - ActionSummary::kSpecialOp, cur_cache_seq), - "check point must return true"); + ActionSummary::kSpecialOp, cur_cache_seq), + "check point must return true"); // this is the critical region where we will change all the stored models // increase version number version_number += 1; @@ -550,8 +550,9 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model, delta = utils::GetTime() - start; // log checkpoint ack latency if (rabit_debug) { - utils::HandleLogInfo("[%d] checkpoint ack finished version %d, take %f seconds\n", - rank, version_number, delta); + utils::HandleLogInfo( + "[%d] checkpoint ack finished version %d, take %f seconds\n", rank, + version_number, delta); } } /*! diff --git a/src/engine_empty.cc b/src/engine_empty.cc index b8e7e5ad0..53ec85ee3 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -12,11 +12,6 @@ #include "rabit/internal/engine.h" namespace rabit { - -namespace utils { - bool STOP_PROCESS_ON_ERROR = true; -} - namespace engine { /*! \brief EmptyEngine */ class EmptyEngine : public IEngine { diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 107baf0a3..c63f13a78 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -15,11 +15,6 @@ #include "rabit/internal/utils.h" namespace rabit { - -namespace utils { - bool STOP_PROCESS_ON_ERROR = true; -} - namespace engine { /*! \brief implementation of engine using MPI */ class MPIEngine : public IEngine { diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 979c059a8..9216d080c 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -3,6 +3,7 @@ find_package(GTest REQUIRED) add_executable( unit_tests test_io.cc + test_utils.cc allreduce_robust_test.cc allreduce_base_test.cc allreduce_mock_test.cc diff --git a/test/cpp/allreduce_mock_test.cc b/test/cpp/allreduce_mock_test.cc index e659d8ea8..5d03dc71a 100644 --- a/test/cpp/allreduce_mock_test.cc +++ b/test/cpp/allreduce_mock_test.cc @@ -17,7 +17,7 @@ TEST(allreduce_mock, mock_allreduce) char* argv[] = {cmd}; m.Init(1, argv); m.rank = 0; - EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), ""); + EXPECT_THROW(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), dmlc::Error); } TEST(allreduce_mock, mock_broadcast) @@ -32,5 +32,5 @@ TEST(allreduce_mock, mock_broadcast) m.rank = 0; m.version_number=1; m.seq_counter=2; - EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), ""); + EXPECT_THROW(m.Broadcast(nullptr,0,0), dmlc::Error); } diff --git a/test/cpp/allreduce_mock_test.cpp b/test/cpp/allreduce_mock_test.cpp index ec3190c96..5d7a1b16f 100644 --- a/test/cpp/allreduce_mock_test.cpp +++ b/test/cpp/allreduce_mock_test.cpp @@ -3,6 +3,7 @@ #include #include +#include #include "../../src/allreduce_mock.h" TEST(allreduce_mock, mock_allreduce) @@ -17,7 +18,7 @@ TEST(allreduce_mock, mock_allreduce) char* argv[] = {cmd}; m.Init(1, argv); m.rank = 0; - EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), ""); + EXPECT_THROW({m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr);}, dmlc::Error); } TEST(allreduce_mock, mock_broadcast) @@ -32,7 +33,7 @@ TEST(allreduce_mock, mock_broadcast) m.rank = 0; m.version_number=1; m.seq_counter=2; - EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), ""); + EXPECT_THROW({m.Broadcast(nullptr,0,0);}, dmlc::Error); } TEST(allreduce_mock, mock_gather) @@ -47,5 +48,5 @@ TEST(allreduce_mock, mock_gather) m.rank = 3; m.version_number=13; m.seq_counter=22; - EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), ""); + EXPECT_THROW({m.Allgather(nullptr,0,0,0,0);}, dmlc::Error); } diff --git a/test/cpp/test_utils.cc b/test/cpp/test_utils.cc new file mode 100644 index 000000000..0b8787bdd --- /dev/null +++ b/test/cpp/test_utils.cc @@ -0,0 +1,6 @@ +#include +#include + +TEST(Utils, Assert) { + EXPECT_THROW({rabit::utils::Assert(false, "foo");}, dmlc::Error); +}