Fix prediction heuristic (#5955)

* Relax check for prediction.
* Relax test in spark test.
* Add tests in C++.
This commit is contained in:
Jiaming Yuan
2020-07-29 19:24:07 +08:00
committed by GitHub
parent 5879acde9a
commit 75b8c22b0b
11 changed files with 103 additions and 28 deletions

View File

@@ -946,7 +946,7 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train.get());
this->ValidateDMatrix(train.get(), true);
auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id);
@@ -972,7 +972,7 @@ class LearnerImpl : public LearnerIO {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train.get());
this->ValidateDMatrix(train.get(), true);
this->cache_.Cache(train, generic_parameters_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get()));
@@ -994,7 +994,7 @@ class LearnerImpl : public LearnerIO {
for (size_t i = 0; i < data_sets.size(); ++i) {
std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id);
this->ValidateDMatrix(m.get());
this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false);
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
@@ -1079,11 +1079,11 @@ class LearnerImpl : public LearnerIO {
bool training,
unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
this->ValidateDMatrix(data);
this->ValidateDMatrix(data, false);
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
}
void ValidateDMatrix(DMatrix* p_fmat) const {
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {
MetaInfo const& info = p_fmat->Info();
info.Validate(generic_parameters_.gpu_id);
@@ -1092,8 +1092,15 @@ class LearnerImpl : public LearnerIO {
tparam_.dsplit == DataSplitMode::kAuto;
};
if (row_based_split()) {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in booster.";
if (is_training) {
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
} else {
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
<< "Number of columns does not match number of features in "
"booster.";
}
}
}