Remove duplicated learning rate parameter. (#8941)

This commit is contained in:
Jiaming Yuan 2023-03-22 20:51:14 +08:00 committed by GitHub
parent a05799ed39
commit a551bed803
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 9 deletions

View File

@ -996,8 +996,9 @@ class Dart : public GBTree {
} }
// set normalization factors // set normalization factors
inline size_t NormalizeTrees(size_t size_new_trees) { std::size_t NormalizeTrees(size_t size_new_trees) {
float lr = 1.0 * dparam_.learning_rate / size_new_trees; CHECK(tree_param_.GetInitialised());
float lr = 1.0 * tree_param_.learning_rate / size_new_trees;
size_t num_drop = idx_drop_.size(); size_t num_drop = idx_drop_.size();
if (num_drop == 0) { if (num_drop == 0) {
for (size_t i = 0; i < size_new_trees; ++i) { for (size_t i = 0; i < size_new_trees; ++i) {

View File

@ -111,8 +111,6 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
bool one_drop; bool one_drop;
/*! \brief probability of skipping the dropout during an iteration */ /*! \brief probability of skipping the dropout during an iteration */
float skip_drop; float skip_drop;
/*! \brief learning step size for a time */
float learning_rate;
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(DartTrainParam) { DMLC_DECLARE_PARAMETER(DartTrainParam) {
DMLC_DECLARE_FIELD(sample_type) DMLC_DECLARE_FIELD(sample_type)
@ -136,11 +134,6 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
.set_range(0.0f, 1.0f) .set_range(0.0f, 1.0f)
.set_default(0.0f) .set_default(0.0f)
.describe("Probability of skipping the dropout during a boosting iteration."); .describe("Probability of skipping the dropout during a boosting iteration.");
DMLC_DECLARE_FIELD(learning_rate)
.set_lower_bound(0.0f)
.set_default(0.3f)
.describe("Learning rate(step size) of update.");
DMLC_DECLARE_ALIAS(learning_rate, eta);
} }
}; };