[Breaking] Don't drop trees during DART prediction by default (#5115)
* Simplify DropTrees calling logic * Add `training` parameter for prediction method. * [Breaking]: Add `training` to C API. * Change for R and Python custom objective. * Correct comment. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu> Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -694,7 +694,7 @@ class LearnerImpl : public Learner {
|
||||
this->ValidateDMatrix(train);
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_[train]);
|
||||
this->PredictRaw(train, &preds_[train], true);
|
||||
monitor_.Stop("PredictRaw");
|
||||
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
|
||||
|
||||
@@ -735,7 +735,7 @@ class LearnerImpl : public Learner {
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
DMatrix * dmat = data_sets[i];
|
||||
this->ValidateDMatrix(dmat);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat]);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat], false);
|
||||
obj_->EvalTransform(&preds_[dmat]);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
@@ -799,6 +799,7 @@ class LearnerImpl : public Learner {
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool training,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) override {
|
||||
int multiple_predictions = static_cast<int>(pred_leaf) +
|
||||
@@ -814,7 +815,7 @@ class LearnerImpl : public Learner {
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
this->PredictRaw(data, out_preds, training, ntree_limit);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
@@ -832,13 +833,15 @@ class LearnerImpl : public Learner {
|
||||
* \param out_preds output vector that stores the prediction
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param training allow dropout when the DART booster is being used
|
||||
*/
|
||||
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_ != nullptr)
|
||||
<< "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data);
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
|
||||
}
|
||||
|
||||
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
|
||||
|
||||
Reference in New Issue
Block a user