Check number of trees in inplace predict. (#7409)
This commit is contained in:
parent
97d7582457
commit
ca6f980932
@ -273,6 +273,7 @@ class GBTree : public GradientBooster {
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) =
|
||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
||||
std::vector<Predictor const *> predictors{
|
||||
cpu_predictor_.get(),
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@ -452,4 +452,47 @@ TEST(GBTree, FeatureScore) {
|
||||
test_eq("gain");
|
||||
test_eq("cover");
|
||||
}
|
||||
|
||||
TEST(GBTree, PredictRange) {
|
||||
size_t n_samples = 1000, n_features = 10, n_classes = 4;
|
||||
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
|
||||
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
learner->SetParam("num_class", std::to_string(n_classes));
|
||||
|
||||
learner->Configure();
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
learner->UpdateOneIter(i, m);
|
||||
}
|
||||
HostDeviceVector<float> out_predt;
|
||||
ASSERT_THROW(learner->Predict(m, false, &out_predt, 0, 3), dmlc::Error);
|
||||
|
||||
auto m_1 =
|
||||
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
|
||||
HostDeviceVector<float> out_predt_full;
|
||||
learner->Predict(m_1, false, &out_predt_full, 0, 0);
|
||||
ASSERT_TRUE(std::equal(out_predt.HostVector().begin(), out_predt.HostVector().end(),
|
||||
out_predt_full.HostVector().begin()));
|
||||
|
||||
{
|
||||
// inplace predict
|
||||
HostDeviceVector<float> raw_storage;
|
||||
auto raw = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateArrayInterface(&raw_storage);
|
||||
std::shared_ptr<data::ArrayAdapter> x{new data::ArrayAdapter{StringView{raw}}};
|
||||
|
||||
HostDeviceVector<float>* out_predt;
|
||||
learner->InplacePredict(x, nullptr, PredictionType::kValue,
|
||||
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 2);
|
||||
auto h_out_predt = out_predt->HostVector();
|
||||
learner->InplacePredict(x, nullptr, PredictionType::kValue,
|
||||
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 0);
|
||||
auto h_out_predt_full = out_predt->HostVector();
|
||||
|
||||
ASSERT_TRUE(std::equal(h_out_predt.begin(), h_out_predt.end(), h_out_predt_full.begin()));
|
||||
|
||||
ASSERT_THROW(learner->InplacePredict(x, nullptr, PredictionType::kValue,
|
||||
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 3),
|
||||
dmlc::Error);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user