[rabit_bootstrap_cache ] failed xgb worker recover from other workers (#4808)

* Better recovery support.  Restarting only the failed workers.
This commit is contained in:
Chen Qin
2019-09-16 20:31:52 -07:00
committed by Jiaming Yuan
parent c89bcc4de5
commit 512f037e55
11 changed files with 111 additions and 14 deletions

View File

@@ -272,6 +272,9 @@ class LearnerImpl : public Learner {
kv.second = "cpu_predictor";
LOG(INFO) << "Switch gpu_predictor to cpu_predictor.";
}
if (saved_configs_.find(saved_param) != saved_configs_.end()) {
cfg_[saved_param] = kv.second;
}
}
}
attributes_ = std::map<std::string, std::string>(attr.begin(), attr.end());
@@ -304,6 +307,10 @@ class LearnerImpl : public Learner {
p_metric->Configure({cfg_.begin(), cfg_.end()});
}
// copy dsplit from config since it will not run again during restore
if (tparam_.dsplit == DataSplitMode::kAuto && rabit::IsDistributed()) {
tparam_.dsplit = DataSplitMode::kRow;
}
this->configured_ = true;
}
@@ -334,8 +341,15 @@ class LearnerImpl : public Learner {
}
}
{
// Write `predictor`, `gpu_id` parameters as extra attributes
for (const auto& key : std::vector<std::string>{"predictor", "gpu_id"}) {
std::vector<std::string> saved_params{"predictor", "gpu_id"};
// check if rabit_bootstrap_cache were set to non zero before adding to checkpoint
if (cfg_.find("rabit_bootstrap_cache") != cfg_.end() &&
(cfg_.find("rabit_bootstrap_cache"))->second != "0") {
std::copy(saved_configs_.begin(), saved_configs_.end(),
std::back_inserter(saved_params));
}
// Write `predictor`, `n_gpus`, `gpu_id` parameters as extra attributes
for (const auto& key : saved_params) {
auto it = cfg_.find(key);
if (it != cfg_.end()) {
mparam.contain_extra_attrs = 1;
@@ -603,7 +617,7 @@ class LearnerImpl : public Learner {
num_feature = std::max(num_feature, static_cast<unsigned>(num_col));
}
// run allreduce on num_feature to find the maximum value
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature");
if (num_feature > mparam_.num_feature) {
mparam_.num_feature = num_feature;
}
@@ -650,6 +664,10 @@ class LearnerImpl : public Learner {
std::vector<std::shared_ptr<DMatrix> > cache_;
common::Monitor monitor_;
/*! \brief saved config keys used to restore failed worker */
std::set<std::string> saved_configs_ = {"max_depth", "tree_method", "dsplit",
"seed", "silent", "num_round", "gamma", "min_child_weight"};
};
std::string const LearnerImpl::kEvalMetric {"eval_metric"}; // NOLINT