Avoid rabit calls in learner configuration (#5581)
This commit is contained in:
parent
92913aaf7f
commit
660be66207
@ -482,22 +482,26 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
|
||||
void ConfigureNumFeatures() {
|
||||
// estimate feature bound
|
||||
// TODO(hcho3): Change num_feature to 64-bit integer
|
||||
unsigned num_feature = 0;
|
||||
for (auto & matrix : cache_.Container()) {
|
||||
CHECK(matrix.first);
|
||||
CHECK(!matrix.second.ref.expired());
|
||||
const uint64_t num_col = matrix.first->Info().num_col_;
|
||||
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||
}
|
||||
// run allreduce on num_feature to find the maximum value
|
||||
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature");
|
||||
if (num_feature > mparam_.num_feature) {
|
||||
mparam_.num_feature = num_feature;
|
||||
// Compute number of global features if parameter not already set
|
||||
if (mparam_.num_feature == 0) {
|
||||
// TODO(hcho3): Change num_feature to 64-bit integer
|
||||
unsigned num_feature = 0;
|
||||
for (auto& matrix : cache_.Container()) {
|
||||
CHECK(matrix.first);
|
||||
CHECK(!matrix.second.ref.expired());
|
||||
const uint64_t num_col = matrix.first->Info().num_col_;
|
||||
CHECK_LE(num_col,
|
||||
static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||
}
|
||||
|
||||
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr,
|
||||
"num_feature");
|
||||
if (num_feature > mparam_.num_feature) {
|
||||
mparam_.num_feature = num_feature;
|
||||
}
|
||||
}
|
||||
CHECK_NE(mparam_.num_feature, 0)
|
||||
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||
|
||||
@ -20,8 +20,9 @@ num_round = 20
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||
|
||||
# Save the model, only ask process 0 to save the model.
|
||||
bst.save_model("test.model{}".format(xgb.rabit.get_rank()))
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
bst.save_model("test.model")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
|
||||
@ -70,8 +70,9 @@ watchlist = [(dtrain,'train')]
|
||||
num_round = 2
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
|
||||
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank()))
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
bst.save_model("test_issue3402.model")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user