Drop single point model recovery (#6262)
* Pass rabit params in JVM package. * Implement timeout using poll timeout parameter. * Remove OOB data check.
This commit is contained in:
parent
81c37c28d5
commit
b5c2a47b20
14
Jenkinsfile
vendored
14
Jenkinsfile
vendored
@ -321,20 +321,6 @@ def TestPythonGPU(args) {
|
||||
}
|
||||
}
|
||||
|
||||
def TestCppRabit() {
|
||||
node(nodeReq) {
|
||||
unstash name: 'xgboost_rabit_tests'
|
||||
unstash name: 'srcs'
|
||||
echo "Test C++, rabit mock on"
|
||||
def container_type = "cpu"
|
||||
def docker_binary = "docker"
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} tests/ci_build/runxgb.sh xgboost tests/ci_build/approx.conf.in
|
||||
"""
|
||||
deleteDir()
|
||||
}
|
||||
}
|
||||
|
||||
def TestCppGPU(args) {
|
||||
def nodeReq = 'linux && mgpu'
|
||||
def artifact_cuda_version = (args.artifact_cuda_version) ?: ref_cuda_ver
|
||||
|
||||
@ -22,4 +22,4 @@ PKG_LIBS = @OPENMP_CXXFLAGS@ @OPENMP_LIB@ @ENDIAN_FLAG@ @BACKTRACE_LIB@ -pthread
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
|
||||
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o
|
||||
|
||||
@ -34,6 +34,6 @@ PKG_LIBS = $(SHLIB_OPENMP_CXXFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
OBJECTS= ./xgboost_R.o ./xgboost_custom.o ./xgboost_assert.o ./init.o \
|
||||
$(PKGROOT)/amalgamation/xgboost-all0.o $(PKGROOT)/amalgamation/dmlc-minimum0.o \
|
||||
$(PKGROOT)/rabit/src/engine.o $(PKGROOT)/rabit/src/c_api.o \
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o $(PKGROOT)/rabit/src/allreduce_robust.o
|
||||
$(PKGROOT)/rabit/src/allreduce_base.o
|
||||
|
||||
$(OBJECTS) : xgblib
|
||||
|
||||
@ -577,6 +577,7 @@ object XGBoost extends Serializable {
|
||||
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
|
||||
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
|
||||
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
|
||||
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
|
||||
val sc = trainingData.sparkContext
|
||||
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
|
||||
hasGroup, xgbExecParams.numWorkers)
|
||||
@ -595,6 +596,8 @@ object XGBoost extends Serializable {
|
||||
xgbExecParams.timeoutRequestWorkers,
|
||||
xgbExecParams.numWorkers,
|
||||
xgbExecParams.killSparkContextOnWorkerFailure)
|
||||
|
||||
tracker.getWorkerEnvs().putAll(xgbRabitParams)
|
||||
val rabitEnv = tracker.getWorkerEnvs
|
||||
val boostersAndMetrics = if (hasGroup) {
|
||||
trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
|
||||
|
||||
@ -2,8 +2,9 @@ cmake_minimum_required(VERSION 3.3)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
add_library(rabit src/allreduce_base.cc src/allreduce_robust.cc src/engine.cc src/c_api.cc)
|
||||
add_library(rabit_mock_static src/allreduce_base.cc src/allreduce_robust.cc src/engine_mock.cc src/c_api.cc)
|
||||
add_library(rabit src/allreduce_base.cc src/engine.cc src/c_api.cc)
|
||||
add_library(rabit_mock_static src/allreduce_base.cc src/engine_mock.cc src/c_api.cc)
|
||||
|
||||
target_link_libraries(rabit Threads::Threads dmlc)
|
||||
target_link_libraries(rabit_mock_static Threads::Threads dmlc)
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <unordered_map>
|
||||
#include "utils.h"
|
||||
|
||||
@ -95,18 +96,18 @@ namespace utils {
|
||||
static constexpr int kInvalidSocket = -1;
|
||||
|
||||
template <typename PollFD>
|
||||
int PollImpl(PollFD *pfd, int nfds, int timeout) {
|
||||
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
|
||||
#if defined(_WIN32)
|
||||
|
||||
#if IS_MINGW()
|
||||
MingWError();
|
||||
return -1;
|
||||
#else
|
||||
return WSAPoll(pfd, nfds, timeout);
|
||||
return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
|
||||
#endif // IS_MINGW()
|
||||
|
||||
#else
|
||||
return poll(pfd, nfds, timeout);
|
||||
return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
|
||||
#endif // IS_MINGW()
|
||||
}
|
||||
|
||||
@ -608,40 +609,20 @@ struct PollHelper {
|
||||
const auto& pfd = fds.find(fd);
|
||||
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
|
||||
}
|
||||
/*!
|
||||
* \brief Check if the descriptor has any exception
|
||||
* \param fd file descriptor to check status
|
||||
*/
|
||||
inline bool CheckExcept(SOCKET fd) const {
|
||||
const auto& pfd = fds.find(fd);
|
||||
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
|
||||
}
|
||||
/*!
|
||||
* \brief wait for exception event on a single descriptor
|
||||
* \param fd the file descriptor to wait the event for
|
||||
* \param timeout the timeout counter, can be negative, which means wait until the event happen
|
||||
* \return 1 if success, 0 if timeout, and -1 if error occurs
|
||||
*/
|
||||
inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
|
||||
pollfd pfd;
|
||||
pfd.fd = fd;
|
||||
pfd.events = POLLPRI;
|
||||
return PollImpl(&pfd, 1, timeout);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief peform poll on the set defined, read, write, exception
|
||||
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
|
||||
* \return
|
||||
*/
|
||||
inline void Poll(long timeout = -1) { // NOLINT(*)
|
||||
inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
|
||||
std::vector<pollfd> fdset;
|
||||
fdset.reserve(fds.size());
|
||||
for (auto kv : fds) {
|
||||
fdset.push_back(kv.second);
|
||||
}
|
||||
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||
if (ret == -1) {
|
||||
if (ret <= 0) {
|
||||
Socket::Error("Poll");
|
||||
} else {
|
||||
for (auto& pfd : fdset) {
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
option(DMLC_ROOT "Specify root of external dmlc core.")
|
||||
|
||||
add_library(allreduce_base "")
|
||||
add_library(allreduce_mock "")
|
||||
|
||||
target_sources(
|
||||
allreduce_base
|
||||
PRIVATE
|
||||
allreduce_base.cc
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
|
||||
)
|
||||
target_sources(
|
||||
allreduce_mock
|
||||
PRIVATE
|
||||
allreduce_robust.cc
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
allreduce_base
|
||||
PUBLIC
|
||||
${DMLC_ROOT}/include
|
||||
${CMAKE_CURRENT_LIST_DIR}/../../include)
|
||||
|
||||
target_include_directories(
|
||||
allreduce_mock
|
||||
PUBLIC
|
||||
${DMLC_ROOT}/include
|
||||
${CMAKE_CURRENT_LIST_DIR}/../../include)
|
||||
@ -1,6 +0,0 @@
|
||||
Source Files of Rabit
|
||||
====
|
||||
* This folder contains the source files of rabit library
|
||||
* The library headers are in folder [include](../include)
|
||||
* The .h files in this folder are internal header files that are only used by rabit and will not be seen by users
|
||||
|
||||
@ -6,8 +6,9 @@
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#define NOMINMAX
|
||||
#include "rabit/base.h"
|
||||
#include "rabit/internal/rabit-inl.h"
|
||||
#include "allreduce_base.h"
|
||||
#include <rabit/base.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <netinet/tcp.h>
|
||||
@ -208,8 +209,8 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
||||
rabit_timeout = utils::StringToBool(val);
|
||||
}
|
||||
if (!strcmp(name, "rabit_timeout_sec")) {
|
||||
timeout_sec = atoi(val);
|
||||
utils::Assert(timeout_sec >= 0, "rabit_timeout_sec should be non negative second");
|
||||
timeout_sec = std::chrono::seconds(atoi(val));
|
||||
utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second");
|
||||
}
|
||||
if (!strcmp(name, "rabit_enable_tcp_no_delay")) {
|
||||
if (!strcmp(val, "true")) {
|
||||
@ -549,14 +550,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
||||
// finish runing allreduce
|
||||
if (finished) break;
|
||||
// select must return
|
||||
watcher.Poll();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
if (watcher.CheckExcept(links[i].sock)) {
|
||||
return ReportError(&links[i], kGetExcept);
|
||||
}
|
||||
}
|
||||
watcher.Poll(timeout_sec);
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
|
||||
@ -729,14 +723,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
// finish running
|
||||
if (finished) break;
|
||||
// select
|
||||
watcher.Poll();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
if (watcher.CheckExcept(links[i].sock)) {
|
||||
return ReportError(&links[i], kGetExcept);
|
||||
}
|
||||
}
|
||||
watcher.Poll(timeout_sec);
|
||||
if (in_link == -2) {
|
||||
// probe in-link
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
@ -819,7 +806,7 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
watcher.Poll();
|
||||
watcher.Poll(timeout_sec);
|
||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||
size_t size = stop_read - read_ptr;
|
||||
size_t start = read_ptr % total_size;
|
||||
@ -831,7 +818,10 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
read_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return();
|
||||
if (ret != kSuccess) return ReportError(&next, ret);
|
||||
if (ret != kSuccess) {
|
||||
auto err = ReportError(&next, ret);
|
||||
return err;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (write_ptr < read_ptr && write_ptr != stop_write) {
|
||||
@ -845,7 +835,10 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
|
||||
write_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
ReturnType ret = Errno2Return();
|
||||
if (ret != kSuccess) return ReportError(&prev, ret);
|
||||
if (ret != kSuccess) {
|
||||
auto err = ReportError(&prev, ret);
|
||||
return err;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -913,7 +906,7 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
|
||||
finished = false;
|
||||
}
|
||||
if (finished) break;
|
||||
watcher.Poll();
|
||||
watcher.Poll(timeout_sec);
|
||||
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
|
||||
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
|
||||
if (ret != kSuccess) {
|
||||
|
||||
@ -12,6 +12,8 @@
|
||||
#ifndef RABIT_ALLREDUCE_BASE_H_
|
||||
#define RABIT_ALLREDUCE_BASE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
@ -35,6 +37,7 @@ class Datatype {
|
||||
}
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
|
||||
/*! \brief implementation of basic Allreduce engine */
|
||||
class AllreduceBase : public IEngine {
|
||||
public:
|
||||
@ -103,9 +106,11 @@ class AllreduceBase : public IEngine {
|
||||
size_t slice_end, size_t size_prev_slice,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size,
|
||||
slice_begin, slice_end, size_prev_slice) == kSuccess,
|
||||
if (world_size == 1 || world_size == -1) {
|
||||
return;
|
||||
}
|
||||
utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin,
|
||||
slice_end, size_prev_slice) == kSuccess,
|
||||
"AllgatherRing failed");
|
||||
}
|
||||
/*!
|
||||
@ -130,8 +135,8 @@ class AllreduceBase : public IEngine {
|
||||
const char *_caller = _CALLER) override {
|
||||
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
|
||||
if (world_size == 1 || world_size == -1) return;
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_,
|
||||
type_nbytes, count, reducer) == kSuccess,
|
||||
utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) ==
|
||||
kSuccess,
|
||||
"Allreduce failed");
|
||||
}
|
||||
/*!
|
||||
@ -518,9 +523,9 @@ class AllreduceBase : public IEngine {
|
||||
//---- data structure related to model ----
|
||||
// call sequence counter, records how many calls we made so far
|
||||
// from last call to CheckPoint, LoadCheckPoint
|
||||
int seq_counter; // NOLINT
|
||||
int seq_counter{0}; // NOLINT
|
||||
// version number of model
|
||||
int version_number; // NOLINT
|
||||
int version_number {0}; // NOLINT
|
||||
// whether the job is running in hadoop
|
||||
bool hadoop_mode; // NOLINT
|
||||
//---- local data related to link ----
|
||||
@ -571,7 +576,7 @@ class AllreduceBase : public IEngine {
|
||||
// enable detailed logging
|
||||
bool rabit_debug = false; // NOLINT
|
||||
// by default, if rabit worker not recover in half an hour exit
|
||||
int timeout_sec = 1800; // NOLINT
|
||||
std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT
|
||||
// flag to enable rabit_timeout
|
||||
bool rabit_timeout = false; // NOLINT
|
||||
// Enable TCP node delay
|
||||
|
||||
@ -13,11 +13,11 @@
|
||||
#include <sstream>
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "rabit/internal/timer.h"
|
||||
#include "allreduce_robust.h"
|
||||
#include "allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
class AllreduceMock : public AllreduceRobust {
|
||||
class AllreduceMock : public AllreduceBase {
|
||||
public:
|
||||
// constructor
|
||||
AllreduceMock() {
|
||||
@ -30,7 +30,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
// destructor
|
||||
~AllreduceMock() override = default;
|
||||
void SetParam(const char *name, const char *val) override {
|
||||
AllreduceRobust::SetParam(name, val);
|
||||
AllreduceBase::SetParam(name, val);
|
||||
// additional parameters
|
||||
if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val);
|
||||
if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val);
|
||||
@ -51,9 +51,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
const char *_caller = _CALLER) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||
count, reducer, prepare_fun, prepare_arg,
|
||||
_file, _line, _caller);
|
||||
AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer,
|
||||
prepare_fun, prepare_arg, _file, _line, _caller);
|
||||
tsum_allreduce_ += utils::GetTime() - tstart;
|
||||
}
|
||||
void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin,
|
||||
@ -62,16 +61,15 @@ class AllreduceMock : public AllreduceRobust {
|
||||
const char *_caller = _CALLER) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather");
|
||||
double tstart = utils::GetTime();
|
||||
AllreduceRobust::Allgather(sendrecvbuf, total_size,
|
||||
slice_begin, slice_end,
|
||||
size_prev_slice, _file, _line, _caller);
|
||||
AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end,
|
||||
size_prev_slice, _file, _line, _caller);
|
||||
tsum_allgather_ += utils::GetTime() - tstart;
|
||||
}
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast");
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
||||
AllreduceBase::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
|
||||
}
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) override {
|
||||
@ -79,11 +77,11 @@ class AllreduceMock : public AllreduceRobust {
|
||||
tsum_allgather_ = 0.0;
|
||||
time_checkpoint_ = utils::GetTime();
|
||||
if (force_local_ == 0) {
|
||||
return AllreduceRobust::LoadCheckPoint(global_model, local_model);
|
||||
return AllreduceBase::LoadCheckPoint(global_model, local_model);
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
ComboSerializer com(global_model, local_model);
|
||||
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
||||
return AllreduceBase::LoadCheckPoint(&dum, &com);
|
||||
}
|
||||
}
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
@ -92,18 +90,17 @@ class AllreduceMock : public AllreduceRobust {
|
||||
double tstart = utils::GetTime();
|
||||
double tbet_chkpt = tstart - time_checkpoint_;
|
||||
if (force_local_ == 0) {
|
||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||
AllreduceBase::CheckPoint(global_model, local_model);
|
||||
} else {
|
||||
DummySerializer dum;
|
||||
ComboSerializer com(global_model, local_model);
|
||||
AllreduceRobust::CheckPoint(&dum, &com);
|
||||
AllreduceBase::CheckPoint(&dum, &com);
|
||||
}
|
||||
time_checkpoint_ = utils::GetTime();
|
||||
double tcost = utils::GetTime() - tstart;
|
||||
if (report_stats_ != 0 && rank == 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[v" << version_number << "] global_size=" << global_checkpoint_.length()
|
||||
<< ",local_size=" << (local_chkpt_[0].length() + local_chkpt_[1].length())
|
||||
ss << "[v" << version_number << "] global_size="
|
||||
<< ",check_tcost="<< tcost <<" sec"
|
||||
<< ",allreduce_tcost=" << tsum_allreduce_ << " sec"
|
||||
<< ",allgather_tcost=" << tsum_allgather_ << " sec"
|
||||
@ -116,7 +113,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "LazyCheckPoint");
|
||||
AllreduceRobust::LazyCheckPoint(global_model);
|
||||
AllreduceBase::LazyCheckPoint(global_model);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -186,7 +183,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
if (mock_map_.count(key) != 0) {
|
||||
num_trial_ += 1;
|
||||
// data processing frameworks runs on shared process
|
||||
error_("[%d]@@@Hit Mock Error:%s ", rank, name);
|
||||
throw dmlc::Error(std::to_string(rank) + "@@@Hit Mock Error: " + name);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,169 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_robust-inl.h
|
||||
* \brief implementation of inline template function in AllreduceRobust
|
||||
*
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_ROBUST_INL_H_
|
||||
#define RABIT_ALLREDUCE_ROBUST_INL_H_
|
||||
#include <vector>
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*!
|
||||
* \brief run message passing algorithm on the allreduce tree
|
||||
* the result is edge message stored in p_edge_in and p_edge_out
|
||||
* \param node_value the value associated with current node
|
||||
* \param p_edge_in used to store input message from each of the edge
|
||||
* \param p_edge_out used to store output message from each of the edge
|
||||
* \param func a function that defines the message passing rule
|
||||
* Parameters of func:
|
||||
* - node_value same as node_value in the main function
|
||||
* - edge_in the array of input messages from each edge,
|
||||
* this includes the output edge, which should be excluded
|
||||
* - out_index array the index of output edge, the function should
|
||||
* exclude the output edge when compute the message passing value
|
||||
* Return of func:
|
||||
* the function returns the output message based on the input message and node_value
|
||||
*
|
||||
* \tparam EdgeType type of edge message, must be simple struct
|
||||
* \tparam NodeType type of node value
|
||||
*/
|
||||
template<typename NodeType, typename EdgeType>
|
||||
inline AllreduceRobust::ReturnType
|
||||
AllreduceRobust::MsgPassing(const NodeType &node_value,
|
||||
std::vector<EdgeType> *p_edge_in,
|
||||
std::vector<EdgeType> *p_edge_out,
|
||||
EdgeType(*func)
|
||||
(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index)) {
|
||||
RefLinkVector &links = tree_links;
|
||||
if (links.Size() == 0) return kSuccess;
|
||||
// number of links
|
||||
const int nlink = static_cast<int>(links.Size());
|
||||
// initialize the pointers
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
}
|
||||
std::vector<EdgeType> &edge_in = *p_edge_in;
|
||||
std::vector<EdgeType> &edge_out = *p_edge_out;
|
||||
edge_in.resize(nlink);
|
||||
edge_out.resize(nlink);
|
||||
// stages in the process
|
||||
// 0: recv messages from childs
|
||||
// 1: send message to parent
|
||||
// 2: recv message from parent
|
||||
// 3: send message to childs
|
||||
int stage = 0;
|
||||
// if no childs, no need to, directly start passing message
|
||||
if (nlink == static_cast<int>(parent_index != -1)) {
|
||||
utils::Assert(parent_index == 0, "parent must be 0");
|
||||
edge_out[parent_index] = func(node_value, edge_in, parent_index);
|
||||
stage = 1;
|
||||
}
|
||||
// while we have not passed the messages out
|
||||
while (true) {
|
||||
// for node with no parent, directly do stage 3
|
||||
if (parent_index == -1) {
|
||||
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
|
||||
}
|
||||
// poll helper
|
||||
utils::PollHelper watcher;
|
||||
bool done = (stage == 3);
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
watcher.WatchException(links[i].sock);
|
||||
switch (stage) {
|
||||
case 0:
|
||||
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
|
||||
watcher.WatchRead(links[i].sock);
|
||||
}
|
||||
break;
|
||||
case 1:
|
||||
if (i == parent_index) {
|
||||
watcher.WatchWrite(links[i].sock);
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
if (i == parent_index) {
|
||||
watcher.WatchRead(links[i].sock);
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||
watcher.WatchWrite(links[i].sock);
|
||||
done = false;
|
||||
}
|
||||
break;
|
||||
default: utils::Error("invalid stage");
|
||||
}
|
||||
}
|
||||
// finish all the stages, and write out message
|
||||
if (done) break;
|
||||
watcher.Poll();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
// recive OOB message from some link
|
||||
if (watcher.CheckExcept(links[i].sock)) {
|
||||
return ReportError(&links[i], kGetExcept);
|
||||
}
|
||||
}
|
||||
if (stage == 0) {
|
||||
bool finished = true;
|
||||
// read data from childs
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index) {
|
||||
if (watcher.CheckRead(links[i].sock)) {
|
||||
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||
}
|
||||
if (links[i].size_read != sizeof(EdgeType)) finished = false;
|
||||
}
|
||||
}
|
||||
// if no parent, jump to stage 3, otherwise do stage 1
|
||||
if (finished) {
|
||||
if (parent_index != -1) {
|
||||
edge_out[parent_index] = func(node_value, edge_in, parent_index);
|
||||
stage = 1;
|
||||
} else {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
edge_out[i] = func(node_value, edge_in, i);
|
||||
}
|
||||
stage = 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (stage == 1) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
ReturnType ret = links[pid].WriteFromArray(&edge_out[pid], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[pid], ret);
|
||||
if (links[pid].size_write == sizeof(EdgeType)) stage = 2;
|
||||
}
|
||||
if (stage == 2) {
|
||||
const int pid = this->parent_index;
|
||||
utils::Assert(pid != -1, "MsgPassing invalid stage");
|
||||
ReturnType ret = links[pid].ReadToArray(&edge_in[pid], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[pid], ret);
|
||||
if (links[pid].size_read == sizeof(EdgeType)) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != pid) edge_out[i] = func(node_value, edge_in, i);
|
||||
}
|
||||
stage = 3;
|
||||
}
|
||||
}
|
||||
if (stage == 3) {
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
|
||||
ReturnType ret = links[i].WriteFromArray(&edge_out[i], sizeof(EdgeType));
|
||||
if (ret != kSuccess) return ReportError(&links[i], ret);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ALLREDUCE_ROBUST_INL_H_
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,665 +0,0 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file allreduce_robust.h
|
||||
* \brief Robust implementation of Allreduce
|
||||
* using TCP non-block socket and tree-shape reduction.
|
||||
*
|
||||
* This implementation considers the failure of nodes
|
||||
*
|
||||
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_ROBUST_H_
|
||||
#define RABIT_ALLREDUCE_ROBUST_H_
|
||||
#include <future>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "allreduce_base.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
/*! \brief implementation of fault tolerant all reduce engine */
|
||||
class AllreduceRobust : public AllreduceBase {
|
||||
public:
|
||||
AllreduceRobust();
|
||||
~AllreduceRobust() override = default;
|
||||
// initialize the manager
|
||||
bool Init(int argc, char* argv[]) override;
|
||||
/*! \brief shutdown the engine */
|
||||
bool Shutdown() override;
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
void SetParam(const char *name, const char *val) override;
|
||||
/*!
|
||||
* \brief perform immutable local bootstrap cache insertion
|
||||
* \param key unique cache key
|
||||
* \param buf buffer of allreduce/robust payload to copy
|
||||
* \param buflen total number of bytes
|
||||
* \return -1 if no recovery cache fetched otherwise 0
|
||||
*/
|
||||
int SetBootstrapCache(const std::string &key, const void *buf,
|
||||
const size_t type_nbytes, const size_t count);
|
||||
/*!
|
||||
* \brief perform bootstrap cache lookup if nodes in fault recovery
|
||||
* \param key unique cache key
|
||||
* \param buf buffer for recv allreduce/robust payload
|
||||
* \param buflen total number of bytes
|
||||
*/
|
||||
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
|
||||
const size_t count);
|
||||
/*!
|
||||
* \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf,
|
||||
* the data provided by current node k is [slice_begin, slice_end),
|
||||
* the next node's segment must start with slice_end
|
||||
* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments
|
||||
* use a ring based algorithm
|
||||
*
|
||||
* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually
|
||||
* \param total_size total size of data to be gathered
|
||||
* \param slice_begin beginning of the current slice
|
||||
* \param slice_end end of the current slice
|
||||
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin,
|
||||
size_t slice_end, size_t size_prev_slice,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override;
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param type_nbytes the unit number of bytes the type have
|
||||
* \param count number of elements to be reduced
|
||||
* \param reducer reduce function
|
||||
* \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg)
|
||||
* will be called by the function before performing Allreduce, to intialize the data in sendrecvbuf_.
|
||||
* If the result of Allreduce can be recovered directly, then prepare_func will NOT be called
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \param prepare_arg argument used to passed into the lazy preprocessing function
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count,
|
||||
ReduceFunction reducer, PreprocFunction prepare_fun = nullptr,
|
||||
void *prepare_arg = nullptr, const char *_file = _FILE,
|
||||
const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override;
|
||||
/*!
|
||||
* \brief broadcast data from root to all nodes
|
||||
* \param sendrecvbuf_ buffer for both sending and recving data
|
||||
* \param size the size of the data to be broadcasted
|
||||
* \param root the root worker id to broadcast the data
|
||||
* \param _file caller file name used to generate unique cache key
|
||||
* \param _line caller line number used to generate unique cache key
|
||||
* \param _caller caller function name used to generate unique cache key
|
||||
*/
|
||||
void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
|
||||
const char *_file = _FILE, const int _line = _LINE,
|
||||
const char *_caller = _CALLER) override;
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = nullptr) override;
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = nullptr) override {
|
||||
this->CheckPointImpl(global_model, local_model, false);
|
||||
}
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met(see detailed expplaination).
|
||||
*
|
||||
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||
* In another words, global_model model can be changed only between last call of
|
||||
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||
*
|
||||
* For example, suppose the calling sequence is:
|
||||
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||
*
|
||||
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||
* improve efficiency of the program.
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
void LazyCheckPoint(const Serializable *global_model) override {
|
||||
this->CheckPointImpl(global_model, nullptr, true);
|
||||
}
|
||||
/*!
|
||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throw an exception out,
|
||||
* this function is only used for test purpose
|
||||
*/
|
||||
void InitAfterException() override {
|
||||
// simple way, shutdown all links
|
||||
for (auto& link : all_links) {
|
||||
if (link.sock.BadSocket()) {
|
||||
link.sock.Close();
|
||||
}
|
||||
}
|
||||
ReConnectLinks("recover");
|
||||
}
|
||||
|
||||
protected:
|
||||
// constant one byte out of band message to indicate error happening
|
||||
// and mark for channel cleanup
|
||||
static const char kOOBReset = 95;
|
||||
// and mark for channel cleanup, after OOB signal
|
||||
static const char kResetMark = 97;
|
||||
// and mark for channel cleanup
|
||||
static const char kResetAck = 97;
|
||||
/*! \brief type of roles each node can play during recovery */
|
||||
enum RecoverType {
|
||||
/*! \brief current node have data */
|
||||
kHaveData = 0,
|
||||
/*! \brief current node request data */
|
||||
kRequestData = 1,
|
||||
/*! \brief current node only helps to pass data around */
|
||||
kPassData = 2
|
||||
};
|
||||
|
||||
enum SeqType {
|
||||
/*! \brief apply to rabit seq code */
|
||||
kSeq = 0,
|
||||
/*! \brief apply to rabit cache seq code */
|
||||
kCache = 1
|
||||
};
|
||||
/*!
|
||||
* \brief summary of actions proposed in all nodes
|
||||
* this data structure is used to make consensus decision
|
||||
* about next action to take in the recovery mode
|
||||
*/
|
||||
struct ActionSummary {
|
||||
// maximumly allowed sequence id
|
||||
static const uint32_t kSpecialOp = (1 << 26);
|
||||
// special sequence number for local state checkpoint
|
||||
static const uint32_t kLocalCheckPoint = (1 << 26) - 2;
|
||||
// special sequnce number for local state checkpoint ack signal
|
||||
static const uint32_t kLocalCheckAck = (1 << 26) - 1;
|
||||
//---------------------------------------------
|
||||
// The following are bit mask of flag used in
|
||||
//----------------------------------------------
|
||||
// some node want to load check point
|
||||
static const int kLoadCheck = 1;
|
||||
// some node want to do check point
|
||||
static const int kCheckPoint = 2;
|
||||
// check point Ack, we use a two phase message in check point,
|
||||
// this is the second phase of check pointing
|
||||
static const int kCheckAck = 4;
|
||||
// there are difference sequence number the nodes proposed
|
||||
// this means we want to do recover execution of the lower sequence
|
||||
// action instead of normal execution
|
||||
static const int kDiffSeq = 8;
|
||||
// there are nodes request load cache
|
||||
static const int kLoadBootstrapCache = 16;
|
||||
// constructor
|
||||
ActionSummary() = default;
|
||||
// constructor of action
|
||||
explicit ActionSummary(int seqno_flag, int cache_flag = 0,
|
||||
uint32_t minseqno = kSpecialOp, uint32_t maxseqno = kSpecialOp) {
|
||||
seqcode_ = (minseqno << 5) | seqno_flag;
|
||||
maxseqcode_ = (maxseqno << 5) | cache_flag;
|
||||
}
|
||||
// minimum number of all operations by default
|
||||
// maximum number of all cache operations otherwise
|
||||
inline uint32_t Seqno(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return code >> 5;
|
||||
}
|
||||
// whether the operation set contains a load_check
|
||||
inline bool LoadCheck(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kLoadCheck) != 0;
|
||||
}
|
||||
// whether the operation set contains a load_cache
|
||||
inline bool LoadCache(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kLoadBootstrapCache) != 0;
|
||||
}
|
||||
// whether the operation set contains a check point
|
||||
inline bool CheckPoint(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kCheckPoint) != 0;
|
||||
}
|
||||
// whether the operation set contains a check ack
|
||||
inline bool CheckAck(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return (code & kCheckAck) != 0;
|
||||
}
|
||||
// whether the operation set contains different sequence number
|
||||
inline bool DiffSeq() const {
|
||||
return (seqcode_ & kDiffSeq) != 0;
|
||||
}
|
||||
// returns the operation flag of the result
|
||||
inline int Flag(SeqType t = SeqType::kSeq) const {
|
||||
int code = t == SeqType::kSeq ? seqcode_ : maxseqcode_;
|
||||
return code & 31;
|
||||
}
|
||||
// print flags in user friendly way
|
||||
inline void PrintFlags(int rank, std::string prefix ) {
|
||||
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n", rank,
|
||||
prefix.c_str(), Seqno(), CheckPoint(), CheckAck(),
|
||||
LoadCache(), DiffSeq(), Seqno(SeqType::kCache),
|
||||
LoadCache(SeqType::kCache));
|
||||
}
|
||||
// reducer for Allreduce, get the result ActionSummary from all nodes
|
||||
inline static void Reducer(const void *src_, void *dst_,
|
||||
int len, const MPI::Datatype &dtype) {
|
||||
const ActionSummary *src = static_cast<const ActionSummary*>(src_);
|
||||
ActionSummary *dst = reinterpret_cast<ActionSummary*>(dst_);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
uint32_t min_seqno = Min(src[i].Seqno(), dst[i].Seqno());
|
||||
uint32_t max_seqno = Max(src[i].Seqno(SeqType::kCache),
|
||||
dst[i].Seqno(SeqType::kCache));
|
||||
int action_flag = src[i].Flag() | dst[i].Flag();
|
||||
// if any node is not requester set to 0 otherwise 1
|
||||
int role_flag = src[i].Flag(SeqType::kCache) & dst[i].Flag(SeqType::kCache);
|
||||
// if seqno is different in src and destination
|
||||
int seq_diff_flag = src[i].Seqno() != dst[i].Seqno() ? kDiffSeq : 0;
|
||||
// apply or to both seq diff flag as well as cache seq diff flag
|
||||
dst[i] = ActionSummary(action_flag | seq_diff_flag,
|
||||
role_flag, min_seqno, max_seqno);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// internel sequence code min of rabit seqno
|
||||
uint32_t seqcode_;
|
||||
// internal sequence code max of cache seqno
|
||||
uint32_t maxseqcode_;
|
||||
};
|
||||
/*! \brief data structure to remember result of Bcast and Allreduce calls*/
|
||||
class ResultBuffer{
|
||||
public:
|
||||
// constructor
|
||||
ResultBuffer() {
|
||||
this->Clear();
|
||||
}
|
||||
// clear the existing record
|
||||
inline void Clear() {
|
||||
seqno_.clear(); size_.clear();
|
||||
rptr_.clear(); rptr_.push_back(0);
|
||||
data_.clear();
|
||||
}
|
||||
// allocate temporal space
|
||||
inline void *AllocTemp(size_t type_nbytes, size_t count) {
|
||||
size_t size = type_nbytes * count;
|
||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
utils::Assert(nhop != 0, "cannot allocate 0 size memory");
|
||||
// allocate addational nhop buffer size
|
||||
data_.resize(rptr_.back() + nhop);
|
||||
return BeginPtr(data_) + rptr_.back();
|
||||
}
|
||||
// push the result in temp to the
|
||||
inline void PushTemp(int seqid, size_t type_nbytes, size_t count) {
|
||||
size_t size = type_nbytes * count;
|
||||
size_t nhop = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t);
|
||||
if (seqno_.size() != 0) {
|
||||
utils::Assert(seqno_.back() < seqid, "PushTemp seqid inconsistent");
|
||||
}
|
||||
seqno_.push_back(seqid);
|
||||
rptr_.push_back(rptr_.back() + nhop);
|
||||
size_.push_back(size);
|
||||
utils::Assert(data_.size() == rptr_.back(), "PushTemp inconsistent");
|
||||
}
|
||||
// return the stored result of seqid, if any
|
||||
inline void* Query(int seqid, size_t *p_size) {
|
||||
size_t idx = std::lower_bound(seqno_.begin(),
|
||||
seqno_.end(), seqid) - seqno_.begin();
|
||||
if (idx == seqno_.size() || seqno_[idx] != seqid) return nullptr;
|
||||
*p_size = size_[idx];
|
||||
return BeginPtr(data_) + rptr_[idx];
|
||||
}
|
||||
// drop last stored result
|
||||
inline void DropLast() {
|
||||
utils::Assert(seqno_.size() != 0, "there is nothing to be dropped");
|
||||
seqno_.pop_back();
|
||||
rptr_.pop_back();
|
||||
size_.pop_back();
|
||||
data_.resize(rptr_.back());
|
||||
}
|
||||
// the sequence number of last stored result
|
||||
inline int LastSeqNo() const {
|
||||
if (seqno_.size() == 0) return -1;
|
||||
return seqno_.back();
|
||||
}
|
||||
|
||||
private:
|
||||
// sequence number of each
|
||||
std::vector<int> seqno_;
|
||||
// pointer to the positions
|
||||
std::vector<size_t> rptr_;
|
||||
// actual size of each buffer
|
||||
std::vector<size_t> size_;
|
||||
// content of the buffer
|
||||
std::vector<uint64_t> data_;
|
||||
};
|
||||
/*!
|
||||
* \brief internal consistency check function,
|
||||
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||
* with or without local but not both, this function will set the approperiate settings
|
||||
* in the first call of LoadCheckPoint/CheckPoint
|
||||
*
|
||||
* \param with_local whether the user calls CheckPoint with local model
|
||||
*/
|
||||
void LocalModelCheck(bool with_local);
|
||||
/*!
|
||||
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||
*
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
* \param lazy_checkpt whether the action is lazy checkpoint
|
||||
*
|
||||
* \sa CheckPoint, LazyCheckPoint
|
||||
*/
|
||||
void CheckPointImpl(const Serializable *global_model,
|
||||
const Serializable *local_model, bool lazy_checkpt);
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
* after this function finishes, all the messages received and sent
|
||||
* before in all live links are discarded,
|
||||
* This allows us to get a fresh start after error has happened
|
||||
*
|
||||
* TODO(tqchen): this function is not yet functioning was not used by engine,
|
||||
* simple resetlink and reconnect strategy is used
|
||||
*
|
||||
* \return this function can return kSuccess or kSockError
|
||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||
* and some link recovery proceduer is needed
|
||||
*/
|
||||
ReturnType TryResetLinks();
|
||||
/*!
|
||||
* \brief if err_type indicates an error
|
||||
* recover links according to the error type reported
|
||||
* if there is no error, return true
|
||||
* \param err_type the type of error happening in the system
|
||||
* \return true if err_type is kSuccess, false otherwise
|
||||
*/
|
||||
bool CheckAndRecover(ReturnType err_type);
|
||||
/*!
|
||||
* \brief try to run recover execution for a request action described by flag and seqno,
|
||||
* the function will keep blocking to run possible recovery operations before the specified action,
|
||||
* until the requested result is received by a recovering procedure,
|
||||
* or the function discovers that the requested action is not yet executed, and return false
|
||||
*
|
||||
* \param buf the buffer to store the result
|
||||
* \param size the total size of the buffer
|
||||
* \param flag flag information about the action \sa ActionSummary
|
||||
* \param seqno sequence number of the action, if it is special action with flag set,
|
||||
* seqno needs to be set to ActionSummary::kSpecialOp
|
||||
*
|
||||
* \return if this function can return true or false
|
||||
* - true means buf already set to the
|
||||
* result by recovering procedure, the action is complete, no further action is needed
|
||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||
*/
|
||||
bool RecoverExec(void *buf, size_t size, int flag,
|
||||
int seqno = ActionSummary::kSpecialOp,
|
||||
int cacheseqno = ActionSummary::kSpecialOp,
|
||||
const char* caller = _CALLER);
|
||||
/*!
|
||||
* \brief try to load check point
|
||||
*
|
||||
* This is a collaborative function called by all nodes
|
||||
* only the nodes with requester set to true really needs to load the check point
|
||||
* other nodes acts as collaborative roles to complete this request
|
||||
*
|
||||
* \param requester whether current node is the requester
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryLoadCheckPoint(bool requester);
|
||||
|
||||
/*!
|
||||
* \brief try to load cache
|
||||
*
|
||||
* This is a collaborative function called by all nodes
|
||||
* only the nodes with requester set to true really needs to load the check point
|
||||
* other nodes acts as collaborative roles to complete this request
|
||||
* \param requester whether current node is the requester
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryRestoreCache(bool requester, const int min_seq = ActionSummary::kSpecialOp,
|
||||
const int max_seq = ActionSummary::kSpecialOp);
|
||||
/*!
|
||||
* \brief try to get the result of operation specified by seqno
|
||||
*
|
||||
* This is a collaborative function called by all nodes
|
||||
* only the nodes with requester set to true really needs to get the result
|
||||
* other nodes acts as collaborative roles to complete this request
|
||||
*
|
||||
* \param buf the buffer to store the result, this parameter is only used when current node is requester
|
||||
* \param size the total size of the buffer, this parameter is only used when current node is requester
|
||||
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
|
||||
* \param requester whether current node is the requester
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryGetResult(void *buf, size_t size, int seqno, bool requester);
|
||||
/*!
|
||||
* \brief try to decide the routing strategy for recovery
|
||||
* \param role the current role of the node
|
||||
* \param p_size used to store the size of the message, for node in state kHaveData,
|
||||
* this size must be set correctly before calling the function
|
||||
* for others, this surves as output parameter
|
||||
|
||||
* \param p_recvlink used to store the link current node should recv data from, if necessary
|
||||
* this can be -1, which means current node have the data
|
||||
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
||||
*
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryRecoverData
|
||||
*/
|
||||
ReturnType TryDecideRouting(RecoverType role,
|
||||
size_t *p_size,
|
||||
int *p_recvlink,
|
||||
std::vector<bool> *p_req_in);
|
||||
/*!
|
||||
* \brief try to finish the data recovery request,
|
||||
* this function is used together with TryDecideRouting
|
||||
* \param role the current role of the node
|
||||
* \param sendrecvbuf_ the buffer to store the data to be sent/recived
|
||||
* - if the role is kHaveData, this stores the data to be sent
|
||||
* - if the role is kRequestData, this is the buffer to store the result
|
||||
* - if the role is kPassData, this will not be used, and can be NULL
|
||||
* \param size the size of the data, obtained from TryDecideRouting
|
||||
* \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting
|
||||
* \param req_in the request of each link to send data, obtained from TryDecideRouting
|
||||
*
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryDecideRouting
|
||||
*/
|
||||
ReturnType TryRecoverData(RecoverType role,
|
||||
void *sendrecvbuf_,
|
||||
size_t size,
|
||||
int recv_link,
|
||||
const std::vector<bool> &req_in);
|
||||
/*!
|
||||
* \brief try to recover the local state, making each local state to be the result of itself
|
||||
* plus replication of states in previous num_local_replica hops in the ring
|
||||
*
|
||||
* The input parameters must contain the valid local states available in current nodes,
|
||||
* This function try ist best to "complete" the missing parts of local_rptr and local_chkpt
|
||||
* If there is sufficient information in the ring, when the function returns, local_chkpt will
|
||||
* contain num_local_replica + 1 checkpoints (including the chkpt of this node)
|
||||
* If there is no sufficient information in the ring, this function the number of checkpoints
|
||||
* will be less than the specified value
|
||||
*
|
||||
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||
* \param p_local_chkpt the pointer to the storage of local check points
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType
|
||||
*/
|
||||
ReturnType TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
std::string *p_local_chkpt);
|
||||
/*!
|
||||
* \brief try to checkpoint local state, this function is called in normal executation phase
|
||||
* of checkpoint that contains local state
|
||||
o * the input state must exactly one saved state(local state of current node),
|
||||
* after complete, this function will get local state from previous num_local_replica nodes and put them
|
||||
* into local_chkpt and local_rptr
|
||||
*
|
||||
* It is also OK to call TryRecoverLocalState instead,
|
||||
* TryRecoverLocalState makes less assumption about the input, and requires more communications
|
||||
*
|
||||
* \param p_local_rptr the pointer to the segment pointers in the states array
|
||||
* \param p_local_chkpt the pointer to the storage of local check points
|
||||
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
||||
* \sa ReturnType, TryRecoverLocalState
|
||||
*/
|
||||
ReturnType TryCheckinLocalState(std::vector<size_t> *p_local_rptr,
|
||||
std::string *p_local_chkpt);
|
||||
/*!
|
||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||
* this allows data to stream over a ring structure
|
||||
* sendrecvbuf[0:read_ptr] are already provided by current node
|
||||
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
|
||||
* current node will send sendrecvbuf[write_ptr:write_end] to next link
|
||||
* write_ptr will wait till the data is readed before sending the data
|
||||
* this function requires read_end >= write_end
|
||||
*
|
||||
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
|
||||
* \param read_ptr the initial read pointer
|
||||
* \param read_end the ending position to read
|
||||
* \param write_ptr the initial write pointer
|
||||
* \param write_end the ending position to write
|
||||
* \param read_link pointer to link to previous position in ring
|
||||
* \param write_link pointer to link of next position in ring
|
||||
*/
|
||||
ReturnType RingPassing(void *senrecvbuf_,
|
||||
size_t read_ptr,
|
||||
size_t read_end,
|
||||
size_t write_ptr,
|
||||
size_t write_end,
|
||||
LinkRecord *read_link,
|
||||
LinkRecord *write_link);
|
||||
/*!
|
||||
* \brief run message passing algorithm on the allreduce tree
|
||||
* the result is edge message stored in p_edge_in and p_edge_out
|
||||
* \param node_value the value associated with current node
|
||||
* \param p_edge_in used to store input message from each of the edge
|
||||
* \param p_edge_out used to store output message from each of the edge
|
||||
* \param func a function that defines the message passing rule
|
||||
* Parameters of func:
|
||||
* - node_value same as node_value in the main function
|
||||
* - edge_in the array of input messages from each edge,
|
||||
* this includes the output edge, which should be excluded
|
||||
* - out_index array the index of output edge, the function should
|
||||
* exclude the output edge when compute the message passing value
|
||||
* Return of func:
|
||||
* the function returns the output message based on the input message and node_value
|
||||
*
|
||||
* \tparam EdgeType type of edge message, must be simple struct
|
||||
* \tparam NodeType type of node value
|
||||
*/
|
||||
template<typename NodeType, typename EdgeType>
|
||||
inline ReturnType MsgPassing(const NodeType &node_value,
|
||||
std::vector<EdgeType> *p_edge_in,
|
||||
std::vector<EdgeType> *p_edge_out,
|
||||
EdgeType(*func)
|
||||
(const NodeType &node_value,
|
||||
const std::vector<EdgeType> &edge_in,
|
||||
size_t out_index));
|
||||
//---- recovery data structure ----
|
||||
// the round of result buffer, used to mode the result
|
||||
int result_buffer_round_;
|
||||
// result buffer of all reduce
|
||||
ResultBuffer resbuf_;
|
||||
// current cached allreduce/braodcast sequence number
|
||||
int cur_cache_seq_;
|
||||
// result buffer of cached all reduce
|
||||
ResultBuffer cachebuf_;
|
||||
// key of each cache entry
|
||||
ResultBuffer lookupbuf_;
|
||||
// last check point global model
|
||||
std::string global_checkpoint_;
|
||||
// lazy checkpoint of global model
|
||||
const Serializable *global_lazycheck_;
|
||||
// number of replica for local state/model
|
||||
int num_local_replica_;
|
||||
// number of default local replica
|
||||
int default_local_replica_;
|
||||
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
|
||||
int use_local_model_;
|
||||
// number of replica for global state/model
|
||||
int num_global_replica_;
|
||||
// number of times recovery happens
|
||||
int recover_counter_;
|
||||
// --- recovery data structure for local checkpoint
|
||||
// there is two version of the data structure,
|
||||
// at one time one version is valid and another is used as temp memory
|
||||
// pointer to memory position in the local model
|
||||
// local model is stored in CSR format(like a sparse matrices)
|
||||
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops
|
||||
std::vector<size_t> local_rptr_[2];
|
||||
// storage for local model replicas
|
||||
std::string local_chkpt_[2];
|
||||
// version of local checkpoint can be 1 or 0
|
||||
int local_chkpt_version_;
|
||||
// if checkpoint were loaded, used to distinguish results boostrap cache from seqno cache
|
||||
bool checkpoint_loaded_;
|
||||
// sidecar executing timeout task
|
||||
std::future<bool> rabit_timeout_task_;
|
||||
// flag to shutdown rabit_timeout_task before timeout
|
||||
std::atomic<bool> shutdown_timeout_{false};
|
||||
// error handler
|
||||
void (* error_)(const char *fmt, ...) = utils::Error;
|
||||
// assert handler
|
||||
void (* assert_)(bool exp, const char *fmt, ...) = utils::Assert;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
// implementation of inline template function
|
||||
#include "./allreduce_robust-inl.h"
|
||||
#endif // RABIT_ALLREDUCE_ROBUST_H_
|
||||
@ -12,14 +12,13 @@
|
||||
#include <memory>
|
||||
#include "rabit/internal/engine.h"
|
||||
#include "allreduce_base.h"
|
||||
#include "allreduce_robust.h"
|
||||
|
||||
namespace rabit {
|
||||
namespace engine {
|
||||
// singleton sync manager
|
||||
#ifndef RABIT_USE_BASE
|
||||
#ifndef RABIT_USE_MOCK
|
||||
using Manager = AllreduceRobust;
|
||||
using Manager = AllreduceBase;
|
||||
#else
|
||||
typedef AllreduceMock Manager;
|
||||
#endif // RABIT_USE_MOCK
|
||||
|
||||
@ -809,12 +809,6 @@ class LearnerIO : public LearnerConfiguration {
|
||||
|
||||
{
|
||||
std::vector<std::string> saved_params;
|
||||
// check if rabit_bootstrap_cache were set to non zero before adding to checkpoint
|
||||
if (cfg_.find("rabit_bootstrap_cache") != cfg_.end() &&
|
||||
(cfg_.find("rabit_bootstrap_cache"))->second != "0") {
|
||||
std::copy(saved_configs_.begin(), saved_configs_.end(),
|
||||
std::back_inserter(saved_params));
|
||||
}
|
||||
for (const auto& key : saved_params) {
|
||||
auto it = cfg_.find(key);
|
||||
if (it != cfg_.end()) {
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
# Originally an example in demo/regression/
|
||||
tree_method=approx
|
||||
eta = 0.5
|
||||
gamma = 1.0
|
||||
seed = 0
|
||||
min_child_weight = 0
|
||||
max_depth = 5
|
||||
|
||||
num_round = 12
|
||||
save_period = 100
|
||||
data = "demo/data/agaricus.txt.train"
|
||||
eval[test] = "demo/data/agaricus.txt.test"
|
||||
@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source activate cpu_test
|
||||
|
||||
export DMLC_SUBMIT_CLUSTER=local
|
||||
|
||||
submit="python3 dmlc-core/tracker/dmlc-submit"
|
||||
# build xgboost with librabit mock
|
||||
# define max worker retry with dmlc-core local num atempt
|
||||
# instrument worker failure with mock=xxxx
|
||||
# check if host recovered from expectected iteration
|
||||
echo "====== 1. Fault recovery distributed test ======"
|
||||
exec $submit --cluster=local --num-workers=10 --local-num-attempt=10 $1 $2 mock=0,10,1,0 mock=1,11,1,0 mock=1,11,1,1 mock=0,11,1,0 mock=4,11,1,0 mock=9,11,1,0 mock=8,11,2,0 mock=4,11,3,0 rabit_bootstrap_cache=1 rabit_debug=1
|
||||
@ -1,53 +0,0 @@
|
||||
#define RABIT_CXXTESTDEFS_H
|
||||
#if !defined(_WIN32)
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "../../../rabit/src/allreduce_mock.h"
|
||||
|
||||
TEST(AllreduceMock, MockAllreduce)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
|
||||
std::string mock_str = "mock=0,0,0,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
EXPECT_THROW(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(AllreduceMock, MockBroadcast)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
std::string mock_str = "mock=0,1,2,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 0;
|
||||
m.version_number=1;
|
||||
m.seq_counter=2;
|
||||
EXPECT_THROW(m.Broadcast(nullptr,0,0), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(AllreduceMock, MockGather)
|
||||
{
|
||||
rabit::engine::AllreduceMock m;
|
||||
std::string mock_str = "mock=3,13,22,0";
|
||||
char cmd[mock_str.size()+1];
|
||||
std::copy(mock_str.begin(), mock_str.end(), cmd);
|
||||
cmd[mock_str.size()] = '\0';
|
||||
char* argv[] = {cmd};
|
||||
m.Init(1, argv);
|
||||
m.rank = 3;
|
||||
m.version_number=13;
|
||||
m.seq_counter=22;
|
||||
EXPECT_THROW({m.Allgather(nullptr,0,0,0,0);}, dmlc::Error);
|
||||
}
|
||||
#endif // !defined(_WIN32)
|
||||
@ -1,235 +0,0 @@
|
||||
#define RABIT_CXXTESTDEFS_H
|
||||
#if !defined(_WIN32)
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "../../../rabit/src/allreduce_robust.h"
|
||||
|
||||
inline void MockErr(const char *fmt, ...) {EXPECT_STRCASEEQ(fmt, "[%d] exit due to time out %d s\n");}
|
||||
inline void MockAssert(bool val, const char *fmt, ...) {}
|
||||
rabit::engine::AllreduceRobust::ReturnType err_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSockError);
|
||||
rabit::engine::AllreduceRobust::ReturnType succ_type(rabit::engine::AllreduceRobust::ReturnTypeEnum::kSuccess);
|
||||
|
||||
TEST(AllreduceRobust, SyncErrorTimeout)
|
||||
{
|
||||
rabit::engine::AllreduceRobust m;
|
||||
|
||||
std::string rabit_timeout = "rabit_timeout=1";
|
||||
char cmd[rabit_timeout.size()+1];
|
||||
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||
cmd[rabit_timeout.size()] = '\0';
|
||||
|
||||
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||
char cmd1[rabit_timeout_sec.size()+1];
|
||||
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd,cmd1};
|
||||
m.Init(2, argv);
|
||||
m.rank = 0;
|
||||
m.rabit_bootstrap_cache = true;
|
||||
m.error_ = MockErr;
|
||||
m.assert_ = MockAssert;
|
||||
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1500));
|
||||
EXPECT_EQ(m.rabit_timeout_task_.get(), false);
|
||||
}
|
||||
|
||||
TEST(AllreduceRobust, SyncErrorReset)
|
||||
{
|
||||
rabit::engine::AllreduceRobust m;
|
||||
|
||||
std::string rabit_timeout = "rabit_timeout=1";
|
||||
char cmd[rabit_timeout.size()+1];
|
||||
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||
cmd[rabit_timeout.size()] = '\0';
|
||||
|
||||
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||
char cmd1[rabit_timeout_sec.size()+1];
|
||||
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||
|
||||
std::string rabit_debug = "rabit_debug=1";
|
||||
char cmd2[rabit_debug.size()+1];
|
||||
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||
cmd2[rabit_debug.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd1,cmd2};
|
||||
m.Init(3, argv);
|
||||
m.rank = 0;
|
||||
m.assert_ = MockAssert;
|
||||
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||
EXPECT_EQ(m.rabit_timeout_task_.get(), true);
|
||||
m.Shutdown();
|
||||
}
|
||||
|
||||
TEST(AllreduceRobust, SyncSuccessErrorTimeout)
|
||||
{
|
||||
rabit::engine::AllreduceRobust m;
|
||||
|
||||
std::string rabit_timeout = "rabit_timeout=1";
|
||||
char cmd[rabit_timeout.size()+1];
|
||||
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||
cmd[rabit_timeout.size()] = '\0';
|
||||
|
||||
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||
char cmd1[rabit_timeout_sec.size()+1];
|
||||
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||
|
||||
std::string rabit_debug = "rabit_debug=1";
|
||||
char cmd2[rabit_debug.size()+1];
|
||||
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||
cmd2[rabit_debug.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd1,cmd2};
|
||||
m.Init(3, argv);
|
||||
m.rank = 0;
|
||||
m.rabit_bootstrap_cache = true;
|
||||
m.assert_ = MockAssert;
|
||||
m.error_ = MockErr;
|
||||
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1500));
|
||||
EXPECT_EQ(m.rabit_timeout_task_.get(), false);
|
||||
}
|
||||
|
||||
TEST(AllreduceRobust, SyncSuccessErrorSuccess)
|
||||
{
|
||||
rabit::engine::AllreduceRobust m;
|
||||
|
||||
std::string rabit_timeout = "rabit_timeout=1";
|
||||
char cmd[rabit_timeout.size()+1];
|
||||
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||
cmd[rabit_timeout.size()] = '\0';
|
||||
|
||||
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||
char cmd1[rabit_timeout_sec.size()+1];
|
||||
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||
|
||||
std::string rabit_debug = "rabit_debug=1";
|
||||
char cmd2[rabit_debug.size()+1];
|
||||
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||
cmd2[rabit_debug.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd1,cmd2};
|
||||
m.Init(3, argv);
|
||||
m.rank = 0;
|
||||
m.rabit_bootstrap_cache = true;
|
||||
m.assert_ = MockAssert;
|
||||
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
|
||||
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1100));
|
||||
EXPECT_EQ(m.rabit_timeout_task_.get(), true);
|
||||
m.Shutdown();
|
||||
}
|
||||
|
||||
TEST(AllreduceRobust, SyncErrorNoResetTimeout)
|
||||
{
|
||||
rabit::engine::AllreduceRobust m;
|
||||
|
||||
std::string rabit_timeout = "rabit_timeout=1";
|
||||
char cmd[rabit_timeout.size()+1];
|
||||
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||
cmd[rabit_timeout.size()] = '\0';
|
||||
|
||||
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||
char cmd1[rabit_timeout_sec.size()+1];
|
||||
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||
|
||||
std::string rabit_debug = "rabit_debug=1";
|
||||
char cmd2[rabit_debug.size()+1];
|
||||
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||
cmd2[rabit_debug.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd1,cmd2};
|
||||
m.Init(3, argv);
|
||||
m.rank = 0;
|
||||
m.rabit_bootstrap_cache = true;
|
||||
m.assert_ = MockAssert;
|
||||
m.error_ = MockErr;
|
||||
auto start = std::chrono::system_clock::now();
|
||||
|
||||
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1100));
|
||||
|
||||
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||
|
||||
m.rabit_timeout_task_.wait();
|
||||
auto end = std::chrono::system_clock::now();
|
||||
std::chrono::duration<double> diff = end-start;
|
||||
|
||||
EXPECT_EQ(m.rabit_timeout_task_.get(), false);
|
||||
// expect second error don't overwrite/reset timeout task
|
||||
EXPECT_LT(diff.count(), 2);
|
||||
}
|
||||
|
||||
TEST(AllreduceRobust, NoTimeoutShutDown)
|
||||
{
|
||||
rabit::engine::AllreduceRobust m;
|
||||
|
||||
std::string rabit_timeout = "rabit_timeout=1";
|
||||
char cmd[rabit_timeout.size()+1];
|
||||
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||
cmd[rabit_timeout.size()] = '\0';
|
||||
|
||||
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||
char cmd1[rabit_timeout_sec.size()+1];
|
||||
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||
|
||||
std::string rabit_debug = "rabit_debug=1";
|
||||
char cmd2[rabit_debug.size()+1];
|
||||
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||
cmd2[rabit_debug.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd1,cmd2};
|
||||
m.Init(3, argv);
|
||||
m.rank = 0;
|
||||
|
||||
EXPECT_EQ(m.CheckAndRecover(succ_type), true);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
m.Shutdown();
|
||||
}
|
||||
|
||||
TEST(AllreduceRobust, ShutDownBeforeTimeout)
|
||||
{
|
||||
rabit::engine::AllreduceRobust m;
|
||||
|
||||
std::string rabit_timeout = "rabit_timeout=1";
|
||||
char cmd[rabit_timeout.size()+1];
|
||||
std::copy(rabit_timeout.begin(), rabit_timeout.end(), cmd);
|
||||
cmd[rabit_timeout.size()] = '\0';
|
||||
|
||||
std::string rabit_timeout_sec = "rabit_timeout_sec=1";
|
||||
char cmd1[rabit_timeout_sec.size()+1];
|
||||
std::copy(rabit_timeout_sec.begin(), rabit_timeout_sec.end(), cmd1);
|
||||
cmd1[rabit_timeout_sec.size()] = '\0';
|
||||
|
||||
std::string rabit_debug = "rabit_debug=1";
|
||||
char cmd2[rabit_debug.size()+1];
|
||||
std::copy(rabit_debug.begin(), rabit_debug.end(), cmd2);
|
||||
cmd2[rabit_debug.size()] = '\0';
|
||||
|
||||
char* argv[] = {cmd, cmd1,cmd2};
|
||||
m.Init(3, argv);
|
||||
m.rank = 0;
|
||||
rabit::engine::AllreduceRobust::LinkRecord a;
|
||||
m.err_link = &a;
|
||||
|
||||
EXPECT_EQ(m.CheckAndRecover(err_type), false);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
m.Shutdown();
|
||||
}
|
||||
#endif // !defined(_WIN32)
|
||||
@ -1,8 +1,8 @@
|
||||
"""Distributed GPU tests."""
|
||||
import sys
|
||||
import time
|
||||
import xgboost as xgb
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def run_test(name, params_fun):
|
||||
@ -28,7 +28,7 @@ def run_test(name, params_fun):
|
||||
# Have each worker save its model
|
||||
model_name = "test.model.%s.%d" % (name, rank)
|
||||
bst.dump_model(model_name, with_stats=True)
|
||||
time.sleep(2)
|
||||
xgb.rabit.allreduce(np.ones((1, 1)), xgb.rabit.Op.MAX) # sync
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
if (rank == 0):
|
||||
@ -49,9 +49,6 @@ def run_test(name, params_fun):
|
||||
|
||||
xgb.rabit.finalize()
|
||||
|
||||
if os.path.exists(model_name):
|
||||
os.remove(model_name)
|
||||
|
||||
|
||||
base_params = {
|
||||
'tree_method': 'gpu_hist',
|
||||
|
||||
@ -7,6 +7,8 @@ submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit"
|
||||
|
||||
echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
|
||||
$submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py basic_1x4 || exit 1
|
||||
rm test.model.*
|
||||
|
||||
echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
|
||||
$submit --num-workers=$(nvidia-smi -L | wc -l) python distributed_gpu.py rf_1x4 || exit 1
|
||||
rm test.model.*
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user