Revert "Fix #3485, #3540: Don't use dropout for predicting test sets" (#3563)

* Revert "Fix #3485, #3540: Don't use dropout for predicting test sets (#3556)"

This reverts commit 44811f2330.

* Document behavior of predict() for DART booster

* Add notice to parameter.rst
This commit is contained in:
Philip Hyunsu Cho
2018-08-08 09:48:55 -07:00
committed by GitHub
parent e3e776bd58
commit 3c72654e3b
9 changed files with 61 additions and 30 deletions

View File

@@ -103,7 +103,6 @@ class GBLinear : public GradientBooster {
void PredictBatch(DMatrix *p_fmat,
HostDeviceVector<bst_float> *out_preds,
bool dropout,
unsigned ntree_limit) override {
monitor_.Start("PredictBatch");
CHECK_EQ(ntree_limit, 0U)

View File

@@ -217,7 +217,6 @@ class GBTree : public GradientBooster {
void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
bool dropout,
unsigned ntree_limit) override {
predictor_->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
}
@@ -357,11 +356,8 @@ class Dart : public GBTree {
// predict the leaf scores with dropout if ntree_limit = 0
void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
bool dropout,
unsigned ntree_limit) override {
if (dropout) {
DropTrees(ntree_limit);
}
DropTrees(ntree_limit);
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
}

View File

@@ -469,7 +469,7 @@ class LearnerImpl : public Learner {
} else if (pred_leaf) {
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
} else {
this->PredictRaw(data, out_preds, false, ntree_limit);
this->PredictRaw(data, out_preds, ntree_limit);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
@@ -560,16 +560,14 @@ class LearnerImpl : public Learner {
* \brief get un-transformed prediction
* \param data training data matrix
* \param out_preds output vector that stores the prediction
* \param dropout whether dropout should be applied to prediction.
* This option is only meaningful if booster='dart'; otherwise ignored.
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
*/
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
bool dropout = true, unsigned ntree_limit = 0) const {
unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr)
<< "Predict must happen after Load or InitModel";
gbm_->PredictBatch(data, out_preds, dropout, ntree_limit);
gbm_->PredictBatch(data, out_preds, ntree_limit);
}
// model parameter