diff --git a/doc/parameter.md b/doc/parameter.md index 8a1fbcbf3..ef10e0499 100644 --- a/doc/parameter.md +++ b/doc/parameter.md @@ -110,11 +110,14 @@ Additional parameters for Dart Booster - weight of new trees are 1 / (1 + learning_rate) - dropped trees are scaled by a factor of 1 / (1 + learning_rate) * rate_drop [default=0.0] - - dropout rate. + - dropout rate (a fraction of previous trees to drop during the dropout). - range: [0.0, 1.0] +* one_drop [default=0] + - when this flag is enabled, at least one tree is always dropped during the dropout (allows Binomial-plus-one or epsilon-dropout from the original DART paper). * skip_drop [default=0.0] - - probability of skip dropout. + - Probability of skipping the dropout procedure during a boosting iteration. - If a dropout is skipped, new trees are added in the same manner as gbtree. + - Note that non-zero skip_drop has higher priority than rate_drop or one_drop. - range: [0.0, 1.0] Parameters for Linear Booster diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 266dba178..3d8c2a9db 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -72,9 +72,11 @@ struct DartTrainParam : public dmlc::Parameter { int sample_type; /*! \brief type of normalization algorithm */ int normalize_type; - /*! \brief how many trees are dropped */ + /*! \brief fraction of trees to drop during the dropout */ float rate_drop; - /*! \brief whether to drop trees */ + /*! \brief whether at least one tree should always be dropped during the dropout */ + bool one_drop; + /*! \brief probability of skipping the dropout during an iteration */ float skip_drop; /*! \brief learning step size for a time */ float learning_rate; @@ -96,11 +98,14 @@ struct DartTrainParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(rate_drop) .set_range(0.0f, 1.0f) .set_default(0.0f) - .describe("Parameter of how many trees are dropped."); + .describe("Fraction of trees to drop during the dropout."); + DMLC_DECLARE_FIELD(one_drop) + .set_default(false) + .describe("Whether at least one tree should always be dropped during the dropout."); DMLC_DECLARE_FIELD(skip_drop) .set_range(0.0f, 1.0f) .set_default(0.0f) - .describe("Parameter of whether to drop trees."); + .describe("Probability of skipping the dropout during a boosting iteration."); DMLC_DECLARE_FIELD(learning_rate) .set_lower_bound(0.0f) .set_default(0.3f) @@ -658,12 +663,27 @@ class Dart : public GBTree { idx_drop.push_back(i); } } + if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) { + // the expression below is an ugly but MSVC2013-friendly equivalent of + // size_t i = std::discrete_distribution(weight_drop.begin(), + // weight_drop.end())(rnd); + size_t i = std::discrete_distribution( + weight_drop.size(), 0., static_cast(weight_drop.size()), + [this](double x) -> double { + return weight_drop[static_cast(x)]; + })(rnd); + idx_drop.push_back(i); + } } else { for (size_t i = 0; i < weight_drop.size(); ++i) { if (runif(rnd) < dparam.rate_drop) { idx_drop.push_back(i); } } + if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) { + size_t i = std::uniform_int_distribution(0, weight_drop.size() - 1)(rnd); + idx_drop.push_back(i); + } } } }