Fix prediction heuristic (#5955)
* Relax check for prediction. * Relax test in spark test. * Add tests in C++.
This commit is contained in:
@@ -946,7 +946,7 @@ class LearnerImpl : public LearnerIO {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train.get());
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
|
||||
auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id);
|
||||
|
||||
@@ -972,7 +972,7 @@ class LearnerImpl : public LearnerIO {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train.get());
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
this->cache_.Cache(train, generic_parameters_.gpu_id);
|
||||
|
||||
gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get()));
|
||||
@@ -994,7 +994,7 @@ class LearnerImpl : public LearnerIO {
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
std::shared_ptr<DMatrix> m = data_sets[i];
|
||||
auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id);
|
||||
this->ValidateDMatrix(m.get());
|
||||
this->ValidateDMatrix(m.get(), false);
|
||||
this->PredictRaw(m.get(), &predt, false);
|
||||
|
||||
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
|
||||
@@ -1079,11 +1079,11 @@ class LearnerImpl : public LearnerIO {
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data);
|
||||
this->ValidateDMatrix(data, false);
|
||||
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
|
||||
}
|
||||
|
||||
void ValidateDMatrix(DMatrix* p_fmat) const {
|
||||
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {
|
||||
MetaInfo const& info = p_fmat->Info();
|
||||
info.Validate(generic_parameters_.gpu_id);
|
||||
|
||||
@@ -1092,8 +1092,15 @@ class LearnerImpl : public LearnerIO {
|
||||
tparam_.dsplit == DataSplitMode::kAuto;
|
||||
};
|
||||
if (row_based_split()) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in booster.";
|
||||
if (is_training) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
} else {
|
||||
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user