From b5c2a47b20d2b075b0ed3b1a8cdd098a280d9f94 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 21 Oct 2020 15:27:03 +0800 Subject: [PATCH] Drop single point model recovery (#6262) * Pass rabit params in JVM package. * Implement timeout using poll timeout parameter. * Remove OOB data check. --- Jenkinsfile | 14 - R-package/src/Makevars.in | 2 +- R-package/src/Makevars.win | 2 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 3 + rabit/CMakeLists.txt | 5 +- rabit/include/rabit/internal/socket.h | 31 +- rabit/src/CMakeLists.txt | 31 - rabit/src/README.md | 6 - rabit/src/allreduce_base.cc | 39 +- rabit/src/allreduce_base.h | 21 +- rabit/src/allreduce_mock.h | 33 +- rabit/src/allreduce_robust-inl.h | 169 -- rabit/src/allreduce_robust.cc | 1590 ----------------- rabit/src/allreduce_robust.h | 665 ------- rabit/src/engine.cc | 3 +- src/learner.cc | 6 - tests/ci_build/approx.conf.in | 12 - tests/ci_build/runxgb.sh | 13 - tests/cpp/rabit/allreduce_mock_test.cc | 53 - tests/cpp/rabit/allreduce_robust_test.cc | 235 --- tests/distributed/distributed_gpu.py | 7 +- tests/distributed/runtests-gpu.sh | 2 + 22 files changed, 63 insertions(+), 2879 deletions(-) delete mode 100644 rabit/src/CMakeLists.txt delete mode 100644 rabit/src/README.md delete mode 100644 rabit/src/allreduce_robust-inl.h delete mode 100644 rabit/src/allreduce_robust.cc delete mode 100644 rabit/src/allreduce_robust.h delete mode 100644 tests/ci_build/approx.conf.in delete mode 100755 tests/ci_build/runxgb.sh delete mode 100644 tests/cpp/rabit/allreduce_mock_test.cc delete mode 100644 tests/cpp/rabit/allreduce_robust_test.cc diff --git a/Jenkinsfile b/Jenkinsfile index 5f86758f8..690b22583 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -321,20 +321,6 @@ def TestPythonGPU(args) { } } -def TestCppRabit() { - node(nodeReq) { - unstash name: 'xgboost_rabit_tests' - unstash name: 'srcs' - echo "Test C++, rabit mock on" - def container_type = "cpu" - def docker_binary = "docker" - sh """ - ${dockerRun} ${container_type} ${docker_binary} tests/ci_build/runxgb.sh xgboost tests/ci_build/approx.conf.in - """ - deleteDir() - } -} - def TestCppGPU(args) { def nodeReq = 'linux && mgpu' def artifact_cuda_version = (args.artifact_cuda_version) ?: ref_cuda_ver diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index e0524a8eb..65a26bbbd 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -22,4 +22,4 @@ PKG_LIBS = @OPENMP_CXXFLAGS@ @OPENMP_LIB@ @ENDIAN_FLAG@ @BACKTRACE_LIB@ -pthread OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \ $(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \ $(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \ - $(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o + $(PKGROOT)/rabit/src/allreduce_base.o diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 727807af6..3583b4f17 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -34,6 +34,6 @@ PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS) OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \ $(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \ $(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \ - $(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o + $(PKGROOT)/rabit/src/allreduce_base.o $(OBJECTS) : xgblib diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 66cce0c32..eb184f1a2 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -577,6 +577,7 @@ object XGBoost extends Serializable { logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext) val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams + val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava val sc = trainingData.sparkContext val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet, hasGroup, xgbExecParams.numWorkers) @@ -595,6 +596,8 @@ object XGBoost extends Serializable { xgbExecParams.timeoutRequestWorkers, xgbExecParams.numWorkers, xgbExecParams.killSparkContextOnWorkerFailure) + + tracker.getWorkerEnvs().putAll(xgbRabitParams) val rabitEnv = tracker.getWorkerEnvs val boostersAndMetrics = if (hasGroup) { trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster, diff --git a/rabit/CMakeLists.txt b/rabit/CMakeLists.txt index 59df3ebdd..eeb056e83 100644 --- a/rabit/CMakeLists.txt +++ b/rabit/CMakeLists.txt @@ -2,8 +2,9 @@ cmake_minimum_required(VERSION 3.3) find_package(Threads REQUIRED) -add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc) -add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc) +add_library(rabit src/allreduce_base.cc src/engine.cc src/c_api.cc) +add_library(rabit_mock_static src/allreduce_base.cc src/engine_mock.cc src/c_api.cc) + target_link_libraries(rabit Threads::Threads dmlc) target_link_libraries(rabit_mock_static Threads::Threads dmlc) diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index a6348eb6b..b93cc953d 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include "utils.h" @@ -95,18 +96,18 @@ namespace utils { static constexpr int kInvalidSocket = -1; template -int PollImpl(PollFD *pfd, int nfds, int timeout) { +int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) { #if defined(_WIN32) #if IS_MINGW() MingWError(); return -1; #else - return WSAPoll(pfd, nfds, timeout); + return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count()); #endif // IS_MINGW() #else - return poll(pfd, nfds, timeout); + return poll(pfd, nfds, std::chrono::milliseconds(timeout).count()); #endif // IS_MINGW() } @@ -608,40 +609,20 @@ struct PollHelper { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); } - /*! - * \brief Check if the descriptor has any exception - * \param fd file descriptor to check status - */ - inline bool CheckExcept(SOCKET fd) const { - const auto& pfd = fds.find(fd); - return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); - } - /*! - * \brief wait for exception event on a single descriptor - * \param fd the file descriptor to wait the event for - * \param timeout the timeout counter, can be negative, which means wait until the event happen - * \return 1 if success, 0 if timeout, and -1 if error occurs - */ - inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*) - pollfd pfd; - pfd.fd = fd; - pfd.events = POLLPRI; - return PollImpl(&pfd, 1, timeout); - } /*! * \brief peform poll on the set defined, read, write, exception * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block * \return */ - inline void Poll(long timeout = -1) { // NOLINT(*) + inline void Poll(std::chrono::seconds timeout) { // NOLINT(*) std::vector fdset; fdset.reserve(fds.size()); for (auto kv : fds) { fdset.push_back(kv.second); } int ret = PollImpl(fdset.data(), fdset.size(), timeout); - if (ret == -1) { + if (ret <= 0) { Socket::Error("Poll"); } else { for (auto& pfd : fdset) { diff --git a/rabit/src/CMakeLists.txt b/rabit/src/CMakeLists.txt deleted file mode 100644 index a7aa5d03b..000000000 --- a/rabit/src/CMakeLists.txt +++ /dev/null @@ -1,31 +0,0 @@ -option(DMLC_ROOT "Specify root of external dmlc core.") - -add_library(allreduce_base "") -add_library(allreduce_mock "") - -target_sources( - allreduce_base - PRIVATE - allreduce_base.cc - PUBLIC - ${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h -) -target_sources( - allreduce_mock - PRIVATE - allreduce_robust.cc - PUBLIC - ${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h -) - -target_include_directories( - allreduce_base - PUBLIC - ${DMLC_ROOT}/include - ${CMAKE_CURRENT_LIST_DIR}/../../include) - -target_include_directories( - allreduce_mock - PUBLIC - ${DMLC_ROOT}/include - ${CMAKE_CURRENT_LIST_DIR}/../../include) diff --git a/rabit/src/README.md b/rabit/src/README.md deleted file mode 100644 index 5e55d9210..000000000 --- a/rabit/src/README.md +++ /dev/null @@ -1,6 +0,0 @@ -Source Files of Rabit -==== -* This folder contains the source files of rabit library -* The library headers are in folder [include](../include) -* The .h files in this folder are internal header files that are only used by rabit and will not be seen by users - diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index d1959eaa6..60c8f744b 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -6,8 +6,9 @@ * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou */ #define NOMINMAX +#include "rabit/base.h" +#include "rabit/internal/rabit-inl.h" #include "allreduce_base.h" -#include #ifndef _WIN32 #include @@ -208,8 +209,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) { rabit_timeout = utils::StringToBool(val); } if (!strcmp(name, "rabit_timeout_sec")) { - timeout_sec = atoi(val); - utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second"); + timeout_sec = std::chrono::seconds(atoi(val)); + utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second"); } if (!strcmp(name, "rabit_enable_tcp_no_delay")) { if (!strcmp(val, "true")) { @@ -549,14 +550,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, // finish runing allreduce if (finished) break; // select must return - watcher.Poll(); - // exception handling - for (int i = 0; i < nlink; ++i) { - // recive OOB message from some link - if (watcher.CheckExcept(links[i].sock)) { - return ReportError(&links[i], kGetExcept); - } - } + watcher.Poll(timeout_sec); // read data from childs for (int i = 0; i < nlink; ++i) { if (i != parent_index && watcher.CheckRead(links[i].sock)) { @@ -729,14 +723,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { // finish running if (finished) break; // select - watcher.Poll(); - // exception handling - for (int i = 0; i < nlink; ++i) { - // recive OOB message from some link - if (watcher.CheckExcept(links[i].sock)) { - return ReportError(&links[i], kGetExcept); - } - } + watcher.Poll(timeout_sec); if (in_link == -2) { // probe in-link for (int i = 0; i < nlink; ++i) { @@ -819,7 +806,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, finished = false; } if (finished) break; - watcher.Poll(); + watcher.Poll(timeout_sec); if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { size_t size = stop_read - read_ptr; size_t start = read_ptr % total_size; @@ -831,7 +818,10 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, read_ptr += static_cast(len); } else { ReturnType ret = Errno2Return(); - if (ret != kSuccess) return ReportError(&next, ret); + if (ret != kSuccess) { + auto err = ReportError(&next, ret); + return err; + } } } if (write_ptr < read_ptr && write_ptr != stop_write) { @@ -845,7 +835,10 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, write_ptr += static_cast(len); } else { ReturnType ret = Errno2Return(); - if (ret != kSuccess) return ReportError(&prev, ret); + if (ret != kSuccess) { + auto err = ReportError(&prev, ret); + return err; + } } } } @@ -913,7 +906,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, finished = false; } if (finished) break; - watcher.Poll(); + watcher.Poll(timeout_sec); if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); if (ret != kSuccess) { diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h index 815b548f6..90c55e3de 100644 --- a/rabit/src/allreduce_base.h +++ b/rabit/src/allreduce_base.h @@ -12,6 +12,8 @@ #ifndef RABIT_ALLREDUCE_BASE_H_ #define RABIT_ALLREDUCE_BASE_H_ +#include +#include #include #include #include @@ -35,6 +37,7 @@ class Datatype { } namespace rabit { namespace engine { + /*! \brief implementation of basic Allreduce engine */ class AllreduceBase : public IEngine { public: @@ -103,9 +106,11 @@ class AllreduceBase : public IEngine { size_t slice_end, size_t size_prev_slice, const char *_file = _FILE, const int _line = _LINE, const char *_caller = _CALLER) override { - if (world_size == 1 || world_size == -1) return; - utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, - slice_begin, slice_end, size_prev_slice) == kSuccess, + if (world_size == 1 || world_size == -1) { + return; + } + utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin, + slice_end, size_prev_slice) == kSuccess, "AllgatherRing failed"); } /*! @@ -130,8 +135,8 @@ class AllreduceBase : public IEngine { const char *_caller = _CALLER) override { if (prepare_fun != nullptr) prepare_fun(prepare_arg); if (world_size == 1 || world_size == -1) return; - utils::Assert(TryAllreduce(sendrecvbuf_, - type_nbytes, count, reducer) == kSuccess, + utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == + kSuccess, "Allreduce failed"); } /*! @@ -518,9 +523,9 @@ class AllreduceBase : public IEngine { //---- data structure related to model ---- // call sequence counter, records how many calls we made so far // from last call to CheckPoint, LoadCheckPoint - int seq_counter; // NOLINT + int seq_counter{0}; // NOLINT // version number of model - int version_number; // NOLINT + int version_number {0}; // NOLINT // whether the job is running in hadoop bool hadoop_mode; // NOLINT //---- local data related to link ---- @@ -571,7 +576,7 @@ class AllreduceBase : public IEngine { // enable detailed logging bool rabit_debug = false; // NOLINT // by default, if rabit worker not recover in half an hour exit - int timeout_sec = 1800; // NOLINT + std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT // flag to enable rabit_timeout bool rabit_timeout = false; // NOLINT // Enable TCP node delay diff --git a/rabit/src/allreduce_mock.h b/rabit/src/allreduce_mock.h index 7c0a25e80..a0e725f8e 100644 --- a/rabit/src/allreduce_mock.h +++ b/rabit/src/allreduce_mock.h @@ -13,11 +13,11 @@ #include #include "rabit/internal/engine.h" #include "rabit/internal/timer.h" -#include "allreduce_robust.h" +#include "allreduce_base.h" namespace rabit { namespace engine { -class AllreduceMock : public AllreduceRobust { +class AllreduceMock : public AllreduceBase { public: // constructor AllreduceMock() { @@ -30,7 +30,7 @@ class AllreduceMock : public AllreduceRobust { // destructor ~AllreduceMock() override = default; void SetParam(const char *name, const char *val) override { - AllreduceRobust::SetParam(name, val); + AllreduceBase::SetParam(name, val); // additional parameters if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val); if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val); @@ -51,9 +51,8 @@ class AllreduceMock : public AllreduceRobust { const char *_caller = _CALLER) override { this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce"); double tstart = utils::GetTime(); - AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, - count, reducer, prepare_fun, prepare_arg, - _file, _line, _caller); + AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer, + prepare_fun, prepare_arg, _file, _line, _caller); tsum_allreduce_ += utils::GetTime() - tstart; } void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, @@ -62,16 +61,15 @@ class AllreduceMock : public AllreduceRobust { const char *_caller = _CALLER) override { this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather"); double tstart = utils::GetTime(); - AllreduceRobust::Allgather(sendrecvbuf, total_size, - slice_begin, slice_end, - size_prev_slice, _file, _line, _caller); + AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end, + size_prev_slice, _file, _line, _caller); tsum_allgather_ += utils::GetTime() - tstart; } void Broadcast(void *sendrecvbuf_, size_t total_size, int root, const char *_file = _FILE, const int _line = _LINE, const char *_caller = _CALLER) override { this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast"); - AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller); + AllreduceBase::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller); } int LoadCheckPoint(Serializable *global_model, Serializable *local_model) override { @@ -79,11 +77,11 @@ class AllreduceMock : public AllreduceRobust { tsum_allgather_ = 0.0; time_checkpoint_ = utils::GetTime(); if (force_local_ == 0) { - return AllreduceRobust::LoadCheckPoint(global_model, local_model); + return AllreduceBase::LoadCheckPoint(global_model, local_model); } else { DummySerializer dum; ComboSerializer com(global_model, local_model); - return AllreduceRobust::LoadCheckPoint(&dum, &com); + return AllreduceBase::LoadCheckPoint(&dum, &com); } } void CheckPoint(const Serializable *global_model, @@ -92,18 +90,17 @@ class AllreduceMock : public AllreduceRobust { double tstart = utils::GetTime(); double tbet_chkpt = tstart - time_checkpoint_; if (force_local_ == 0) { - AllreduceRobust::CheckPoint(global_model, local_model); + AllreduceBase::CheckPoint(global_model, local_model); } else { DummySerializer dum; ComboSerializer com(global_model, local_model); - AllreduceRobust::CheckPoint(&dum, &com); + AllreduceBase::CheckPoint(&dum, &com); } time_checkpoint_ = utils::GetTime(); double tcost = utils::GetTime() - tstart; if (report_stats_ != 0 && rank == 0) { std::stringstream ss; - ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length() - << ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length()) + ss << "[v" << version_number << "] global_size=" << ",check_tcost="<< tcost <<" sec" << ",allreduce_tcost=" << tsum_allreduce_ << " sec" << ",allgather_tcost=" << tsum_allgather_ << " sec" @@ -116,7 +113,7 @@ class AllreduceMock : public AllreduceRobust { void LazyCheckPoint(const Serializable *global_model) override { this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint"); - AllreduceRobust::LazyCheckPoint(global_model); + AllreduceBase::LazyCheckPoint(global_model); } protected: @@ -186,7 +183,7 @@ class AllreduceMock : public AllreduceRobust { if (mock_map_.count(key) != 0) { num_trial_ += 1; // data processing frameworks runs on shared process - error_("[%d]@@@Hit Mock Error:%s ", rank, name); + throw dmlc::Error(std::to_string(rank) + "@@@Hit Mock Error: " + name); } } }; diff --git a/rabit/src/allreduce_robust-inl.h b/rabit/src/allreduce_robust-inl.h deleted file mode 100644 index 7dcbbb456..000000000 --- a/rabit/src/allreduce_robust-inl.h +++ /dev/null @@ -1,169 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file allreduce_robust-inl.h - * \brief implementation of inline template function in AllreduceRobust - * - * \author Tianqi Chen - */ -#ifndef RABIT_ALLREDUCE_ROBUST_INL_H_ -#define RABIT_ALLREDUCE_ROBUST_INL_H_ -#include - -namespace rabit { -namespace engine { -/*! - * \brief run message passing algorithm on the allreduce tree - * the result is edge message stored in p_edge_in and p_edge_out - * \param node_value the value associated with current node - * \param p_edge_in used to store input message from each of the edge - * \param p_edge_out used to store output message from each of the edge - * \param func a function that defines the message passing rule - * Parameters of func: - * - node_value same as node_value in the main function - * - edge_in the array of input messages from each edge, - * this includes the output edge, which should be excluded - * - out_index array the index of output edge, the function should - * exclude the output edge when compute the message passing value - * Return of func: - * the function returns the output message based on the input message and node_value - * - * \tparam EdgeType type of edge message, must be simple struct - * \tparam NodeType type of node value - */ -template -inline AllreduceRobust::ReturnType -AllreduceRobust::MsgPassing(const NodeType &node_value, - std::vector *p_edge_in, - std::vector *p_edge_out, - EdgeType(*func) - (const NodeType &node_value, - const std::vector &edge_in, - size_t out_index)) { - RefLinkVector &links = tree_links; - if (links.Size() == 0) return kSuccess; - // number of links - const int nlink = static_cast(links.Size()); - // initialize the pointers - for (int i = 0; i < nlink; ++i) { - links[i].ResetSize(); - } - std::vector &edge_in = *p_edge_in; - std::vector &edge_out = *p_edge_out; - edge_in.resize(nlink); - edge_out.resize(nlink); - // stages in the process - // 0: recv messages from childs - // 1: send message to parent - // 2: recv message from parent - // 3: send message to childs - int stage = 0; - // if no childs, no need to, directly start passing message - if (nlink == static_cast(parent_index != -1)) { - utils::Assert(parent_index == 0, "parent must be 0"); - edge_out[parent_index] = func(node_value, edge_in, parent_index); - stage = 1; - } - // while we have not passed the messages out - while (true) { - // for node with no parent, directly do stage 3 - if (parent_index == -1) { - utils::Assert(stage != 2 && stage != 1, "invalie stage id"); - } - // poll helper - utils::PollHelper watcher; - bool done = (stage == 3); - for (int i = 0; i < nlink; ++i) { - watcher.WatchException(links[i].sock); - switch (stage) { - case 0: - if (i != parent_index && links[i].size_read != sizeof(EdgeType)) { - watcher.WatchRead(links[i].sock); - } - break; - case 1: - if (i == parent_index) { - watcher.WatchWrite(links[i].sock); - } - break; - case 2: - if (i == parent_index) { - watcher.WatchRead(links[i].sock); - } - break; - case 3: - if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { - watcher.WatchWrite(links[i].sock); - done = false; - } - break; - default: utils::Error("invalid stage"); - } - } - // finish all the stages, and write out message - if (done) break; - watcher.Poll(); - // exception handling - for (int i = 0; i < nlink; ++i) { - // recive OOB message from some link - if (watcher.CheckExcept(links[i].sock)) { - return ReportError(&links[i], kGetExcept); - } - } - if (stage == 0) { - bool finished = true; - // read data from childs - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - if (watcher.CheckRead(links[i].sock)) { - ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType)); - if (ret != kSuccess) return ReportError(&links[i], ret); - } - if (links[i].size_read != sizeof(EdgeType)) finished = false; - } - } - // if no parent, jump to stage 3, otherwise do stage 1 - if (finished) { - if (parent_index != -1) { - edge_out[parent_index] = func(node_value, edge_in, parent_index); - stage = 1; - } else { - for (int i = 0; i < nlink; ++i) { - edge_out[i] = func(node_value, edge_in, i); - } - stage = 3; - } - } - } - if (stage == 1) { - const int pid = this->parent_index; - utils::Assert(pid != -1, "MsgPassing invalid stage"); - ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType)); - if (ret != kSuccess) return ReportError(&links[pid], ret); - if (links[pid].size_write == sizeof(EdgeType)) stage = 2; - } - if (stage == 2) { - const int pid = this->parent_index; - utils::Assert(pid != -1, "MsgPassing invalid stage"); - ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType)); - if (ret != kSuccess) return ReportError(&links[pid], ret); - if (links[pid].size_read == sizeof(EdgeType)) { - for (int i = 0; i < nlink; ++i) { - if (i != pid) edge_out[i] = func(node_value, edge_in, i); - } - stage = 3; - } - } - if (stage == 3) { - for (int i = 0; i < nlink; ++i) { - if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { - ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType)); - if (ret != kSuccess) return ReportError(&links[i], ret); - } - } - } - } - return kSuccess; -} -} // namespace engine -} // namespace rabit -#endif // RABIT_ALLREDUCE_ROBUST_INL_H_ diff --git a/rabit/src/allreduce_robust.cc b/rabit/src/allreduce_robust.cc deleted file mode 100644 index 8fb5f8183..000000000 --- a/rabit/src/allreduce_robust.cc +++ /dev/null @@ -1,1590 +0,0 @@ -/*! - * Copyright (c) 2014-2019 by Contributors - * \file allreduce_robust.cc - * \brief Robust implementation of Allreduce - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#define NOMINMAX -#include -#include -#include -#include -#include -#include "rabit/internal/io.h" -#include "rabit/internal/timer.h" -#include "rabit/internal/utils.h" -#include "rabit/internal/engine.h" -#include "rabit/internal/rabit-inl.h" -#include "allreduce_robust.h" - -#undef _assert - -namespace rabit { -namespace engine { - -AllreduceRobust::AllreduceRobust() { - num_local_replica_ = 0; - num_global_replica_ = 5; - default_local_replica_ = 2; - seq_counter = 0; - cur_cache_seq_ = 0; - local_chkpt_version_ = 0; - result_buffer_round_ = 1; - global_lazycheck_ = nullptr; - use_local_model_ = -1; - recover_counter_ = 0; - checkpoint_loaded_ = false; - env_vars.emplace_back("rabit_global_replica"); - env_vars.emplace_back("rabit_local_replica"); -} -bool AllreduceRobust::Init(int argc, char* argv[]) { - if (AllreduceBase::Init(argc, argv)) { - // chenqin: alert user opted in experimental feature. - if (rabit_bootstrap_cache) { utils::HandleLogInfo( - "[EXPERIMENTAL] bootstrap cache has been enabled\n"); -} - checkpoint_loaded_ = false; - if (num_global_replica_ == 0) { - result_buffer_round_ = -1; - } else { - result_buffer_round_ = std::max(world_size / num_global_replica_, 1); - } - return true; - } else { - return false; - } -} -/*! \brief shutdown the engine */ -bool AllreduceRobust::Shutdown() { - try { - // need to sync the exec before we shutdown, do a pesudo check point - // execute checkpoint, note: when checkpoint existing, load will not happen - assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp, - cur_cache_seq_), "Shutdown: check point must return true"); - // reset result buffer - resbuf_.Clear(); seq_counter = 0; - cachebuf_.Clear(); cur_cache_seq_ = 0; - lookupbuf_.Clear(); - // execute check ack step, load happens here - assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckAck, - ActionSummary::kSpecialOp, cur_cache_seq_), "Shutdown: check ack must return true"); -// travis ci only osx test hang -#if defined (__APPLE__) - sleep(1); -#endif - shutdown_timeout_ = true; - if (rabit_timeout_task_.valid()) { - rabit_timeout_task_.wait(); - assert_(rabit_timeout_task_.get(), "expect timeout task return\n"); - } - return AllreduceBase::Shutdown(); - } catch (const std::exception& e) { - fprintf(stderr, "%s\n", e.what()); - return false; - } -} - -/*! - * \brief set parameters to the engine - * \param name parameter name - * \param val parameter value - */ -void AllreduceRobust::SetParam(const char *name, const char *val) { - AllreduceBase::SetParam(name, val); - if (!strcmp(name, "rabit_global_replica")) num_global_replica_ = atoi(val); - if (!strcmp(name, "rabit_local_replica")) { - num_local_replica_ = atoi(val); - } -} - -int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf, - const size_t type_nbytes, const size_t count) { - for (int i = 0 ; i < cur_cache_seq_; i++) { - size_t nsize = 0; - void* name = lookupbuf_.Query(i, &nsize); - if (nsize == key.length() + 1 - && strcmp(static_cast(name), key.c_str()) == 0) { - break; - } - } - // we should consider way to support duplicated signatures - // https://github.com/dmlc/xgboost/issues/5012 - // _assert(index == -1, "immutable cache key already exists"); - assert_(type_nbytes*count > 0, "can't set empty cache"); - void* temp = cachebuf_.AllocTemp(type_nbytes, count); - cachebuf_.PushTemp(cur_cache_seq_, type_nbytes, count); - std::memcpy(temp, buf, type_nbytes*count); - - std::string k(key); - void* name = lookupbuf_.AllocTemp(strlen(k.c_str()) + 1, 1); - lookupbuf_.PushTemp(cur_cache_seq_, strlen(k.c_str()) + 1, 1); - std::memcpy(name, key.c_str(), strlen(k.c_str()) + 1); - cur_cache_seq_ += 1; - return 0; -} - -int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf, - const size_t type_nbytes, const size_t count) { - // as requester sync with rest of nodes on latest cache content - if (!RecoverExec(nullptr, 0, ActionSummary::kLoadBootstrapCache, - seq_counter, cur_cache_seq_)) return -1; - - int index = -1; - for (int i = 0 ; i < cur_cache_seq_; i++) { - size_t nsize = 0; - void* name = lookupbuf_.Query(i, &nsize); - if (nsize == strlen(key.c_str()) + 1 - && strcmp(reinterpret_cast(name), key.c_str()) == 0) { - index = i; - break; - } - } - // cache doesn't exists - if (index == -1) return -1; - - size_t siz = 0; - void* temp = cachebuf_.Query(index, &siz); - utils::Assert(cur_cache_seq_ > index, "cur_cache_seq is smaller than lookup cache seq index"); - utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested"); - utils::Assert(siz > 0, "cache size should be greater than 0"); - std::memcpy(buf, temp, type_nbytes*count); - return 0; -} - -/*! - * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ -void AllreduceRobust::Allgather(void *sendrecvbuf, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice, - const char* _file, - const int _line, - const char* _caller) { - if (world_size == 1 || world_size == -1) return; - // genreate unique allgather signature - std::string key = std::string(_file) + "::" + std::to_string(_line) + "::" - + std::string(_caller) + "#" +std::to_string(total_size); - - // try fetch bootstrap allgather results from cache - if (!checkpoint_loaded_ && rabit_bootstrap_cache && - GetBootstrapCache(key, sendrecvbuf, total_size, 1) != -1) return; - - double start = utils::GetTime(); - bool recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq_); - - if (resbuf_.LastSeqNo() != -1 && - (result_buffer_round_ == -1 || - resbuf_.LastSeqNo() % result_buffer_round_ != rank % result_buffer_round_)) { - resbuf_.DropLast(); - } - - void *temp = resbuf_.AllocTemp(total_size, 1); - while (true) { - if (recovered) { - std::memcpy(temp, sendrecvbuf, total_size); break; - } else { - std::memcpy(temp, sendrecvbuf, total_size); - if (CheckAndRecover(TryAllgatherRing(temp, total_size, - slice_begin, slice_end, size_prev_slice))) { - std::memcpy(sendrecvbuf, temp, total_size); break; - } else { - recovered = RecoverExec(sendrecvbuf, total_size, 0, seq_counter, cur_cache_seq_); - } - } - } - double delta = utils::GetTime() - start; - // log allgather latency - if (rabit_debug) { - utils::HandleLogInfo("[%d] allgather (%s) finished version %d, seq %d, take %f seconds\n", - rank, key.c_str(), version_number, seq_counter, delta); - } - - // if bootstrap allgather, store and fetch through cache - if (checkpoint_loaded_ || !rabit_bootstrap_cache) { - resbuf_.PushTemp(seq_counter, total_size, 1); - seq_counter += 1; - } else { - SetBootstrapCache(key, sendrecvbuf, total_size, 1); - } -} - -/*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe - * \param sendrecvbuf_ buffer for both sending and recving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to passed into the lazy preprocessing function - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ -void AllreduceRobust::Allreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer, - PreprocFunction prepare_fun, - void *prepare_arg, - const char* _file, - const int _line, - const char* _caller) { - // skip action in single node - if (world_size == 1 || world_size == -1) { - if (prepare_fun != nullptr) prepare_fun(prepare_arg); - return; - } - - // genreate unique allreduce signature - std::string key = std::string(_file) + "::" + std::to_string(_line) + "::" - + std::string(_caller) + "#" +std::to_string(type_nbytes) + "x" + std::to_string(count); - - // try fetch bootstrap allreduce results from cache - if (!checkpoint_loaded_ && rabit_bootstrap_cache && - GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return; - - double start = utils::GetTime(); - bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq_); - - if (resbuf_.LastSeqNo() != -1 && - (result_buffer_round_ == -1 || - resbuf_.LastSeqNo() % result_buffer_round_ != rank % result_buffer_round_)) { - resbuf_.DropLast(); - } - - if (!recovered && prepare_fun != nullptr) prepare_fun(prepare_arg); - void *temp = resbuf_.AllocTemp(type_nbytes, count); - while (true) { - if (recovered) { - std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break; - } else { - std::memcpy(temp, sendrecvbuf_, type_nbytes * count); - if (CheckAndRecover(TryAllreduce(temp, type_nbytes, count, reducer))) { - std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break; - } else { - recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq_); - } - } - } - double delta = utils::GetTime() - start; - // log allreduce latency - if (rabit_debug) { - utils::HandleLogInfo("[%d] allreduce (%s) finished version %d, seq %d, take %f seconds\n", - rank, key.c_str(), version_number, seq_counter, delta); - } - - // if bootstrap allreduce, store and fetch through cache - if (checkpoint_loaded_ || !rabit_bootstrap_cache) { - resbuf_.PushTemp(seq_counter, type_nbytes, count); - seq_counter += 1; - } else { - SetBootstrapCache(key, sendrecvbuf_, type_nbytes, count); - } -} -/*! - * \brief broadcast data from root to all nodes - * \param sendrecvbuf_ buffer for both sending and recving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ -void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root, - const char* _file, - const int _line, - const char* _caller) { - // skip action in single node - if (world_size == 1 || world_size == -1) return; - // genreate unique cache signature - std::string key = std::string(_file) + "::" + std::to_string(_line) + "::" - + std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root); - // try fetch bootstrap allreduce results from cache - if (!checkpoint_loaded_ && rabit_bootstrap_cache && - GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) { - return; - } - double start = utils::GetTime(); - bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq_); - // now we are free to remove the last result, if any - if (resbuf_.LastSeqNo() != -1 && - (result_buffer_round_ == -1 || - resbuf_.LastSeqNo() % result_buffer_round_ != rank % result_buffer_round_)) { - resbuf_.DropLast(); - } - void *temp = resbuf_.AllocTemp(1, total_size); - while (true) { - if (recovered) { - std::memcpy(temp, sendrecvbuf_, total_size); break; - } else { - if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) { - std::memcpy(temp, sendrecvbuf_, total_size); break; - } else { - recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq_); - } - } - } - - double delta = utils::GetTime() - start; - // log broadcast latency - if (rabit_debug) { - utils::HandleLogInfo( - "[%d] broadcast (%s) root %d finished version %d,seq %d, take %f seconds\n", - rank, key.c_str(), root, version_number, seq_counter, delta); - } - // if bootstrap broadcast, store and fetch through cache - if (checkpoint_loaded_ || !rabit_bootstrap_cache) { - resbuf_.PushTemp(seq_counter, 1, total_size); - seq_counter += 1; - } else { - SetBootstrapCache(key, sendrecvbuf_, total_size, 1); - } -} -/*! - * \brief load latest check point - * \param global_model pointer to the globally shared model/state - * when calling this function, the caller need to gauranttees that global_model - * is the same in all nodes - * \param local_model pointer to local model, that is specific to current node/rank - * this can be NULL when no local model is needed - * - * \return the version number of check point loaded - * if returned version == 0, this means no model has been CheckPointed - * the p_model is not touched, user should do necessary initialization by themselves - * - * Common usage example: - * int iter = rabit::LoadCheckPoint(&model); - * if (iter == 0) model.InitParameters(); - * for (i = iter; i < max_iter; ++i) { - * do many things, include allreduce - * rabit::CheckPoint(model); - * } - * - * \sa CheckPoint, VersionNumber - */ -int AllreduceRobust::LoadCheckPoint(Serializable *global_model, - Serializable *local_model) { - checkpoint_loaded_ = true; - // skip action in single node - if (world_size == 1) return 0; - this->LocalModelCheck(local_model != nullptr); - if (num_local_replica_ == 0) { - utils::Check(local_model == nullptr, - "need to set rabit_local_replica larger than 1 to checkpoint local_model"); - } - double start = utils::GetTime(); - // check if we succeed - if (RecoverExec(nullptr, 0, ActionSummary::kLoadCheck, ActionSummary::kSpecialOp, cur_cache_seq_)) { - int nlocal = std::max(static_cast(local_rptr_[local_chkpt_version_].size()) - 1, 0); - if (local_model != nullptr) { - if (nlocal == num_local_replica_ + 1) { - // load in local model - utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt_[local_chkpt_version_]), - local_rptr_[local_chkpt_version_][1]); - local_model->Load(&fs); - } else { - assert_(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal); - } - } - // reset result buffer - resbuf_.Clear(); seq_counter = 0; - // load from buffer - utils::MemoryBufferStream fs(&global_checkpoint_); - if (global_checkpoint_.length() == 0) { - version_number = 0; - } else { - assert_(fs.Read(&version_number, sizeof(version_number)) != 0, - "read in version number"); - global_model->Load(&fs); - assert_(local_model == nullptr || nlocal == num_local_replica_ + 1, - "local model inconsistent, nlocal=%d", nlocal); - } - // run another phase of check ack, if recovered from data - assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckAck, - ActionSummary::kSpecialOp, cur_cache_seq_), "check ack must return true"); - - if (!RecoverExec(nullptr, 0, ActionSummary::kLoadBootstrapCache, seq_counter, cur_cache_seq_)) { - utils::Printf("no need to load cache\n"); - } - double delta = utils::GetTime() - start; - - // log broadcast latency - if (rabit_debug) { - utils::HandleLogInfo("[%d] loadcheckpoint size %ld finished version %d, " - "seq %d, take %f seconds\n", - rank, global_checkpoint_.length(), - version_number, seq_counter, delta); - } - return version_number; - } else { - // log job fresh start - if (rabit_debug) utils::HandleLogInfo("[%d] loadcheckpoint reset\n", rank); - - // reset result buffer - resbuf_.Clear(); seq_counter = 0; version_number = 0; - // nothing loaded, a fresh start, everyone init model - return version_number; - } -} -/*! - * \brief internal consistency check function, - * use check to ensure user always call CheckPoint/LoadCheckPoint - * with or without local but not both, this function will set the approperiate settings - * in the first call of LoadCheckPoint/CheckPoint - * - * \param with_local whether the user calls CheckPoint with local model - */ -void AllreduceRobust::LocalModelCheck(bool with_local) { - if (use_local_model_ == -1) { - if (with_local) { - use_local_model_ = 1; - if (num_local_replica_ == 0) { - num_local_replica_ = default_local_replica_; - } - } else { - use_local_model_ = 0; - num_local_replica_ = 0; - } - } else { - utils::Check(use_local_model_ == static_cast(with_local), - "Can only call Checkpoint/LoadCheckPoint always with"\ - "or without local_model, but not mixed case"); - } -} -/*! - * \brief internal implementation of checkpoint, support both lazy and normal way - * - * \param global_model pointer to the globally shared model/state - * when calling this function, the caller need to gauranttees that global_model - * is the same in all nodes - * \param local_model pointer to local model, that is specific to current node/rank - * this can be NULL when no local state is needed - * \param lazy_checkpt whether the action is lazy checkpoint - * - * \sa CheckPoint, LazyCheckPoint - */ -void AllreduceRobust::CheckPointImpl(const Serializable *global_model, - const Serializable *local_model, - bool lazy_checkpt) { - // never do check point in single machine mode - if (world_size == 1) { - version_number += 1; return; - } - double start = utils::GetTime(); - this->LocalModelCheck(local_model != nullptr); - if (num_local_replica_ == 0) { - utils::Check(local_model == nullptr, - "need to set rabit_local_replica larger than 1 to checkpoint local_model"); - } - if (num_local_replica_ != 0) { - while (true) { - if (RecoverExec(nullptr, 0, 0, ActionSummary::kLocalCheckPoint)) break; - // save model to new version place - int new_version = !local_chkpt_version_; - - local_chkpt_[new_version].clear(); - utils::MemoryBufferStream fs(&local_chkpt_[new_version]); - if (local_model != nullptr) { - local_model->Save(&fs); - } - local_rptr_[new_version].clear(); - local_rptr_[new_version].push_back(0); - local_rptr_[new_version].push_back(local_chkpt_[new_version].length()); - if (CheckAndRecover(TryCheckinLocalState(&local_rptr_[new_version], - &local_chkpt_[new_version]))) break; - } - // run the ack phase, can be true or false - RecoverExec(nullptr, 0, 0, ActionSummary::kLocalCheckAck); - // switch pointer to new version - local_chkpt_version_ = !local_chkpt_version_; - } - // execute checkpoint, note: when checkpoint existing, load will not happen - assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckPoint, - ActionSummary::kSpecialOp, cur_cache_seq_), - "check point must return true"); - // this is the critical region where we will change all the stored models - // increase version number - version_number += 1; - // save model - if (lazy_checkpt) { - global_lazycheck_ = global_model; - } else { - global_checkpoint_.resize(0); - utils::MemoryBufferStream fs(&global_checkpoint_); - fs.Write(&version_number, sizeof(version_number)); - global_model->Save(&fs); - global_lazycheck_ = nullptr; - } - double delta = utils::GetTime() - start; - // log checkpoint latency - if (rabit_debug) { - utils::HandleLogInfo( - "[%d] checkpoint finished version %d,seq %d, take %f seconds\n", - rank, version_number, seq_counter, delta); - } - start = utils::GetTime(); - // reset result buffer, mark boostrap phase complete - resbuf_.Clear(); seq_counter = 0; - // execute check ack step, load happens here - assert_(RecoverExec(nullptr, 0, ActionSummary::kCheckAck, - ActionSummary::kSpecialOp, cur_cache_seq_), "check ack must return true"); - - delta = utils::GetTime() - start; - // log checkpoint ack latency - if (rabit_debug) { - utils::HandleLogInfo( - "[%d] checkpoint ack finished version %d, take %f seconds\n", rank, - version_number, delta); - } -} -/*! - * \brief reset the all the existing links by sending Out-of-Band message marker - * after this function finishes, all the messages received and sent before in all live links are discarded, - * This allows us to get a fresh start after error has happened - * - * \return this function can return kSuccess or kSockError - * when kSockError is returned, it simply means there are bad sockets in the links, - * and some link recovery proceduer is needed - */ -AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks() { - // number of links - const int nlink = static_cast(all_links.size()); - for (int i = 0; i < nlink; ++i) { - all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); - all_links[i].ResetSize(); - } - // read and discard data from all channels until pass mark - while (true) { - for (int i = 0; i < nlink; ++i) { - if (all_links[i].sock.BadSocket()) continue; - if (all_links[i].size_write == 0) { - char sig = kOOBReset; - ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); - // error will be filtered in next loop - if (len == sizeof(sig)) all_links[i].size_write = 1; - } - if (all_links[i].size_write == 1) { - char sig = kResetMark; - ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig)); - if (len == sizeof(sig)) all_links[i].size_write = 2; - } - } - utils::PollHelper rsel; - bool finished = true; - for (int i = 0; i < nlink; ++i) { - if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) { - rsel.WatchWrite(all_links[i].sock); finished = false; - } - } - if (finished) break; - // wait to read from the channels to discard data - rsel.Poll(); - } - for (int i = 0; i < nlink; ++i) { - if (!all_links[i].sock.BadSocket()) { - utils::PollHelper::WaitExcept(all_links[i].sock); - } - } - while (true) { - utils::PollHelper rsel; - bool finished = true; - for (int i = 0; i < nlink; ++i) { - if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { - rsel.WatchRead(all_links[i].sock); finished = false; - } - } - if (finished) break; - rsel.Poll(); - for (int i = 0; i < nlink; ++i) { - if (all_links[i].sock.BadSocket()) continue; - if (all_links[i].size_read == 0) { - int atmark = all_links[i].sock.AtMark(); - if (atmark < 0) { - assert_(all_links[i].sock.BadSocket(), "must already gone bad"); - } else if (atmark > 0) { - all_links[i].size_read = 1; - } else { - // no at mark, read and discard data - ssize_t len = all_links[i].sock.Recv(all_links[i].buffer_head, all_links[i].buffer_size); - if (all_links[i].sock.AtMark()) all_links[i].size_read = 1; - // zero length, remote closed the connection, close socket - if (len == 0) all_links[i].sock.Close(); - } - } - } - } - // start synchronization, use blocking I/O to avoid select - for (int i = 0; i < nlink; ++i) { - if (!all_links[i].sock.BadSocket()) { - char oob_mark; - all_links[i].sock.SetNonBlock(false); - ssize_t len = all_links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); - if (len == 0) { - all_links[i].sock.Close(); continue; - } else if (len > 0) { - assert_(oob_mark == kResetMark, "wrong oob msg"); - assert_(all_links[i].sock.AtMark() != 1, "should already read past mark"); - } else { - assert_(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); - } - // send out ack - char ack = kResetAck; - while (true) { - len = all_links[i].sock.Send(&ack, sizeof(ack)); - if (len == sizeof(ack)) break; - if (len == -1) { - if (errno != EAGAIN && errno != EWOULDBLOCK) break; - } - } - } - } - // wait all ack - for (int i = 0; i < nlink; ++i) { - if (!all_links[i].sock.BadSocket()) { - char ack; - ssize_t len = all_links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL); - if (len == 0) { - all_links[i].sock.Close(); continue; - } else if (len > 0) { - assert_(ack == kResetAck, "wrong Ack MSG"); - } else { - assert_(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); - } - // set back to nonblock mode - all_links[i].sock.SetNonBlock(true); - } - } - for (int i = 0; i < nlink; ++i) { - if (all_links[i].sock.BadSocket()) return kSockError; - } - return kSuccess; -} -/*! - * \brief if err_type indicates an error - * recover links according to the error type reported - * if there is no error, return true - * \param err_type the type of error happening in the system - * \return true if err_type is kSuccess, false otherwise - */ -bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { - shutdown_timeout_ = err_type == kSuccess; - if (err_type == kSuccess) return true; - - assert_(err_link != nullptr, "must know the error link"); - recover_counter_ += 1; - // async launch timeout task if enable_rabit_timeout is set - if (rabit_timeout && !rabit_timeout_task_.valid()) { - utils::Printf("[EXPERIMENTAL] timeout thread expires in %d second(s)\n", timeout_sec); - rabit_timeout_task_ = std::async(std::launch::async, [=]() { - if (rabit_debug) { - utils::Printf("[%d] timeout thread %ld starts\n", rank, - std::this_thread::get_id()); - } - int time = 0; - // check if rabit recovered every 100ms - while (time++ < 10 * timeout_sec) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (shutdown_timeout_.load()) { - if (rabit_debug) { - utils::Printf("[%d] timeout task thread %ld exits\n", - rank, std::this_thread::get_id()); - } - return true; - } - } - error_("[%d] exit due to time out %d s\n", rank, timeout_sec); - return false; - }); - } - // simple way, shutdown all links - for (auto & all_link : all_links) { - if (!all_link.sock.BadSocket()) all_link.sock.Close(); - } - // smooth out traffic to tracker - std::this_thread::sleep_for(std::chrono::milliseconds(10*rank)); - ReConnectLinks("recover"); - return false; -} -/*! - * \brief message passing function, used to decide the - * shortest distance to the possible source of data - * \param node_value a pair of have_data and size - * have_data whether current node have data - * size gives the size of data, if current node is kHaveData - * \param dist_in the shorest to any data source distance in each direction - * \param out_index the edge index of output link - * \return the shorest distance result of out edge specified by out_index - */ -inline std::pair -ShortestDist(const std::pair &node_value, - const std::vector< std::pair > &dist_in, - size_t out_index) { - if (node_value.first) { - return std::make_pair(1, node_value.second); - } - size_t size = 0; - int res = std::numeric_limits::max(); - for (size_t i = 0; i < dist_in.size(); ++i) { - if (i == out_index) continue; - if (dist_in[i].first == std::numeric_limits::max()) continue; - if (dist_in[i].first + 1 < res) { - res = dist_in[i].first + 1; - size = dist_in[i].second; - } - } - // add one hop - - return std::make_pair(res, size); -} -/*! - * \brief message passing function, used to decide the - * data request from each edge, whether need to request data from certain edge - * \param node_value a pair of request_data and best_link - * request_data stores whether current node need to request data - * best_link gives the best edge index to fetch the data - * \param req_in the data request from incoming edges - * \param out_index the edge index of output link - * \return the request to the output edge - */ -inline char DataRequest(const std::pair &node_value, - const std::vector &req_in, - size_t out_index) { - // whether current node need to request data - bool request_data = node_value.first; - // which edge index is the best link to request data - // can be -1, which means current node contains data - const int best_link = node_value.second; - if (static_cast(out_index) == best_link) { - if (request_data) return 1; - for (size_t i = 0; i < req_in.size(); ++i) { - if (i == out_index) continue; - if (req_in[i] != 0) return 1; - } - } - return 0; -} -/*! - * \brief try to decide the recovery message passing request - * \param role the current role of the node - * \param p_size used to store the size of the message, for node in state kHaveData, - * this size must be set correctly before calling the function - * for others, this surves as output parameter - * - * \param p_recvlink used to store the link current node should recv data from, if necessary - * this can be -1, which means current node have the data - * \param p_req_in used to store the resulting vector, indicating which link we should send the data to - * - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceRobust::ReturnType -AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role, - size_t *p_size, - int *p_recvlink, - std::vector *p_req_in) { - int best_link = -2; - { - // get the shortest distance to the request point - std::vector > dist_in, dist_out; - - ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size), - &dist_in, &dist_out, ShortestDist); - if (succ != kSuccess) return succ; - if (role != kHaveData) { - for (size_t i = 0; i < dist_in.size(); ++i) { - if (dist_in[i].first != std::numeric_limits::max()) { - utils::Check(best_link == -2 || *p_size == dist_in[i].second, - "[%d] Allreduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n", - rank, dist_in[i].first, *p_size, dist_in[i].second); - if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) { - best_link = static_cast(i); - *p_size = dist_in[i].second; - } - } - } - utils::Check(best_link != -2, "Too many nodes went down and we cannot recover.."); - } else { - best_link = -1; - } - } - // get the node request - std::vector req_in, req_out; - ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link), - &req_in, &req_out, DataRequest); - if (succ != kSuccess) return succ; - // set p_req_in - p_req_in->resize(req_in.size()); - for (size_t i = 0; i < req_in.size(); ++i) { - // set p_req_in - (*p_req_in)[i] = (req_in[i] != 0); - if (req_out[i] != 0) { - assert_(req_in[i] == 0, "cannot get and receive request"); - assert_(static_cast(i) == best_link, "request result inconsistent"); - } - } - *p_recvlink = best_link; - return kSuccess; -} -/*! - * \brief try to finish the data recovery request, - * this function is used together with TryDecideRouting - * \param role the current role of the node - * \param sendrecvbuf_ the buffer to store the data to be sent/recived - * - if the role is kHaveData, this stores the data to be sent - * - if the role is kRequestData, this is the buffer to store the result - * - if the role is kPassData, this will not be used, and can be NULL - * \param size the size of the data, obtained from TryDecideRouting - * \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting - * \param req_in the request of each link to send data, obtained from TryDecideRouting - * - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType, TryDecideRouting - */ -AllreduceRobust::ReturnType -AllreduceRobust::TryRecoverData(RecoverType role, - void *sendrecvbuf_, - size_t size, - int recv_link, - const std::vector &req_in) { - RefLinkVector &links = tree_links; - // no need to run recovery for zero size messages - if (links.Size() == 0 || size == 0) return kSuccess; - assert_(req_in.size() == links.Size(), "TryRecoverData"); - const int nlink = static_cast(links.Size()); - { - bool req_data = role == kRequestData; - for (int i = 0; i < nlink; ++i) { - if (req_in[i]) { - assert_(i != recv_link, "TryDecideRouting"); - req_data = true; - } - } - // do not need to provide data or receive data, directly exit - if (!req_data) return kSuccess; - } - assert_(recv_link >= 0 || role == kHaveData, "recv_link must be active"); - if (role == kPassData) { - links[recv_link].InitBuffer(1, size, reduce_buffer_size); - } - for (int i = 0; i < nlink; ++i) { - links[i].ResetSize(); - } - while (true) { - bool finished = true; - utils::PollHelper watcher; - for (int i = 0; i < nlink; ++i) { - if (i == recv_link && links[i].size_read != size) { - watcher.WatchRead(links[i].sock); - finished = false; - } - if (req_in[i] && links[i].size_write != size) { - if (role == kHaveData || - (links[recv_link].size_read != links[i].size_write)) { - watcher.WatchWrite(links[i].sock); - } - finished = false; - } - watcher.WatchException(links[i].sock); - } - if (finished) break; - watcher.Poll(); - // exception handling - for (int i = 0; i < nlink; ++i) { - if (watcher.CheckExcept(links[i].sock)) { - return ReportError(&links[i], kGetExcept); - } - } - if (role == kRequestData) { - const int pid = recv_link; - if (watcher.CheckRead(links[pid].sock)) { - ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size); - if (ret != kSuccess) { - return ReportError(&links[pid], ret); - } - } - for (int i = 0; i < nlink; ++i) { - if (req_in[i] && links[i].size_write != links[pid].size_read) { - ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - if (role == kHaveData) { - for (int i = 0; i < nlink; ++i) { - if (req_in[i] && links[i].size_write != size) { - ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - if (role == kPassData) { - const int pid = recv_link; - const size_t buffer_size = links[pid].buffer_size; - if (watcher.CheckRead(links[pid].sock)) { - size_t min_write = size; - for (int i = 0; i < nlink; ++i) { - if (req_in[i]) min_write = std::min(links[i].size_write, min_write); - } - assert_(min_write <= links[pid].size_read, "boundary check"); - ReturnType ret = links[pid].ReadToRingBuffer(min_write, size); - if (ret != kSuccess) { - return ReportError(&links[pid], ret); - } - } - for (int i = 0; i < nlink; ++i) { - if (req_in[i] && links[pid].size_read != links[i].size_write) { - size_t start = links[i].size_write % buffer_size; - // send out data from ring buffer - size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); - ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite); - if (len != -1) { - links[i].size_write += len; - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) return ReportError(&links[i], ret); - } - } - } - } - } - return kSuccess; -} -/*! - * \brief try to fetch allreduce/broadcast results from rest of nodes - * as collaberative function called by all nodes, only requester node - * will pass seqno to rest of nodes and reconstruct/backfill sendrecvbuf_ - * of specific seqno from other nodes. - */ -AllreduceRobust::ReturnType AllreduceRobust::TryRestoreCache(bool requester, - const int min_seq, const int max_seq) { - // clear requester and rebuild from those with most cache entries - if (requester) { - assert_(cur_cache_seq_ <= max_seq, "requester is expected to have fewer cache entries"); - cachebuf_.Clear(); - lookupbuf_.Clear(); - cur_cache_seq_ = 0; - } - RecoverType role = requester ? kRequestData : kHaveData; - size_t size = 1; - int recv_link; - std::vector req_in; - ReturnType ret = TryDecideRouting(role, &size, &recv_link, &req_in); - if (ret != kSuccess) return ret; - // only recover missing cache entries in requester - // as tryrecoverdata is collective call, need to go through entire cache - // and only work on those missing - for (int i = 0; i < max_seq; i++) { - // restore lookup map - size_t cache_size = 0; - void* key = lookupbuf_.Query(i, &cache_size); - ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in); - if (ret != kSuccess) return ret; - if (requester) { - key = lookupbuf_.AllocTemp(cache_size, 1); - lookupbuf_.PushTemp(i, cache_size, 1); - } - ret = TryRecoverData(role, key, cache_size, recv_link, req_in); - if (ret != kSuccess) return ret; - // restore cache content - cache_size = 0; - void* buf = cachebuf_.Query(i, &cache_size); - ret = TryRecoverData(role, &cache_size, sizeof(size_t), recv_link, req_in); - if (requester) { - buf = cachebuf_.AllocTemp(cache_size, 1); - cachebuf_.PushTemp(i, cache_size, 1); - cur_cache_seq_ +=1; - } - ret = TryRecoverData(role, buf, cache_size, recv_link, req_in); - if (ret != kSuccess) return ret; - } - - return kSuccess; -} - -/*! - * \brief try to load check point - * - * This is a collaborative function called by all nodes - * only the nodes with requester set to true really needs to load the check point - * other nodes acts as collaborative roles to complete this request - * - * \param requester whether current node is the requester - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { - // check in local data - RecoverType role = requester ? kRequestData : kHaveData; - ReturnType succ; - if (num_local_replica_ != 0) { - if (requester) { - // clear existing history, if any, before load - local_rptr_[local_chkpt_version_].clear(); - local_chkpt_[local_chkpt_version_].clear(); - } - // recover local checkpoint - succ = TryRecoverLocalState(&local_rptr_[local_chkpt_version_], - &local_chkpt_[local_chkpt_version_]); - if (succ != kSuccess) return succ; - int nlocal = std::max(static_cast(local_rptr_[local_chkpt_version_].size()) - 1, 0); - // check if everyone is OK - unsigned state = 0; - if (nlocal == num_local_replica_ + 1) { - // complete recovery - state = 1; - } else if (nlocal == 0) { - // get nothing - state = 2; - } else { - // partially complete state - state = 4; - } - succ = TryAllreduce(&state, sizeof(state), 1, op::Reducer); - if (succ != kSuccess) return succ; - utils::Check(state == 1 || state == 2, - "LoadCheckPoint: too many nodes fails, cannot recover local state"); - } - // do call save model if the checkpoint was lazy - if (role == kHaveData && global_lazycheck_ != nullptr) { - global_checkpoint_.resize(0); - utils::MemoryBufferStream fs(&global_checkpoint_); - fs.Write(&version_number, sizeof(version_number)); - global_lazycheck_->Save(&fs); - global_lazycheck_ = nullptr; - } - // recover global checkpoint - size_t size = this->global_checkpoint_.length(); - int recv_link; - std::vector req_in; - succ = TryDecideRouting(role, &size, &recv_link, &req_in); - if (succ != kSuccess) return succ; - if (role == kRequestData) { - global_checkpoint_.resize(size); - } - if (size == 0) return kSuccess; - return TryRecoverData(role, BeginPtr(global_checkpoint_), size, recv_link, req_in); -} -/*! - * \brief try to get the result of operation specified by seqno - * - * This is a collaborative function called by all nodes - * only the nodes with requester set to true really needs to get the result - * other nodes acts as collaborative roles to complete this request - * - * \param buf the buffer to store the result, this parameter is only used when current node is requester - * \param size the total size of the buffer, this parameter is only used when current node is requester - * \param seqno sequence number of the operation, this is unique index of a operation in current iteration - * \param requester whether current node is the requester - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceRobust::ReturnType -AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { - // if minimum sequence requested is local check point ack, - // this means all nodes have finished local check point, directly return - if (seqno == ActionSummary::kLocalCheckAck) return kSuccess; - if (seqno == ActionSummary::kLocalCheckPoint) { - // new version of local model - int new_version = !local_chkpt_version_; - int nlocal = std::max(static_cast(local_rptr_[new_version].size()) - 1, 0); - // if we goes to this place, use must have already setup the state once - assert_(nlocal == 1 || nlocal == num_local_replica_ + 1, - "TryGetResult::Checkpoint"); - return TryRecoverLocalState(&local_rptr_[new_version], &local_chkpt_[new_version]); - } - - // handles normal data recovery - RecoverType role; - if (!requester) { - sendrecvbuf = resbuf_.Query(seqno, &size); - role = sendrecvbuf != nullptr ? kHaveData : kPassData; - } else { - role = kRequestData; - } - int recv_link; - std::vector req_in; - // size of data - size_t data_size = size; - ReturnType succ = TryDecideRouting(role, &data_size, &recv_link, &req_in); - if (succ != kSuccess) return succ; - utils::Check(data_size != 0, "zero size check point is not allowed"); - if (role == kRequestData || role == kHaveData) { - utils::Check(data_size == size, - "Allreduce Recovered data size do not match the specification of function call.\n"\ - "Please check if calling sequence of recovered program is the " \ - "same the original one in current VersionNumber"); - } - return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in); -} -/*! - * \brief try to run recover execution for a request action described by flag and seqno, - * the function will keep blocking to run possible recovery operations before the specified action, - * until the requested result is received by a recovering procedure, - * or the function discovers that the requested action is not yet executed, and return false - * - * \param buf the buffer to store the result - * \param size the total size of the buffer - * \param flag flag information about the action \sa ActionSummary - * \param seqno sequence number of the action, if it is special action with flag set, - * seqno needs to be set to ActionSummary::kSpecialOp - * - * \return if this function can return true or false - * - true means buf already set to the - * result by recovering procedure, the action is complete, no further action is needed - * - false means this is the lastest action that has not yet been executed, need to execute the action - */ -bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, - int cache_seqno, const char* caller) { - // kLoadBootstrapCache should be treated similar as allreduce - // when loadcheck/check/checkack runs in other nodes - if (flag != 0 && flag != ActionSummary::kLoadBootstrapCache) { - assert_(seqno == ActionSummary::kSpecialOp, "must only set seqno for normal operations"); - } - - std::string msg = std::string(caller) + " pass negative seqno " - + std::to_string(seqno) + " flag " + std::to_string(flag) - + " version " + std::to_string(version_number); - assert_(seqno >=0, msg.c_str()); - - ActionSummary req(flag, flag, seqno, cache_seqno); - - while (true) { - this->ReportStatus(); - // copy to action and send to allreduce with other nodes - ActionSummary act = req; - // get the reduced action - if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue; - - if (act.CheckAck()) { - if (act.CheckPoint()) { - // if we also have check_point, do check point first - assert_(!act.DiffSeq(), - "check ack & check pt cannot occur together with normal ops"); - // if we requested checkpoint, we are free to go - if (req.CheckPoint()) return true; - } else if (act.LoadCheck()) { - // if there is only check_ack and load_check, do load_check - if (!CheckAndRecover(TryLoadCheckPoint(req.LoadCheck()))) continue; - // if requested load check, then misson complete - if (req.LoadCheck()) return true; - } else { - // there is no check point and no load check, execute check ack - if (req.CheckAck()) return true; - } - // if execute to this point - // this means the action requested has not been completed - // try next round - } else { - if (act.CheckPoint()) { - if (act.DiffSeq()) { - assert_(act.Seqno() != ActionSummary::kSpecialOp, "min seq bug"); - // print checkpoint consensus flag if user turn on debug - if (rabit_debug) { - req.PrintFlags(rank, "checkpoint req"); - act.PrintFlags(rank, "checkpoint act"); - } - /* - * Chen Qin - * at least one hit checkpoint_ code & at least one not hitting - * compare with version_number of req.check_point() set true with rest - * expect to be equal, means rest fall behind in sequence - * use resbuf resbuf to recover - * worker-0 worker-1 - * checkpoint(n-1) checkpoint(n-1) - * allreduce allreduce (requester) | - * broadcast V - * checkpoint(n req) - * after catch up to checkpoint n, diff_seq will be false - * */ - // assume requester is falling behind - bool requester = req.Seqno() == act.Seqno(); - // if not load cache - if (!act.LoadCache()) { - if (act.Seqno() > 0) { - if (!requester) { - assert_(req.CheckPoint(), "checkpoint node should be KHaveData role"); - buf = resbuf_.Query(act.Seqno(), &size); - assert_(buf != nullptr, "buf should have data from resbuf"); - assert_(size > 0, "buf size should be greater than 0"); - } - if (!CheckAndRecover(TryGetResult(buf, size, act.Seqno(), requester))) continue; - } - } else { - // cache seq no should be smaller than kSpecialOp - assert_(act.Seqno(SeqType::kCache) != ActionSummary::kSpecialOp, - "checkpoint with kSpecialOp"); - int max_cache_seq = cur_cache_seq_; - if (TryAllreduce(&max_cache_seq, sizeof(max_cache_seq), 1, - op::Reducer) != kSuccess) continue; - - if (TryRestoreCache(req.LoadCache(), act.Seqno(), max_cache_seq) - != kSuccess) continue; - } - if (requester) return true; - } else { - // no difference in seq no, means we are free to check point - if (req.CheckPoint()) return true; - } - } else { - // no check point - if (act.LoadCheck()) { - // all the nodes called load_check, this is an incomplete action - if (!act.DiffSeq()) return false; - // load check have higher priority, do load_check - if (!CheckAndRecover(TryLoadCheckPoint(req.LoadCheck()))) continue; - // if requested load check, then misson complete - if (req.LoadCheck()) return true; - } else { - // run all nodes in a isolated cache restore logic - if (act.LoadCache()) { - // print checkpoint consensus flag if user turn on debug - if (rabit_debug) { - req.PrintFlags(rank, "loadcache req"); - act.PrintFlags(rank, "loadcache act"); - } - // load cache should not running in parralel with other states - assert_(!act.LoadCheck(), - "load cache state expect no nodes doing load checkpoint"); - assert_(!act.CheckPoint() , - "load cache state expect no nodes doing checkpoint"); - assert_(!act.CheckAck(), - "load cache state expect no nodes doing checkpoint ack"); - - // if all nodes are requester in load cache, skip - if (act.LoadCache(SeqType::kCache)) return false; - - // bootstrap cache always restore before loadcheckpoint - // requester always have seq diff with non requester - if (act.DiffSeq()) { - // restore cache failed, retry from what's left - if (TryRestoreCache(req.LoadCache(), act.Seqno(), act.Seqno(SeqType::kCache)) - != kSuccess) continue; - } - // if requested load cache, then mission complete - if (req.LoadCache()) return true; - continue; - } - - // assert no req with load cache set goes into seq catch up - assert_(!req.LoadCache(), "load cache not interacte with rest states"); - - // no special flags, no checkpoint, check ack, load_check - assert_(act.Seqno() != ActionSummary::kSpecialOp, "min seq bug"); - if (act.DiffSeq()) { - bool requester = req.Seqno() == act.Seqno(); - if (!CheckAndRecover(TryGetResult(buf, size, act.Seqno(), requester))) continue; - if (requester) return true; - } else { - // all the request is same, - // this is most recent command that is yet to be executed - return false; - } - } - } - // something is still incomplete try next round - } - } - assert_(false, "RecoverExec: should not reach here"); - return true; -} -/*! - * \brief try to recover the local state, making each local state to be the result of itself - * plus replication of states in previous num_local_replica hops in the ring - * - * The input parameters must contain the valid local states available in current nodes, - * This function try ist best to "complete" the missing parts of local_rptr and local_chkpt - * If there is sufficient information in the ring, when the function returns, local_chkpt will - * contain num_local_replica + 1 checkpoints (including the chkpt of this node) - * If there is no sufficient information in the ring, this function the number of checkpoints - * will be less than the specified value - * - * \param p_local_rptr the pointer to the segment pointers in the states array - * \param p_local_chkpt the pointer to the storage of local check points - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceRobust::ReturnType -AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, - std::string *p_local_chkpt) { - // if there is no local replica, we can do nothing - if (num_local_replica_ == 0) return kSuccess; - std::vector &rptr = *p_local_rptr; - std::string &chkpt = *p_local_chkpt; - if (rptr.size() == 0) { - rptr.push_back(0); - assert_(chkpt.length() == 0, "local chkpt space inconsistent"); - } - const int n = num_local_replica_; - { - // backward passing, passing state in backward direction of the ring - const int nlocal = static_cast(rptr.size() - 1); - assert_(nlocal <= n + 1, "invalid local replica"); - std::vector msg_back(n + 1); - msg_back[0] = nlocal; - // backward passing one hop the request - ReturnType succ; - succ = RingPassing(BeginPtr(msg_back), - 1 * sizeof(int), (n+1) * sizeof(int), - 0 * sizeof(int), n * sizeof(int), - ring_next, ring_prev); - if (succ != kSuccess) return succ; - int msg_forward[2]; - msg_forward[0] = nlocal; - succ = RingPassing(msg_forward, - 1 * sizeof(int), 2 * sizeof(int), - 0 * sizeof(int), 1 * sizeof(int), - ring_prev, ring_next); - if (succ != kSuccess) return succ; - // calculate the number of things we can read from next link - int nread_end = nlocal; - for (int i = 1; i <= n; ++i) { - nread_end = std::max(nread_end, msg_back[i] - i); - } - // gives the size of forward - int nwrite_start = std::min(msg_forward[1] + 1, nread_end); - // get the size of each segments - std::vector sizes(nread_end); - for (int i = 0; i < nlocal; ++i) { - sizes[i] = rptr[i + 1] - rptr[i]; - } - // pass size through the link - succ = RingPassing(BeginPtr(sizes), - nlocal * sizeof(size_t), - nread_end * sizeof(size_t), - nwrite_start * sizeof(size_t), - nread_end * sizeof(size_t), - ring_next, ring_prev); - if (succ != kSuccess) return succ; - // update rptr - rptr.resize(nread_end + 1); - for (int i = nlocal; i < nread_end; ++i) { - rptr[i + 1] = rptr[i] + sizes[i]; - } - chkpt.resize(rptr.back()); - // pass data through the link - succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], - rptr[nwrite_start], rptr[nread_end], - ring_next, ring_prev); - if (succ != kSuccess) { - rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; - } - } - { - // forward passing, passing state in forward direction of the ring - const int nlocal = static_cast(rptr.size() - 1); - assert_(nlocal <= n + 1, "invalid local replica"); - std::vector msg_forward(n + 1); - msg_forward[0] = nlocal; - // backward passing one hop the request - ReturnType succ; - succ = RingPassing(BeginPtr(msg_forward), - 1 * sizeof(int), (n+1) * sizeof(int), - 0 * sizeof(int), n * sizeof(int), - ring_prev, ring_next); - if (succ != kSuccess) return succ; - int msg_back[2]; - msg_back[0] = nlocal; - succ = RingPassing(msg_back, - 1 * sizeof(int), 2 * sizeof(int), - 0 * sizeof(int), 1 * sizeof(int), - ring_next, ring_prev); - if (succ != kSuccess) return succ; - // calculate the number of things we can read from next link - int nread_end = nlocal, nwrite_end = 1; - // have to have itself in order to get other data from prev link - if (nlocal != 0) { - for (int i = 1; i <= n; ++i) { - if (msg_forward[i] == 0) break; - nread_end = std::max(nread_end, i + 1); - nwrite_end = i + 1; - } - if (nwrite_end > n) nwrite_end = n; - } else { - nread_end = 0; nwrite_end = 0; - } - // gives the size of forward - int nwrite_start = std::min(msg_back[1] - 1, nwrite_end); - // next node miss the state of itself, cannot recover - if (nwrite_start < 0) nwrite_start = nwrite_end = 0; - // get the size of each segments - std::vector sizes(nread_end); - for (int i = 0; i < nlocal; ++i) { - sizes[i] = rptr[i + 1] - rptr[i]; - } - // pass size through the link, check consistency - succ = RingPassing(BeginPtr(sizes), - nlocal * sizeof(size_t), - nread_end * sizeof(size_t), - nwrite_start * sizeof(size_t), - nwrite_end * sizeof(size_t), - ring_prev, ring_next); - if (succ != kSuccess) return succ; - // update rptr - rptr.resize(nread_end + 1); - for (int i = nlocal; i < nread_end; ++i) { - rptr[i + 1] = rptr[i] + sizes[i]; - } - chkpt.resize(rptr.back()); - // pass data through the link - succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], - rptr[nwrite_start], rptr[nwrite_end], - ring_prev, ring_next); - if (succ != kSuccess) { - rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; - } - } - return kSuccess; -} -/*! - * \brief try to checkpoint local state, this function is called in normal executation phase - * of checkpoint that contains local state - * the input state must exactly one saved state(local state of current node), - * after complete, this function will get local state from previous num_local_replica nodes and put them - * into local_chkpt and local_rptr - * - * It is also OK to call TryRecoverLocalState instead, - * TryRecoverLocalState makes less assumption about the input, and requires more communications - * - * \param p_local_rptr the pointer to the segment pointers in the states array - * \param p_local_chkpt the pointer to the storage of local check points - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType, TryRecoverLocalState - */ -AllreduceRobust::ReturnType -AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, - std::string *p_local_chkpt) { - // if there is no local replica, we can do nothing - if (num_local_replica_ == 0) return kSuccess; - std::vector &rptr = *p_local_rptr; - std::string &chkpt = *p_local_chkpt; - assert_(rptr.size() == 2, - "TryCheckinLocalState must have exactly 1 state"); - const int n = num_local_replica_; - std::vector sizes(n + 1); - sizes[0] = rptr[1] - rptr[0]; - ReturnType succ; - // pass size through the link - succ = RingPassing(BeginPtr(sizes), - 1 * sizeof(size_t), - (n + 1) * sizeof(size_t), - 0 * sizeof(size_t), - n * sizeof(size_t), - ring_prev, ring_next); - if (succ != kSuccess) return succ; - // update rptr - rptr.resize(n + 2); - for (int i = 1; i <= n; ++i) { - rptr[i + 1] = rptr[i] + sizes[i]; - } - chkpt.resize(rptr.back()); - // pass data through the link - succ = RingPassing(BeginPtr(chkpt), - rptr[1], rptr[n + 1], - rptr[0], rptr[n], - ring_prev, ring_next); - if (succ != kSuccess) { - rptr.resize(2); chkpt.resize(rptr.back()); return succ; - } - return kSuccess; -} -/*! - * \brief perform a ring passing to receive data from prev link, and sent data to next link - * this allows data to stream over a ring structure - * sendrecvbuf[0:read_ptr] are already provided by current node - * current node will recv sendrecvbuf[read_ptr:read_end] from prev link - * current node will send sendrecvbuf[write_ptr:write_end] to next link - * write_ptr will wait till the data is readed before sending the data - * this function requires read_end >= write_end - * - * \param sendrecvbuf_ the place to hold the incoming and outgoing data - * \param read_ptr the initial read pointer - * \param read_end the ending position to read - * \param write_ptr the initial write pointer - * \param write_end the ending position to write - * \param read_link pointer to link to previous position in ring - * \param write_link pointer to link of next position in ring - */ -AllreduceRobust::ReturnType -AllreduceRobust::RingPassing(void *sendrecvbuf_, - size_t read_ptr, - size_t read_end, - size_t write_ptr, - size_t write_end, - LinkRecord *read_link, - LinkRecord *write_link) { - if (read_link == nullptr || write_link == nullptr || read_end == 0) return kSuccess; - assert_(write_end <= read_end, - "RingPassing: boundary check1"); - assert_(read_ptr <= read_end, "RingPassing: boundary check2"); - assert_(write_ptr <= write_end, "RingPassing: boundary check3"); - // take reference - LinkRecord &prev = *read_link, &next = *write_link; - // send recv buffer - char *buf = reinterpret_cast(sendrecvbuf_); - while (true) { - bool finished = true; - utils::PollHelper watcher; - if (read_ptr != read_end) { - watcher.WatchRead(prev.sock); - finished = false; - } - if (write_ptr < read_ptr && write_ptr != write_end) { - watcher.WatchWrite(next.sock); - finished = false; - } - watcher.WatchException(prev.sock); - watcher.WatchException(next.sock); - if (finished) break; - watcher.Poll(); - if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept); - if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept); - if (read_ptr != read_end && watcher.CheckRead(prev.sock)) { - ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr); - if (len == 0) { - prev.sock.Close(); return ReportError(&prev, kRecvZeroLen); - } - if (len != -1) { - read_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) return ReportError(&prev, ret); - } - } - if (write_ptr != write_end && write_ptr < read_ptr) { - size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr); - ssize_t len = next.sock.Send(buf + write_ptr, nsend); - if (len != -1) { - write_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) return ReportError(&prev, ret); - } - } - } - return kSuccess; -} -} // namespace engine -} // namespace rabit diff --git a/rabit/src/allreduce_robust.h b/rabit/src/allreduce_robust.h deleted file mode 100644 index 1b9134650..000000000 --- a/rabit/src/allreduce_robust.h +++ /dev/null @@ -1,665 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file allreduce_robust.h - * \brief Robust implementation of Allreduce - * using TCP non-block socket and tree-shape reduction. - * - * This implementation considers the failure of nodes - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#ifndef RABIT_ALLREDUCE_ROBUST_H_ -#define RABIT_ALLREDUCE_ROBUST_H_ -#include -#include -#include -#include -#include "rabit/internal/engine.h" -#include "allreduce_base.h" - -namespace rabit { -namespace engine { -/*! \brief implementation of fault tolerant all reduce engine */ -class AllreduceRobust : public AllreduceBase { - public: - AllreduceRobust(); - ~AllreduceRobust() override = default; - // initialize the manager - bool Init(int argc, char* argv[]) override; - /*! \brief shutdown the engine */ - bool Shutdown() override; - /*! - * \brief set parameters to the engine - * \param name parameter name - * \param val parameter value - */ - void SetParam(const char *name, const char *val) override; - /*! - * \brief perform immutable local bootstrap cache insertion - * \param key unique cache key - * \param buf buffer of allreduce/robust payload to copy - * \param buflen total number of bytes - * \return -1 if no recovery cache fetched otherwise 0 - */ - int SetBootstrapCache(const std::string &key, const void *buf, - const size_t type_nbytes, const size_t count); - /*! - * \brief perform bootstrap cache lookup if nodes in fault recovery - * \param key unique cache key - * \param buf buffer for recv allreduce/robust payload - * \param buflen total number of bytes - */ - int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes, - const size_t count); - /*! - * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ - void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, - size_t slice_end, size_t size_prev_slice, - const char *_file = _FILE, const int _line = _LINE, - const char *_caller = _CALLER) override; - /*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe - * \param sendrecvbuf_ buffer for both sending and recving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to passed into the lazy preprocessing function - * \param prepare_arg argument used to passed into the lazy preprocessing function - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ - void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer, PreprocFunction prepare_fun = nullptr, - void *prepare_arg = nullptr, const char *_file = _FILE, - const int _line = _LINE, - const char *_caller = _CALLER) override; - /*! - * \brief broadcast data from root to all nodes - * \param sendrecvbuf_ buffer for both sending and recving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ - void Broadcast(void *sendrecvbuf_, size_t total_size, int root, - const char *_file = _FILE, const int _line = _LINE, - const char *_caller = _CALLER) override; - /*! - * \brief load latest check point - * \param global_model pointer to the globally shared model/state - * when calling this function, the caller need to gauranttees that global_model - * is the same in all nodes - * \param local_model pointer to local model, that is specific to current node/rank - * this can be NULL when no local model is needed - * - * \return the version number of check point loaded - * if returned version == 0, this means no model has been CheckPointed - * the p_model is not touched, user should do necessary initialization by themselves - * - * Common usage example: - * int iter = rabit::LoadCheckPoint(&model); - * if (iter == 0) model.InitParameters(); - * for (i = iter; i < max_iter; ++i) { - * do many things, include allreduce - * rabit::CheckPoint(model); - * } - * - * \sa CheckPoint, VersionNumber - */ - int LoadCheckPoint(Serializable *global_model, - Serializable *local_model = nullptr) override; - /*! - * \brief checkpoint the model, meaning we finished a stage of execution - * every time we call check point, there is a version number which will increase by one - * - * \param global_model pointer to the globally shared model/state - * when calling this function, the caller need to gauranttees that global_model - * is the same in all nodes - * \param local_model pointer to local model, that is specific to current node/rank - * this can be NULL when no local state is needed - * - * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will - * bring replication cost in CheckPoint function. global_model do not need explicit replication. - * So only CheckPoint with global_model if possible - * - * \sa LoadCheckPoint, VersionNumber - */ - void CheckPoint(const Serializable *global_model, - const Serializable *local_model = nullptr) override { - this->CheckPointImpl(global_model, local_model, false); - } - /*! - * \brief This function can be used to replace CheckPoint for global_model only, - * when certain condition is met(see detailed expplaination). - * - * This is a "lazy" checkpoint such that only the pointer to global_model is - * remembered and no memory copy is taken. To use this function, the user MUST ensure that: - * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs. - * In another words, global_model model can be changed only between last call of - * Allreduce/Broadcast and LazyCheckPoint in current version - * - * For example, suppose the calling sequence is: - * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint - * - * If user can only changes global_model in code3, then LazyCheckPoint can be used to - * improve efficiency of the program. - * \param global_model pointer to the globally shared model/state - * when calling this function, the caller need to gauranttees that global_model - * is the same in all nodes - * \sa LoadCheckPoint, CheckPoint, VersionNumber - */ - void LazyCheckPoint(const Serializable *global_model) override { - this->CheckPointImpl(global_model, nullptr, true); - } - /*! - * \brief explicitly re-init everything before calling LoadCheckPoint - * call this function when IEngine throw an exception out, - * this function is only used for test purpose - */ - void InitAfterException() override { - // simple way, shutdown all links - for (auto& link : all_links) { - if (link.sock.BadSocket()) { - link.sock.Close(); - } - } - ReConnectLinks("recover"); - } - - protected: - // constant one byte out of band message to indicate error happening - // and mark for channel cleanup - static const char kOOBReset = 95; - // and mark for channel cleanup, after OOB signal - static const char kResetMark = 97; - // and mark for channel cleanup - static const char kResetAck = 97; - /*! \brief type of roles each node can play during recovery */ - enum RecoverType { - /*! \brief current node have data */ - kHaveData = 0, - /*! \brief current node request data */ - kRequestData = 1, - /*! \brief current node only helps to pass data around */ - kPassData = 2 - }; - - enum SeqType { - /*! \brief apply to rabit seq code */ - kSeq = 0, - /*! \brief apply to rabit cache seq code */ - kCache = 1 - }; - /*! - * \brief summary of actions proposed in all nodes - * this data structure is used to make consensus decision - * about next action to take in the recovery mode - */ - struct ActionSummary { - // maximumly allowed sequence id - static const uint32_t kSpecialOp = (1 << 26); - // special sequence number for local state checkpoint - static const uint32_t kLocalCheckPoint = (1 << 26) - 2; - // special sequnce number for local state checkpoint ack signal - static const uint32_t kLocalCheckAck = (1 << 26) - 1; - //--------------------------------------------- - // The following are bit mask of flag used in - //---------------------------------------------- - // some node want to load check point - static const int kLoadCheck = 1; - // some node want to do check point - static const int kCheckPoint = 2; - // check point Ack, we use a two phase message in check point, - // this is the second phase of check pointing - static const int kCheckAck = 4; - // there are difference sequence number the nodes proposed - // this means we want to do recover execution of the lower sequence - // action instead of normal execution - static const int kDiffSeq = 8; - // there are nodes request load cache - static const int kLoadBootstrapCache = 16; - // constructor - ActionSummary() = default; - // constructor of action - explicit ActionSummary(int seqno_flag, int cache_flag = 0, - uint32_t minseqno = kSpecialOp, uint32_t maxseqno = kSpecialOp) { - seqcode_ = (minseqno << 5) | seqno_flag; - maxseqcode_ = (maxseqno << 5) | cache_flag; - } - // minimum number of all operations by default - // maximum number of all cache operations otherwise - inline uint32_t Seqno(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; - return code >> 5; - } - // whether the operation set contains a load_check - inline bool LoadCheck(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; - return (code & kLoadCheck) != 0; - } - // whether the operation set contains a load_cache - inline bool LoadCache(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; - return (code & kLoadBootstrapCache) != 0; - } - // whether the operation set contains a check point - inline bool CheckPoint(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; - return (code & kCheckPoint) != 0; - } - // whether the operation set contains a check ack - inline bool CheckAck(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; - return (code & kCheckAck) != 0; - } - // whether the operation set contains different sequence number - inline bool DiffSeq() const { - return (seqcode_ & kDiffSeq) != 0; - } - // returns the operation flag of the result - inline int Flag(SeqType t = SeqType::kSeq) const { - int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_; - return code & 31; - } - // print flags in user friendly way - inline void PrintFlags(int rank, std::string prefix ) { - utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank, - prefix.c_str(), Seqno(), CheckPoint(), CheckAck(), - LoadCache(), DiffSeq(), Seqno(SeqType::kCache), - LoadCache(SeqType::kCache)); - } - // reducer for Allreduce, get the result ActionSummary from all nodes - inline static void Reducer(const void *src_, void *dst_, - int len, const MPI::Datatype &dtype) { - const ActionSummary *src = static_cast(src_); - ActionSummary *dst = reinterpret_cast(dst_); - for (int i = 0; i < len; ++i) { - uint32_t min_seqno = Min(src[i].Seqno(), dst[i].Seqno()); - uint32_t max_seqno = Max(src[i].Seqno(SeqType::kCache), - dst[i].Seqno(SeqType::kCache)); - int action_flag = src[i].Flag() | dst[i].Flag(); - // if any node is not requester set to 0 otherwise 1 - int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache); - // if seqno is different in src and destination - int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0; - // apply or to both seq diff flag as well as cache seq diff flag - dst[i] = ActionSummary(action_flag | seq_diff_flag, - role_flag, min_seqno, max_seqno); - } - } - - private: - // internel sequence code min of rabit seqno - uint32_t seqcode_; - // internal sequence code max of cache seqno - uint32_t maxseqcode_; - }; - /*! \brief data structure to remember result of Bcast and Allreduce calls*/ - class ResultBuffer{ - public: - // constructor - ResultBuffer() { - this->Clear(); - } - // clear the existing record - inline void Clear() { - seqno_.clear(); size_.clear(); - rptr_.clear(); rptr_.push_back(0); - data_.clear(); - } - // allocate temporal space - inline void *AllocTemp(size_t type_nbytes, size_t count) { - size_t size = type_nbytes * count; - size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); - utils::Assert(nhop != 0, "cannot allocate 0 size memory"); - // allocate addational nhop buffer size - data_.resize(rptr_.back() + nhop); - return BeginPtr(data_) + rptr_.back(); - } - // push the result in temp to the - inline void PushTemp(int seqid, size_t type_nbytes, size_t count) { - size_t size = type_nbytes * count; - size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); - if (seqno_.size() != 0) { - utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent"); - } - seqno_.push_back(seqid); - rptr_.push_back(rptr_.back() + nhop); - size_.push_back(size); - utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent"); - } - // return the stored result of seqid, if any - inline void* Query(int seqid, size_t *p_size) { - size_t idx = std::lower_bound(seqno_.begin(), - seqno_.end(), seqid) - seqno_.begin(); - if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr; - *p_size = size_[idx]; - return BeginPtr(data_) + rptr_[idx]; - } - // drop last stored result - inline void DropLast() { - utils::Assert(seqno_.size() != 0, "there is nothing to be dropped"); - seqno_.pop_back(); - rptr_.pop_back(); - size_.pop_back(); - data_.resize(rptr_.back()); - } - // the sequence number of last stored result - inline int LastSeqNo() const { - if (seqno_.size() == 0) return -1; - return seqno_.back(); - } - - private: - // sequence number of each - std::vector seqno_; - // pointer to the positions - std::vector rptr_; - // actual size of each buffer - std::vector size_; - // content of the buffer - std::vector data_; - }; - /*! - * \brief internal consistency check function, - * use check to ensure user always call CheckPoint/LoadCheckPoint - * with or without local but not both, this function will set the approperiate settings - * in the first call of LoadCheckPoint/CheckPoint - * - * \param with_local whether the user calls CheckPoint with local model - */ - void LocalModelCheck(bool with_local); - /*! - * \brief internal implementation of checkpoint, support both lazy and normal way - * - * \param global_model pointer to the globally shared model/state - * when calling this function, the caller need to gauranttees that global_model - * is the same in all nodes - * \param local_model pointer to local model, that is specific to current node/rank - * this can be NULL when no local state is needed - * \param lazy_checkpt whether the action is lazy checkpoint - * - * \sa CheckPoint, LazyCheckPoint - */ - void CheckPointImpl(const Serializable *global_model, - const Serializable *local_model, bool lazy_checkpt); - /*! - * \brief reset the all the existing links by sending Out-of-Band message marker - * after this function finishes, all the messages received and sent - * before in all live links are discarded, - * This allows us to get a fresh start after error has happened - * - * TODO(tqchen): this function is not yet functioning was not used by engine, - * simple resetlink and reconnect strategy is used - * - * \return this function can return kSuccess or kSockError - * when kSockError is returned, it simply means there are bad sockets in the links, - * and some link recovery proceduer is needed - */ - ReturnType TryResetLinks(); - /*! - * \brief if err_type indicates an error - * recover links according to the error type reported - * if there is no error, return true - * \param err_type the type of error happening in the system - * \return true if err_type is kSuccess, false otherwise - */ - bool CheckAndRecover(ReturnType err_type); - /*! - * \brief try to run recover execution for a request action described by flag and seqno, - * the function will keep blocking to run possible recovery operations before the specified action, - * until the requested result is received by a recovering procedure, - * or the function discovers that the requested action is not yet executed, and return false - * - * \param buf the buffer to store the result - * \param size the total size of the buffer - * \param flag flag information about the action \sa ActionSummary - * \param seqno sequence number of the action, if it is special action with flag set, - * seqno needs to be set to ActionSummary::kSpecialOp - * - * \return if this function can return true or false - * - true means buf already set to the - * result by recovering procedure, the action is complete, no further action is needed - * - false means this is the lastest action that has not yet been executed, need to execute the action - */ - bool RecoverExec(void *buf, size_t size, int flag, - int seqno = ActionSummary::kSpecialOp, - int cacheseqno = ActionSummary::kSpecialOp, - const char* caller = _CALLER); - /*! - * \brief try to load check point - * - * This is a collaborative function called by all nodes - * only the nodes with requester set to true really needs to load the check point - * other nodes acts as collaborative roles to complete this request - * - * \param requester whether current node is the requester - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryLoadCheckPoint(bool requester); - - /*! - * \brief try to load cache - * - * This is a collaborative function called by all nodes - * only the nodes with requester set to true really needs to load the check point - * other nodes acts as collaborative roles to complete this request - * \param requester whether current node is the requester - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryRestoreCache(bool requester, const int min_seq = ActionSummary::kSpecialOp, - const int max_seq = ActionSummary::kSpecialOp); - /*! - * \brief try to get the result of operation specified by seqno - * - * This is a collaborative function called by all nodes - * only the nodes with requester set to true really needs to get the result - * other nodes acts as collaborative roles to complete this request - * - * \param buf the buffer to store the result, this parameter is only used when current node is requester - * \param size the total size of the buffer, this parameter is only used when current node is requester - * \param seqno sequence number of the operation, this is unique index of a operation in current iteration - * \param requester whether current node is the requester - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester); - /*! - * \brief try to decide the routing strategy for recovery - * \param role the current role of the node - * \param p_size used to store the size of the message, for node in state kHaveData, - * this size must be set correctly before calling the function - * for others, this surves as output parameter - - * \param p_recvlink used to store the link current node should recv data from, if necessary - * this can be -1, which means current node have the data - * \param p_req_in used to store the resulting vector, indicating which link we should send the data to - * - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType, TryRecoverData - */ - ReturnType TryDecideRouting(RecoverType role, - size_t *p_size, - int *p_recvlink, - std::vector *p_req_in); - /*! - * \brief try to finish the data recovery request, - * this function is used together with TryDecideRouting - * \param role the current role of the node - * \param sendrecvbuf_ the buffer to store the data to be sent/recived - * - if the role is kHaveData, this stores the data to be sent - * - if the role is kRequestData, this is the buffer to store the result - * - if the role is kPassData, this will not be used, and can be NULL - * \param size the size of the data, obtained from TryDecideRouting - * \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting - * \param req_in the request of each link to send data, obtained from TryDecideRouting - * - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType, TryDecideRouting - */ - ReturnType TryRecoverData(RecoverType role, - void *sendrecvbuf_, - size_t size, - int recv_link, - const std::vector &req_in); - /*! - * \brief try to recover the local state, making each local state to be the result of itself - * plus replication of states in previous num_local_replica hops in the ring - * - * The input parameters must contain the valid local states available in current nodes, - * This function try ist best to "complete" the missing parts of local_rptr and local_chkpt - * If there is sufficient information in the ring, when the function returns, local_chkpt will - * contain num_local_replica + 1 checkpoints (including the chkpt of this node) - * If there is no sufficient information in the ring, this function the number of checkpoints - * will be less than the specified value - * - * \param p_local_rptr the pointer to the segment pointers in the states array - * \param p_local_chkpt the pointer to the storage of local check points - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryRecoverLocalState(std::vector *p_local_rptr, - std::string *p_local_chkpt); - /*! - * \brief try to checkpoint local state, this function is called in normal executation phase - * of checkpoint that contains local state -o * the input state must exactly one saved state(local state of current node), - * after complete, this function will get local state from previous num_local_replica nodes and put them - * into local_chkpt and local_rptr - * - * It is also OK to call TryRecoverLocalState instead, - * TryRecoverLocalState makes less assumption about the input, and requires more communications - * - * \param p_local_rptr the pointer to the segment pointers in the states array - * \param p_local_chkpt the pointer to the storage of local check points - * \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details - * \sa ReturnType, TryRecoverLocalState - */ - ReturnType TryCheckinLocalState(std::vector *p_local_rptr, - std::string *p_local_chkpt); - /*! - * \brief perform a ring passing to receive data from prev link, and sent data to next link - * this allows data to stream over a ring structure - * sendrecvbuf[0:read_ptr] are already provided by current node - * current node will recv sendrecvbuf[read_ptr:read_end] from prev link - * current node will send sendrecvbuf[write_ptr:write_end] to next link - * write_ptr will wait till the data is readed before sending the data - * this function requires read_end >= write_end - * - * \param sendrecvbuf_ the place to hold the incoming and outgoing data - * \param read_ptr the initial read pointer - * \param read_end the ending position to read - * \param write_ptr the initial write pointer - * \param write_end the ending position to write - * \param read_link pointer to link to previous position in ring - * \param write_link pointer to link of next position in ring - */ - ReturnType RingPassing(void *senrecvbuf_, - size_t read_ptr, - size_t read_end, - size_t write_ptr, - size_t write_end, - LinkRecord *read_link, - LinkRecord *write_link); - /*! - * \brief run message passing algorithm on the allreduce tree - * the result is edge message stored in p_edge_in and p_edge_out - * \param node_value the value associated with current node - * \param p_edge_in used to store input message from each of the edge - * \param p_edge_out used to store output message from each of the edge - * \param func a function that defines the message passing rule - * Parameters of func: - * - node_value same as node_value in the main function - * - edge_in the array of input messages from each edge, - * this includes the output edge, which should be excluded - * - out_index array the index of output edge, the function should - * exclude the output edge when compute the message passing value - * Return of func: - * the function returns the output message based on the input message and node_value - * - * \tparam EdgeType type of edge message, must be simple struct - * \tparam NodeType type of node value - */ - template - inline ReturnType MsgPassing(const NodeType &node_value, - std::vector *p_edge_in, - std::vector *p_edge_out, - EdgeType(*func) - (const NodeType &node_value, - const std::vector &edge_in, - size_t out_index)); - //---- recovery data structure ---- - // the round of result buffer, used to mode the result - int result_buffer_round_; - // result buffer of all reduce - ResultBuffer resbuf_; - // current cached allreduce/braodcast sequence number - int cur_cache_seq_; - // result buffer of cached all reduce - ResultBuffer cachebuf_; - // key of each cache entry - ResultBuffer lookupbuf_; - // last check point global model - std::string global_checkpoint_; - // lazy checkpoint of global model - const Serializable *global_lazycheck_; - // number of replica for local state/model - int num_local_replica_; - // number of default local replica - int default_local_replica_; - // flag to decide whether local model is used, -1: unknown, 0: no, 1:yes - int use_local_model_; - // number of replica for global state/model - int num_global_replica_; - // number of times recovery happens - int recover_counter_; - // --- recovery data structure for local checkpoint - // there is two version of the data structure, - // at one time one version is valid and another is used as temp memory - // pointer to memory position in the local model - // local model is stored in CSR format(like a sparse matrices) - // local_model[rptr[0]:rptr[1]] stores the model of current node - // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops - std::vector local_rptr_[2]; - // storage for local model replicas - std::string local_chkpt_[2]; - // version of local checkpoint can be 1 or 0 - int local_chkpt_version_; - // if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache - bool checkpoint_loaded_; - // sidecar executing timeout task - std::future rabit_timeout_task_; - // flag to shutdown rabit_timeout_task before timeout - std::atomic shutdown_timeout_{false}; - // error handler - void (* error_)(const char *fmt, ...) = utils::Error; - // assert handler - void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert; -}; -} // namespace engine -} // namespace rabit -// implementation of inline template function -#include "./allreduce_robust-inl.h" -#endif // RABIT_ALLREDUCE_ROBUST_H_ diff --git a/rabit/src/engine.cc b/rabit/src/engine.cc index ac7731fdd..616004765 100644 --- a/rabit/src/engine.cc +++ b/rabit/src/engine.cc @@ -12,14 +12,13 @@ #include #include "rabit/internal/engine.h" #include "allreduce_base.h" -#include "allreduce_robust.h" namespace rabit { namespace engine { // singleton sync manager #ifndef RABIT_USE_BASE #ifndef RABIT_USE_MOCK -using Manager = AllreduceRobust; +using Manager = AllreduceBase; #else typedef AllreduceMock Manager; #endif // RABIT_USE_MOCK diff --git a/src/learner.cc b/src/learner.cc index 0ef71f480..dd0244e87 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -809,12 +809,6 @@ class LearnerIO : public LearnerConfiguration { { std::vector saved_params; - // check if rabit_bootstrap_cache were set to non zero before adding to checkpoint - if (cfg_.find("rabit_bootstrap_cache") != cfg_.end() && - (cfg_.find("rabit_bootstrap_cache"))->second != "0") { - std::copy(saved_configs_.begin(), saved_configs_.end(), - std::back_inserter(saved_params)); - } for (const auto& key : saved_params) { auto it = cfg_.find(key); if (it != cfg_.end()) { diff --git a/tests/ci_build/approx.conf.in b/tests/ci_build/approx.conf.in deleted file mode 100644 index b2321a79a..000000000 --- a/tests/ci_build/approx.conf.in +++ /dev/null @@ -1,12 +0,0 @@ -# Originally an example in demo/regression/ -tree_method=approx -eta = 0.5 -gamma = 1.0 -seed = 0 -min_child_weight = 0 -max_depth = 5 - -num_round = 12 -save_period = 100 -data = "demo/data/agaricus.txt.train" -eval[test] = "demo/data/agaricus.txt.test" diff --git a/tests/ci_build/runxgb.sh b/tests/ci_build/runxgb.sh deleted file mode 100755 index 4825dccf5..000000000 --- a/tests/ci_build/runxgb.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -source activate cpu_test - -export DMLC_SUBMIT_CLUSTER=local - -submit="python3 dmlc-core/tracker/dmlc-submit" -# build xgboost with librabit mock -# define max worker retry with dmlc-core local num atempt -# instrument worker failure with mock=xxxx -# check if host recovered from expectected iteration -echo "====== 1. Fault recovery distributed test ======" -exec $submit --cluster=local --num-workers=10 --local-num-attempt=10 $1 $2 mock=0,10,1,0 mock=1,11,1,0 mock=1,11,1,1 mock=0,11,1,0 mock=4,11,1,0 mock=9,11,1,0 mock=8,11,2,0 mock=4,11,3,0 rabit_bootstrap_cache=1 rabit_debug=1 diff --git a/tests/cpp/rabit/allreduce_mock_test.cc b/tests/cpp/rabit/allreduce_mock_test.cc deleted file mode 100644 index 8d6f519c5..000000000 --- a/tests/cpp/rabit/allreduce_mock_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -#define RABIT_CXXTESTDEFS_H -#if !defined(_WIN32) -#include - -#include -#include -#include "../../../rabit/src/allreduce_mock.h" - -TEST(AllreduceMock, MockAllreduce) -{ - rabit::engine::AllreduceMock m; - - std::string mock_str = "mock=0,0,0,0"; - char cmd[mock_str.size()+1]; - std::copy(mock_str.begin(), mock_str.end(), cmd); - cmd[mock_str.size()] = '\0'; - - char* argv[] = {cmd}; - m.Init(1, argv); - m.rank = 0; - EXPECT_THROW(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), dmlc::Error); -} - -TEST(AllreduceMock, MockBroadcast) -{ - rabit::engine::AllreduceMock m; - std::string mock_str = "mock=0,1,2,0"; - char cmd[mock_str.size()+1]; - std::copy(mock_str.begin(), mock_str.end(), cmd); - cmd[mock_str.size()] = '\0'; - char* argv[] = {cmd}; - m.Init(1, argv); - m.rank = 0; - m.version_number=1; - m.seq_counter=2; - EXPECT_THROW(m.Broadcast(nullptr,0,0), dmlc::Error); -} - -TEST(AllreduceMock, MockGather) -{ - rabit::engine::AllreduceMock m; - std::string mock_str = "mock=3,13,22,0"; - char cmd[mock_str.size()+1]; - std::copy(mock_str.begin(), mock_str.end(), cmd); - cmd[mock_str.size()] = '\0'; - char* argv[] = {cmd}; - m.Init(1, argv); - m.rank = 3; - m.version_number=13; - m.seq_counter=22; - EXPECT_THROW({m.Allgather(nullptr,0,0,0,0);}, dmlc::Error); -} -#endif // !defined(_WIN32) diff --git a/tests/cpp/rabit/allreduce_robust_test.cc b/tests/cpp/rabit/allreduce_robust_test.cc deleted file mode 100644 index 02e19c4b6..000000000 --- a/tests/cpp/rabit/allreduce_robust_test.cc +++ /dev/null @@ -1,235 +0,0 @@ -#define RABIT_CXXTESTDEFS_H -#if !defined(_WIN32) -#include - -#include -#include -#include -#include "../../../rabit/src/allreduce_robust.h" - -inline void MockErr(const char *fmt, ...) {EXPECT_STRCASEEQ(fmt, "[%d] exit due to time out %d s\n");} -inline void MockAssert(bool val, const char *fmt, ...) {} -rabit::engine::AllreduceRobust::ReturnType err_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSockError); -rabit::engine::AllreduceRobust::ReturnType succ_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSuccess); - -TEST(AllreduceRobust, SyncErrorTimeout) -{ - rabit::engine::AllreduceRobust m; - - std::string rabit_timeout = "rabit_timeout=1"; - char cmd[rabit_timeout.size()+1]; - std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); - cmd[rabit_timeout.size()] = '\0'; - - std::string rabit_timeout_sec = "rabit_timeout_sec=1"; - char cmd1[rabit_timeout_sec.size()+1]; - std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); - cmd1[rabit_timeout_sec.size()] = '\0'; - - char* argv[] = {cmd,cmd1}; - m.Init(2, argv); - m.rank = 0; - m.rabit_bootstrap_cache = true; - m.error_ = MockErr; - m.assert_ = MockAssert; - EXPECT_EQ(m.CheckAndRecover(err_type), false); - std::this_thread::sleep_for(std::chrono::milliseconds(1500)); - EXPECT_EQ(m.rabit_timeout_task_.get(), false); -} - -TEST(AllreduceRobust, SyncErrorReset) -{ - rabit::engine::AllreduceRobust m; - - std::string rabit_timeout = "rabit_timeout=1"; - char cmd[rabit_timeout.size()+1]; - std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); - cmd[rabit_timeout.size()] = '\0'; - - std::string rabit_timeout_sec = "rabit_timeout_sec=1"; - char cmd1[rabit_timeout_sec.size()+1]; - std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); - cmd1[rabit_timeout_sec.size()] = '\0'; - - std::string rabit_debug = "rabit_debug=1"; - char cmd2[rabit_debug.size()+1]; - std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); - cmd2[rabit_debug.size()] = '\0'; - - char* argv[] = {cmd, cmd1,cmd2}; - m.Init(3, argv); - m.rank = 0; - m.assert_ = MockAssert; - EXPECT_EQ(m.CheckAndRecover(err_type), false); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - EXPECT_EQ(m.CheckAndRecover(succ_type), true); - EXPECT_EQ(m.rabit_timeout_task_.get(), true); - m.Shutdown(); -} - -TEST(AllreduceRobust, SyncSuccessErrorTimeout) -{ - rabit::engine::AllreduceRobust m; - - std::string rabit_timeout = "rabit_timeout=1"; - char cmd[rabit_timeout.size()+1]; - std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); - cmd[rabit_timeout.size()] = '\0'; - - std::string rabit_timeout_sec = "rabit_timeout_sec=1"; - char cmd1[rabit_timeout_sec.size()+1]; - std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); - cmd1[rabit_timeout_sec.size()] = '\0'; - - std::string rabit_debug = "rabit_debug=1"; - char cmd2[rabit_debug.size()+1]; - std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); - cmd2[rabit_debug.size()] = '\0'; - - char* argv[] = {cmd, cmd1,cmd2}; - m.Init(3, argv); - m.rank = 0; - m.rabit_bootstrap_cache = true; - m.assert_ = MockAssert; - m.error_ = MockErr; - EXPECT_EQ(m.CheckAndRecover(succ_type), true); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - EXPECT_EQ(m.CheckAndRecover(err_type), false); - std::this_thread::sleep_for(std::chrono::milliseconds(1500)); - EXPECT_EQ(m.rabit_timeout_task_.get(), false); -} - -TEST(AllreduceRobust, SyncSuccessErrorSuccess) -{ - rabit::engine::AllreduceRobust m; - - std::string rabit_timeout = "rabit_timeout=1"; - char cmd[rabit_timeout.size()+1]; - std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); - cmd[rabit_timeout.size()] = '\0'; - - std::string rabit_timeout_sec = "rabit_timeout_sec=1"; - char cmd1[rabit_timeout_sec.size()+1]; - std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); - cmd1[rabit_timeout_sec.size()] = '\0'; - - std::string rabit_debug = "rabit_debug=1"; - char cmd2[rabit_debug.size()+1]; - std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); - cmd2[rabit_debug.size()] = '\0'; - - char* argv[] = {cmd, cmd1,cmd2}; - m.Init(3, argv); - m.rank = 0; - m.rabit_bootstrap_cache = true; - m.assert_ = MockAssert; - EXPECT_EQ(m.CheckAndRecover(succ_type), true); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - EXPECT_EQ(m.CheckAndRecover(err_type), false); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - EXPECT_EQ(m.CheckAndRecover(succ_type), true); - std::this_thread::sleep_for(std::chrono::milliseconds(1100)); - EXPECT_EQ(m.rabit_timeout_task_.get(), true); - m.Shutdown(); -} - -TEST(AllreduceRobust, SyncErrorNoResetTimeout) -{ - rabit::engine::AllreduceRobust m; - - std::string rabit_timeout = "rabit_timeout=1"; - char cmd[rabit_timeout.size()+1]; - std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); - cmd[rabit_timeout.size()] = '\0'; - - std::string rabit_timeout_sec = "rabit_timeout_sec=1"; - char cmd1[rabit_timeout_sec.size()+1]; - std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); - cmd1[rabit_timeout_sec.size()] = '\0'; - - std::string rabit_debug = "rabit_debug=1"; - char cmd2[rabit_debug.size()+1]; - std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); - cmd2[rabit_debug.size()] = '\0'; - - char* argv[] = {cmd, cmd1,cmd2}; - m.Init(3, argv); - m.rank = 0; - m.rabit_bootstrap_cache = true; - m.assert_ = MockAssert; - m.error_ = MockErr; - auto start = std::chrono::system_clock::now(); - - EXPECT_EQ(m.CheckAndRecover(err_type), false); - std::this_thread::sleep_for(std::chrono::milliseconds(1100)); - - EXPECT_EQ(m.CheckAndRecover(err_type), false); - - m.rabit_timeout_task_.wait(); - auto end = std::chrono::system_clock::now(); - std::chrono::duration diff = end-start; - - EXPECT_EQ(m.rabit_timeout_task_.get(), false); - // expect second error don't overwrite/reset timeout task - EXPECT_LT(diff.count(), 2); -} - -TEST(AllreduceRobust, NoTimeoutShutDown) -{ - rabit::engine::AllreduceRobust m; - - std::string rabit_timeout = "rabit_timeout=1"; - char cmd[rabit_timeout.size()+1]; - std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); - cmd[rabit_timeout.size()] = '\0'; - - std::string rabit_timeout_sec = "rabit_timeout_sec=1"; - char cmd1[rabit_timeout_sec.size()+1]; - std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); - cmd1[rabit_timeout_sec.size()] = '\0'; - - std::string rabit_debug = "rabit_debug=1"; - char cmd2[rabit_debug.size()+1]; - std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); - cmd2[rabit_debug.size()] = '\0'; - - char* argv[] = {cmd, cmd1,cmd2}; - m.Init(3, argv); - m.rank = 0; - - EXPECT_EQ(m.CheckAndRecover(succ_type), true); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - m.Shutdown(); -} - -TEST(AllreduceRobust, ShutDownBeforeTimeout) -{ - rabit::engine::AllreduceRobust m; - - std::string rabit_timeout = "rabit_timeout=1"; - char cmd[rabit_timeout.size()+1]; - std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd); - cmd[rabit_timeout.size()] = '\0'; - - std::string rabit_timeout_sec = "rabit_timeout_sec=1"; - char cmd1[rabit_timeout_sec.size()+1]; - std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1); - cmd1[rabit_timeout_sec.size()] = '\0'; - - std::string rabit_debug = "rabit_debug=1"; - char cmd2[rabit_debug.size()+1]; - std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2); - cmd2[rabit_debug.size()] = '\0'; - - char* argv[] = {cmd, cmd1,cmd2}; - m.Init(3, argv); - m.rank = 0; - rabit::engine::AllreduceRobust::LinkRecord a; - m.err_link = &a; - - EXPECT_EQ(m.CheckAndRecover(err_type), false); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - m.Shutdown(); -} -#endif // !defined(_WIN32) diff --git a/tests/distributed/distributed_gpu.py b/tests/distributed/distributed_gpu.py index f30e39b1b..a2ab6d398 100644 --- a/tests/distributed/distributed_gpu.py +++ b/tests/distributed/distributed_gpu.py @@ -1,8 +1,8 @@ """Distributed GPU tests.""" import sys -import time import xgboost as xgb import os +import numpy as np def run_test(name, params_fun): @@ -28,7 +28,7 @@ def run_test(name, params_fun): # Have each worker save its model model_name = "test.model.%s.%d" % (name, rank) bst.dump_model(model_name, with_stats=True) - time.sleep(2) + xgb.rabit.allreduce(np.ones((1, 1)), xgb.rabit.Op.MAX) # sync xgb.rabit.tracker_print("Finished training\n") if (rank == 0): @@ -49,9 +49,6 @@ def run_test(name, params_fun): xgb.rabit.finalize() - if os.path.exists(model_name): - os.remove(model_name) - base_params = { 'tree_method': 'gpu_hist', diff --git a/tests/distributed/runtests-gpu.sh b/tests/distributed/runtests-gpu.sh index cc2d23cec..17e472482 100755 --- a/tests/distributed/runtests-gpu.sh +++ b/tests/distributed/runtests-gpu.sh @@ -7,6 +7,8 @@ submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit" echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" $submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py basic_1x4 || exit 1 +rm test.model.* echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" $submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py rf_1x4 || exit 1 +rm test.model.*