Avoid rabit calls in learner configuration (#5581)

This commit is contained in:
Rory Mitchell 2020-04-24 14:59:29 +12:00 committed by GitHub
parent 92913aaf7f
commit 660be66207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 20 deletions

View File

@ -482,22 +482,26 @@ class LearnerConfiguration : public Learner {
} }
void ConfigureNumFeatures() { void ConfigureNumFeatures() {
// estimate feature bound // Compute number of global features if parameter not already set
// TODO(hcho3): Change num_feature to 64-bit integer if (mparam_.num_feature == 0) {
unsigned num_feature = 0; // TODO(hcho3): Change num_feature to 64-bit integer
for (auto & matrix : cache_.Container()) { unsigned num_feature = 0;
CHECK(matrix.first); for (auto& matrix : cache_.Container()) {
CHECK(!matrix.second.ref.expired()); CHECK(matrix.first);
const uint64_t num_col = matrix.first->Info().num_col_; CHECK(!matrix.second.ref.expired());
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max())) const uint64_t num_col = matrix.first->Info().num_col_;
<< "Unfortunately, XGBoost does not support data matrices with " CHECK_LE(num_col,
<< std::numeric_limits<unsigned>::max() << " features or greater"; static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col)); << "Unfortunately, XGBoost does not support data matrices with "
} << std::numeric_limits<unsigned>::max() << " features or greater";
// run allreduce on num_feature to find the maximum value num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature"); }
if (num_feature > mparam_.num_feature) {
mparam_.num_feature = num_feature; rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr,
"num_feature");
if (num_feature > mparam_.num_feature) {
mparam_.num_feature = num_feature;
}
} }
CHECK_NE(mparam_.num_feature, 0) CHECK_NE(mparam_.num_feature, 0)
<< "0 feature is supplied. Are you using raw Booster interface?"; << "0 feature is supplied. Are you using raw Booster interface?";

View File

@ -20,8 +20,9 @@ num_round = 20
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model. # Save the model, only ask process 0 to save the model.
bst.save_model("test.model{}".format(xgb.rabit.get_rank())) if xgb.rabit.get_rank() == 0:
xgb.rabit.tracker_print("Finished training\n") bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful # Notify the tracker all training has been successful
# This is only needed in distributed training. # This is only needed in distributed training.

View File

@ -70,8 +70,9 @@ watchlist = [(dtrain,'train')]
num_round = 2 num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist) bst = xgb.train(param, dtrain, num_round, watchlist)
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank())) if xgb.rabit.get_rank() == 0:
xgb.rabit.tracker_print("Finished training\n") bst.save_model("test_issue3402.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful # Notify the tracker all training has been successful
# This is only needed in distributed training. # This is only needed in distributed training.