* Fix #3485, #3540: Don't use dropout for predicting test sets Dropout (for DART) should only be used at training time. * Add regression test
This commit is contained in:
parent
109473dae2
commit
44811f2330
@ -76,11 +76,14 @@ class GradientBooster {
|
|||||||
* \brief generate predictions for given feature matrix
|
* \brief generate predictions for given feature matrix
|
||||||
* \param dmat feature matrix
|
* \param dmat feature matrix
|
||||||
* \param out_preds output vector to hold the predictions
|
* \param out_preds output vector to hold the predictions
|
||||||
|
* \param dropout whether dropout should be applied to prediction
|
||||||
|
* This option is only meaningful if booster='dart'; otherwise ignored.
|
||||||
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||||
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
||||||
*/
|
*/
|
||||||
virtual void PredictBatch(DMatrix* dmat,
|
virtual void PredictBatch(DMatrix* dmat,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
|
bool dropout = true,
|
||||||
unsigned ntree_limit = 0) = 0;
|
unsigned ntree_limit = 0) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief online prediction function, predict score for one instance at a time
|
* \brief online prediction function, predict score for one instance at a time
|
||||||
|
|||||||
@ -103,6 +103,7 @@ class GBLinear : public GradientBooster {
|
|||||||
|
|
||||||
void PredictBatch(DMatrix *p_fmat,
|
void PredictBatch(DMatrix *p_fmat,
|
||||||
HostDeviceVector<bst_float> *out_preds,
|
HostDeviceVector<bst_float> *out_preds,
|
||||||
|
bool dropout,
|
||||||
unsigned ntree_limit) override {
|
unsigned ntree_limit) override {
|
||||||
monitor_.Start("PredictBatch");
|
monitor_.Start("PredictBatch");
|
||||||
CHECK_EQ(ntree_limit, 0U)
|
CHECK_EQ(ntree_limit, 0U)
|
||||||
|
|||||||
@ -217,6 +217,7 @@ class GBTree : public GradientBooster {
|
|||||||
|
|
||||||
void PredictBatch(DMatrix* p_fmat,
|
void PredictBatch(DMatrix* p_fmat,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
|
bool dropout,
|
||||||
unsigned ntree_limit) override {
|
unsigned ntree_limit) override {
|
||||||
predictor_->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
predictor_->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||||
}
|
}
|
||||||
@ -356,8 +357,11 @@ class Dart : public GBTree {
|
|||||||
// predict the leaf scores with dropout if ntree_limit = 0
|
// predict the leaf scores with dropout if ntree_limit = 0
|
||||||
void PredictBatch(DMatrix* p_fmat,
|
void PredictBatch(DMatrix* p_fmat,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
|
bool dropout,
|
||||||
unsigned ntree_limit) override {
|
unsigned ntree_limit) override {
|
||||||
DropTrees(ntree_limit);
|
if (dropout) {
|
||||||
|
DropTrees(ntree_limit);
|
||||||
|
}
|
||||||
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
|
PredLoopInternal<Dart>(p_fmat, &out_preds->HostVector(), 0, ntree_limit, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -469,7 +469,7 @@ class LearnerImpl : public Learner {
|
|||||||
} else if (pred_leaf) {
|
} else if (pred_leaf) {
|
||||||
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
|
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
|
||||||
} else {
|
} else {
|
||||||
this->PredictRaw(data, out_preds, ntree_limit);
|
this->PredictRaw(data, out_preds, false, ntree_limit);
|
||||||
if (!output_margin) {
|
if (!output_margin) {
|
||||||
obj_->PredTransform(out_preds);
|
obj_->PredTransform(out_preds);
|
||||||
}
|
}
|
||||||
@ -560,14 +560,16 @@ class LearnerImpl : public Learner {
|
|||||||
* \brief get un-transformed prediction
|
* \brief get un-transformed prediction
|
||||||
* \param data training data matrix
|
* \param data training data matrix
|
||||||
* \param out_preds output vector that stores the prediction
|
* \param out_preds output vector that stores the prediction
|
||||||
|
* \param dropout whether dropout should be applied to prediction.
|
||||||
|
* This option is only meaningful if booster='dart'; otherwise ignored.
|
||||||
* \param ntree_limit limit number of trees used for boosted tree
|
* \param ntree_limit limit number of trees used for boosted tree
|
||||||
* predictor, when it equals 0, this means we are using all the trees
|
* predictor, when it equals 0, this means we are using all the trees
|
||||||
*/
|
*/
|
||||||
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||||
unsigned ntree_limit = 0) const {
|
bool dropout = true, unsigned ntree_limit = 0) const {
|
||||||
CHECK(gbm_ != nullptr)
|
CHECK(gbm_ != nullptr)
|
||||||
<< "Predict must happen after Load or InitModel";
|
<< "Predict must happen after Load or InitModel";
|
||||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
gbm_->PredictBatch(data, out_preds, dropout, ntree_limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
// model parameter
|
// model parameter
|
||||||
|
|||||||
@ -48,6 +48,13 @@ class TestModels(unittest.TestCase):
|
|||||||
preds2 = bst2.predict(dtest2, ntree_limit=num_round)
|
preds2 = bst2.predict(dtest2, ntree_limit=num_round)
|
||||||
# assert they are the same
|
# assert they are the same
|
||||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||||
|
# regression test for issues #3485, #3540
|
||||||
|
for _ in range(10):
|
||||||
|
bst3 = xgb.Booster(params=param, model_file='xgb.model.dart')
|
||||||
|
dtest3 = xgb.DMatrix('dtest.buffer')
|
||||||
|
preds3 = bst3.predict(dtest3)
|
||||||
|
# assert they are the same
|
||||||
|
assert np.sum(np.abs(preds3 - preds)) == 0, 'preds3 = {}, preds = {}'.format(preds3, preds)
|
||||||
|
|
||||||
# check whether sample_type and normalize_type work
|
# check whether sample_type and normalize_type work
|
||||||
num_round = 50
|
num_round = 50
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user