Add tests for prediction cache. (#7650)
* Extract the test from approx for other tree methods. * Add note on how it works.
This commit is contained in:
@@ -75,58 +75,5 @@ TEST(Approx, Partitioner) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Approx, PredictionCache) {
|
||||
size_t n_samples = 2048, n_features = 13;
|
||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||
|
||||
{
|
||||
omp_set_num_threads(1);
|
||||
GenericParameter ctx;
|
||||
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
|
||||
std::unique_ptr<TreeUpdater> approx{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
RegTree tree;
|
||||
std::vector<RegTree *> trees{&tree};
|
||||
auto gpair = GenerateRandomGradients(n_samples);
|
||||
approx->Configure(Args{{"max_bin", "64"}});
|
||||
approx->Update(&gpair, Xy.get(), trees);
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
out_prediction_cached.Resize(n_samples);
|
||||
auto cache = linalg::VectorView<float>{
|
||||
out_prediction_cached.HostSpan(), {out_prediction_cached.Size()}, GenericParameter::kCpuId};
|
||||
ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), cache));
|
||||
}
|
||||
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
learner->SetParam("tree_method", "approx");
|
||||
learner->SetParam("nthread", "0");
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < 8; ++i) {
|
||||
learner->UpdateOneIter(i, Xy);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> out_prediction_cached;
|
||||
learner->Predict(Xy, false, &out_prediction_cached, 0, 0);
|
||||
|
||||
Json model{Object()};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
HostDeviceVector<float> out_prediction;
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
learner->LoadModel(model);
|
||||
learner->Predict(Xy, false, &out_prediction, 0, 0);
|
||||
}
|
||||
|
||||
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
|
||||
auto const h_predt = out_prediction.ConstHostSpan();
|
||||
|
||||
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
|
||||
for (size_t i = 0; i < h_predt.size(); ++i) {
|
||||
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user