[Breaking] Don't drop trees during DART prediction by default (#5115)

* Simplify DropTrees calling logic

* Add `training` parameter for prediction method.

* [Breaking]: Add `training` to C API.

* Change for R and Python custom objective.

* Correct comment.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Kodi Arfer
2020-01-13 08:48:30 -05:00
committed by Jiaming Yuan
parent 7b65698187
commit f100b8d878
23 changed files with 214 additions and 140 deletions

View File

@@ -694,7 +694,7 @@ class LearnerImpl : public Learner {
this->ValidateDMatrix(train);
monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_[train]);
this->PredictRaw(train, &preds_[train], true);
monitor_.Stop("PredictRaw");
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
@@ -735,7 +735,7 @@ class LearnerImpl : public Learner {
for (size_t i = 0; i < data_sets.size(); ++i) {
DMatrix * dmat = data_sets[i];
this->ValidateDMatrix(dmat);
this->PredictRaw(data_sets[i], &preds_[dmat]);
this->PredictRaw(data_sets[i], &preds_[dmat], false);
obj_->EvalTransform(&preds_[dmat]);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
@@ -799,6 +799,7 @@ class LearnerImpl : public Learner {
void Predict(DMatrix* data, bool output_margin,
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
bool training,
bool pred_leaf, bool pred_contribs, bool approx_contribs,
bool pred_interactions) override {
int multiple_predictions = static_cast<int>(pred_leaf) +
@@ -814,7 +815,7 @@ class LearnerImpl : public Learner {
} else if (pred_leaf) {
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
} else {
this->PredictRaw(data, out_preds, ntree_limit);
this->PredictRaw(data, out_preds, training, ntree_limit);
if (!output_margin) {
obj_->PredTransform(out_preds);
}
@@ -832,13 +833,15 @@ class LearnerImpl : public Learner {
* \param out_preds output vector that stores the prediction
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
* \param training allow dropout when the DART booster is being used
*/
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
bool training,
unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr)
<< "Predict must happen after Load or configuration";
this->ValidateDMatrix(data);
gbm_->PredictBatch(data, out_preds, ntree_limit);
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
}
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {