Initial GPU support for the approx tree method. (#9414)

This commit is contained in:
Jiaming Yuan
2023-07-31 15:50:28 +08:00
committed by GitHub
parent 8f0efb4ab3
commit 912e341d57
23 changed files with 639 additions and 360 deletions

View File

@@ -62,8 +62,10 @@ class RegenTest : public ::testing::Test {
auto constexpr Iter() const { return 4; }
template <typename Page>
size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const {
size_t TestTreeMethod(Context const* ctx, std::string tree_method, std::string obj,
bool reset = true) const {
auto learner = std::unique_ptr<Learner>{Learner::Create({p_fmat_})};
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("tree_method", tree_method);
learner->SetParam("objective", obj);
learner->Configure();
@@ -87,40 +89,71 @@ class RegenTest : public ::testing::Test {
} // anonymous namespace
TEST_F(RegenTest, Approx) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:squarederror");
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic");
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic");
ASSERT_EQ(n, this->Iter());
}
TEST_F(RegenTest, Hist) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror");
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:logistic");
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 1);
}
TEST_F(RegenTest, Mixed) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", false);
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", true);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", false);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", true);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(RegenTest, GpuHist) {
auto n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:squarederror");
TEST_F(RegenTest, GpuApprox) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:squarederror", true);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:logistic", false);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() * 2);
}
TEST_F(RegenTest, GpuHist) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("hist", "reg:logistic");
ASSERT_EQ(n, 2);
{
Context ctx;
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 2);
}
}
TEST_F(RegenTest, GpuMixed) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost