An option for doing binomial+1 or epsilon-dropout from DART paper (#1922)

* An option for doing binomial+1 or epsilon-dropout from DART paper

* use callback-based discrete_distribution to make MSVC2013 happy
This commit is contained in:
Vadim Khotilovich 2017-01-05 18:23:22 -06:00 committed by Tianqi Chen
parent ce84af7923
commit d23ea5ca7d
2 changed files with 29 additions and 6 deletions

View File

@ -110,11 +110,14 @@ Additional parameters for Dart Booster
- weight of new trees are 1 / (1 + learning_rate)
- dropped trees are scaled by a factor of 1 / (1 + learning_rate)
* rate_drop [default=0.0]
- dropout rate.
- dropout rate (a fraction of previous trees to drop during the dropout).
- range: [0.0, 1.0]
* one_drop [default=0]
- when this flag is enabled, at least one tree is always dropped during the dropout (allows Binomial-plus-one or epsilon-dropout from the original DART paper).
* skip_drop [default=0.0]
- probability of skip dropout.
- Probability of skipping the dropout procedure during a boosting iteration.
- If a dropout is skipped, new trees are added in the same manner as gbtree.
- Note that non-zero skip_drop has higher priority than rate_drop or one_drop.
- range: [0.0, 1.0]
Parameters for Linear Booster

View File

@ -72,9 +72,11 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
int sample_type;
/*! \brief type of normalization algorithm */
int normalize_type;
/*! \brief how many trees are dropped */
/*! \brief fraction of trees to drop during the dropout */
float rate_drop;
/*! \brief whether to drop trees */
/*! \brief whether at least one tree should always be dropped during the dropout */
bool one_drop;
/*! \brief probability of skipping the dropout during an iteration */
float skip_drop;
/*! \brief learning step size for a time */
float learning_rate;
@ -96,11 +98,14 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
DMLC_DECLARE_FIELD(rate_drop)
.set_range(0.0f, 1.0f)
.set_default(0.0f)
.describe("Parameter of how many trees are dropped.");
.describe("Fraction of trees to drop during the dropout.");
DMLC_DECLARE_FIELD(one_drop)
.set_default(false)
.describe("Whether at least one tree should always be dropped during the dropout.");
DMLC_DECLARE_FIELD(skip_drop)
.set_range(0.0f, 1.0f)
.set_default(0.0f)
.describe("Parameter of whether to drop trees.");
.describe("Probability of skipping the dropout during a boosting iteration.");
DMLC_DECLARE_FIELD(learning_rate)
.set_lower_bound(0.0f)
.set_default(0.3f)
@ -658,12 +663,27 @@ class Dart : public GBTree {
idx_drop.push_back(i);
}
}
if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) {
// the expression below is an ugly but MSVC2013-friendly equivalent of
// size_t i = std::discrete_distribution<size_t>(weight_drop.begin(),
// weight_drop.end())(rnd);
size_t i = std::discrete_distribution<size_t>(
weight_drop.size(), 0., static_cast<double>(weight_drop.size()),
[this](double x) -> double {
return weight_drop[static_cast<size_t>(x)];
})(rnd);
idx_drop.push_back(i);
}
} else {
for (size_t i = 0; i < weight_drop.size(); ++i) {
if (runif(rnd) < dparam.rate_drop) {
idx_drop.push_back(i);
}
}
if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) {
size_t i = std::uniform_int_distribution<size_t>(0, weight_drop.size() - 1)(rnd);
idx_drop.push_back(i);
}
}
}
}