diff --git a/include/rabit/internal/utils.h b/include/rabit/internal/utils.h index 4b36378bd..387f78f7a 100644 --- a/include/rabit/internal/utils.h +++ b/include/rabit/internal/utils.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #ifndef RABIT_STRICT_CXX98_ @@ -61,22 +62,36 @@ namespace utils { /*! \brief error message buffer length */ const int kPrintBuffer = 1 << 12; +/*! \brief we may want to keep the process alive when there are multiple workers + * co-locate in the same process */ +extern bool STOP_PROCESS_ON_ERROR; + #ifndef RABIT_CUSTOMIZE_MSG_ /*! * \brief handling of Assert error, caused by inappropriate input * \param msg error message */ inline void HandleAssertError(const char *msg) { - fprintf(stderr, "AssertError:%s\n", msg); - exit(-1); + if (STOP_PROCESS_ON_ERROR) { + fprintf(stderr, "AssertError:%s, shutting down process\n", msg); + exit(-1); + } else { + fprintf(stderr, "AssertError:%s, rabit is configured to keep process running\n", msg); + throw std::runtime_error(msg); + } } /*! * \brief handling of Check error, caused by inappropriate input * \param msg error message */ inline void HandleCheckError(const char *msg) { - fprintf(stderr, "%s\n", msg); - exit(-1); + if (STOP_PROCESS_ON_ERROR) { + fprintf(stderr, "%s, shutting down process", msg); + exit(-1); + } else { + fprintf(stderr, "%s, rabit is configured to keep process running\n", msg); + throw std::runtime_error(msg); + } } inline void HandlePrint(const char *msg) { printf("%s", msg); diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 603516997..cdff4446e 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -14,6 +14,11 @@ #include "./allreduce_base.h" namespace rabit { + +namespace utils { + bool STOP_PROCESS_ON_ERROR = true; +} + namespace engine { // constructor AllreduceBase::AllreduceBase(void) { @@ -48,6 +53,7 @@ AllreduceBase::AllreduceBase(void) { env_vars.push_back("DMLC_TRACKER_URI"); env_vars.push_back("DMLC_TRACKER_PORT"); env_vars.push_back("DMLC_WORKER_CONNECT_RETRY"); + env_vars.push_back("DMLC_WORKER_STOP_PROCESS_ON_ERROR"); } // initialization function @@ -190,6 +196,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) { connect_retry = atoi(val); } + if (!strcmp(name, "DMLC_WORKER_STOP_PROCESS_ON_ERROR")) { + if (!strcmp(val, "true")) { + rabit::utils::STOP_PROCESS_ON_ERROR = true; + } else if (!strcmp(val, "false")) { + rabit::utils::STOP_PROCESS_ON_ERROR = false; + } else { + throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR"); + } + } } /*! * \brief initialize connection to the tracker diff --git a/src/engine_empty.cc b/src/engine_empty.cc index 8177410ad..c3b210d9a 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -13,6 +13,11 @@ #include "../include/rabit/internal/engine.h" namespace rabit { + +namespace utils { + bool STOP_PROCESS_ON_ERROR = true; +} + namespace engine { /*! \brief EmptyEngine */ class EmptyEngine : public IEngine { diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 35283ad5a..6dc6cb7ca 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -15,6 +15,11 @@ #include "../include/rabit/internal/utils.h" namespace rabit { + +namespace utils { + bool STOP_PROCESS_ON_ERROR = true; +} + namespace engine { /*! \brief implementation of engine using MPI */ class MPIEngine : public IEngine {