Avoid rabit calls in learner configuration (#5581)
This commit is contained in:
parent
92913aaf7f
commit
660be66207
@ -482,22 +482,26 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ConfigureNumFeatures() {
|
void ConfigureNumFeatures() {
|
||||||
// estimate feature bound
|
// Compute number of global features if parameter not already set
|
||||||
// TODO(hcho3): Change num_feature to 64-bit integer
|
if (mparam_.num_feature == 0) {
|
||||||
unsigned num_feature = 0;
|
// TODO(hcho3): Change num_feature to 64-bit integer
|
||||||
for (auto & matrix : cache_.Container()) {
|
unsigned num_feature = 0;
|
||||||
CHECK(matrix.first);
|
for (auto& matrix : cache_.Container()) {
|
||||||
CHECK(!matrix.second.ref.expired());
|
CHECK(matrix.first);
|
||||||
const uint64_t num_col = matrix.first->Info().num_col_;
|
CHECK(!matrix.second.ref.expired());
|
||||||
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
const uint64_t num_col = matrix.first->Info().num_col_;
|
||||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
CHECK_LE(num_col,
|
||||||
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||||
}
|
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
||||||
// run allreduce on num_feature to find the maximum value
|
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||||
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr, "num_feature");
|
}
|
||||||
if (num_feature > mparam_.num_feature) {
|
|
||||||
mparam_.num_feature = num_feature;
|
rabit::Allreduce<rabit::op::Max>(&num_feature, 1, nullptr, nullptr,
|
||||||
|
"num_feature");
|
||||||
|
if (num_feature > mparam_.num_feature) {
|
||||||
|
mparam_.num_feature = num_feature;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
CHECK_NE(mparam_.num_feature, 0)
|
CHECK_NE(mparam_.num_feature, 0)
|
||||||
<< "0 feature is supplied. Are you using raw Booster interface?";
|
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||||
|
|||||||
@ -20,8 +20,9 @@ num_round = 20
|
|||||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||||
|
|
||||||
# Save the model, only ask process 0 to save the model.
|
# Save the model, only ask process 0 to save the model.
|
||||||
bst.save_model("test.model{}".format(xgb.rabit.get_rank()))
|
if xgb.rabit.get_rank() == 0:
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
bst.save_model("test.model")
|
||||||
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
# Notify the tracker all training has been successful
|
# Notify the tracker all training has been successful
|
||||||
# This is only needed in distributed training.
|
# This is only needed in distributed training.
|
||||||
|
|||||||
@ -70,8 +70,9 @@ watchlist = [(dtrain,'train')]
|
|||||||
num_round = 2
|
num_round = 2
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
|
||||||
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank()))
|
if xgb.rabit.get_rank() == 0:
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
bst.save_model("test_issue3402.model")
|
||||||
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
# Notify the tracker all training has been successful
|
# Notify the tracker all training has been successful
|
||||||
# This is only needed in distributed training.
|
# This is only needed in distributed training.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user