Avoid rabit calls in learner configuration (#5581)

This commit is contained in:
Rory Mitchell 2020-04-24 14:59:29 +12:00 committed by GitHub
parent 92913aaf7f
commit 660be66207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 20 deletions

View File

@ -482,23 +482,27 @@ class LearnerConfiguration : public Learner {
}
void ConfigureNumFeatures() {
// estimate feature bound
// Compute number of global features if parameter not already set
if (mparam_.num_feature == 0) {
// TODO(hcho3): Change num_feature to 64-bit integer
unsigned num_feature = 0;
for (auto& matrix : cache_.Container()) {
CHECK(matrix.first);
CHECK(!matrix.second.ref.expired());
const uint64_t num_col = matrix.first->Info().num_col_;
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
CHECK_LE(num_col,
static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<unsigned>::max() << " features or greater";
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
}
// run allreduce on num_feature to find the maximum value
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature");
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr,
"num_feature");
if (num_feature > mparam_.num_feature) {
mparam_.num_feature = num_feature;
}
}
CHECK_NE(mparam_.num_feature, 0)
<< "0 feature is supplied. Are you using raw Booster interface?";
// Remove these once binary IO is gone.

View File

@ -20,7 +20,8 @@ num_round = 20
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
bst.save_model("test.model{}".format(xgb.rabit.get_rank()))
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful

View File

@ -70,7 +70,8 @@ watchlist = [(dtrain,'train')]
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank()))
if xgb.rabit.get_rank() == 0:
bst.save_model("test_issue3402.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful