An option for doing binomial+1 or epsilon-dropout from DART paper (#1922)
* An option for doing binomial+1 or epsilon-dropout from DART paper * use callback-based discrete_distribution to make MSVC2013 happy
This commit is contained in:
committed by
Tianqi Chen
parent
ce84af7923
commit
d23ea5ca7d
@@ -72,9 +72,11 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
||||
int sample_type;
|
||||
/*! \brief type of normalization algorithm */
|
||||
int normalize_type;
|
||||
/*! \brief how many trees are dropped */
|
||||
/*! \brief fraction of trees to drop during the dropout */
|
||||
float rate_drop;
|
||||
/*! \brief whether to drop trees */
|
||||
/*! \brief whether at least one tree should always be dropped during the dropout */
|
||||
bool one_drop;
|
||||
/*! \brief probability of skipping the dropout during an iteration */
|
||||
float skip_drop;
|
||||
/*! \brief learning step size for a time */
|
||||
float learning_rate;
|
||||
@@ -96,11 +98,14 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
|
||||
DMLC_DECLARE_FIELD(rate_drop)
|
||||
.set_range(0.0f, 1.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("Parameter of how many trees are dropped.");
|
||||
.describe("Fraction of trees to drop during the dropout.");
|
||||
DMLC_DECLARE_FIELD(one_drop)
|
||||
.set_default(false)
|
||||
.describe("Whether at least one tree should always be dropped during the dropout.");
|
||||
DMLC_DECLARE_FIELD(skip_drop)
|
||||
.set_range(0.0f, 1.0f)
|
||||
.set_default(0.0f)
|
||||
.describe("Parameter of whether to drop trees.");
|
||||
.describe("Probability of skipping the dropout during a boosting iteration.");
|
||||
DMLC_DECLARE_FIELD(learning_rate)
|
||||
.set_lower_bound(0.0f)
|
||||
.set_default(0.3f)
|
||||
@@ -658,12 +663,27 @@ class Dart : public GBTree {
|
||||
idx_drop.push_back(i);
|
||||
}
|
||||
}
|
||||
if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) {
|
||||
// the expression below is an ugly but MSVC2013-friendly equivalent of
|
||||
// size_t i = std::discrete_distribution<size_t>(weight_drop.begin(),
|
||||
// weight_drop.end())(rnd);
|
||||
size_t i = std::discrete_distribution<size_t>(
|
||||
weight_drop.size(), 0., static_cast<double>(weight_drop.size()),
|
||||
[this](double x) -> double {
|
||||
return weight_drop[static_cast<size_t>(x)];
|
||||
})(rnd);
|
||||
idx_drop.push_back(i);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < weight_drop.size(); ++i) {
|
||||
if (runif(rnd) < dparam.rate_drop) {
|
||||
idx_drop.push_back(i);
|
||||
}
|
||||
}
|
||||
if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) {
|
||||
size_t i = std::uniform_int_distribution<size_t>(0, weight_drop.size() - 1)(rnd);
|
||||
idx_drop.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user