Rewrite approx (#7214)

This PR rewrites the approx tree method to use codebase from hist for better performance and code sharing.

The rewrite has many benefits:
- Support for both `max_leaves` and `max_depth`.
- Support for `grow_policy`.
- Support for mono constraint.
- Support for feature weights.
- Support for easier bin configuration (`max_bin`).
- Support for categorical data.
- Faster performance for most of the datasets. (many times faster)
- Support for prediction cache.
- Significantly better performance for external memory.
- Unites the code base between approx and hist.
This commit is contained in:
Jiaming Yuan
2022-01-10 21:15:05 +08:00
committed by GitHub
parent ed95e77752
commit 001503186c
22 changed files with 635 additions and 264 deletions

View File

@@ -72,5 +72,58 @@ TEST(Approx, Partitioner) {
}
}
}
TEST(Approx, PredictionCache) {
size_t n_samples = 2048, n_features = 13;
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
{
omp_set_num_threads(1);
GenericParameter ctx;
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
std::unique_ptr<TreeUpdater> approx{
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
RegTree tree;
std::vector<RegTree *> trees{&tree};
auto gpair = GenerateRandomGradients(n_samples);
approx->Configure(Args{{"max_bin", "64"}});
approx->Update(&gpair, Xy.get(), trees);
HostDeviceVector<float> out_prediction_cached;
out_prediction_cached.Resize(n_samples);
auto cache = linalg::VectorView<float>{
out_prediction_cached.HostSpan(), {out_prediction_cached.Size()}, GenericParameter::kCpuId};
ASSERT_TRUE(approx->UpdatePredictionCache(Xy.get(), cache));
}
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->SetParam("tree_method", "approx");
learner->SetParam("nthread", "0");
learner->Configure();
for (size_t i = 0; i < 8; ++i) {
learner->UpdateOneIter(i, Xy);
}
HostDeviceVector<float> out_prediction_cached;
learner->Predict(Xy, false, &out_prediction_cached, 0, 0);
Json model{Object()};
learner->SaveModel(&model);
HostDeviceVector<float> out_prediction;
{
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
learner->LoadModel(model);
learner->Predict(Xy, false, &out_prediction, 0, 0);
}
auto const h_predt_cached = out_prediction_cached.ConstHostSpan();
auto const h_predt = out_prediction.ConstHostSpan();
ASSERT_EQ(h_predt.size(), h_predt_cached.size());
for (size_t i = 0; i < h_predt.size(); ++i) {
ASSERT_NEAR(h_predt[i], h_predt_cached[i], kRtEps);
}
}
} // namespace tree
} // namespace xgboost