Initial GPU support for the approx tree method. (#9414)
This commit is contained in:
@@ -62,8 +62,10 @@ class RegenTest : public ::testing::Test {
|
||||
auto constexpr Iter() const { return 4; }
|
||||
|
||||
template <typename Page>
|
||||
size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const {
|
||||
size_t TestTreeMethod(Context const* ctx, std::string tree_method, std::string obj,
|
||||
bool reset = true) const {
|
||||
auto learner = std::unique_ptr<Learner>{Learner::Create({p_fmat_})};
|
||||
learner->SetParam("device", ctx->DeviceName());
|
||||
learner->SetParam("tree_method", tree_method);
|
||||
learner->SetParam("objective", obj);
|
||||
learner->Configure();
|
||||
@@ -87,40 +89,71 @@ class RegenTest : public ::testing::Test {
|
||||
} // anonymous namespace
|
||||
|
||||
TEST_F(RegenTest, Approx) {
|
||||
auto n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:squarederror");
|
||||
Context ctx;
|
||||
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:squarederror");
|
||||
ASSERT_EQ(n, 1);
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic");
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic");
|
||||
ASSERT_EQ(n, this->Iter());
|
||||
}
|
||||
|
||||
TEST_F(RegenTest, Hist) {
|
||||
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror");
|
||||
Context ctx;
|
||||
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror");
|
||||
ASSERT_EQ(n, 1);
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:logistic");
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:logistic");
|
||||
ASSERT_EQ(n, 1);
|
||||
}
|
||||
|
||||
TEST_F(RegenTest, Mixed) {
|
||||
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", false);
|
||||
Context ctx;
|
||||
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", false);
|
||||
ASSERT_EQ(n, 1);
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", true);
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", true);
|
||||
ASSERT_EQ(n, this->Iter() + 1);
|
||||
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", false);
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", false);
|
||||
ASSERT_EQ(n, this->Iter());
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", true);
|
||||
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", true);
|
||||
ASSERT_EQ(n, this->Iter() + 1);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(RegenTest, GpuHist) {
|
||||
auto n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:squarederror");
|
||||
TEST_F(RegenTest, GpuApprox) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:squarederror", true);
|
||||
ASSERT_EQ(n, 1);
|
||||
n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:logistic", false);
|
||||
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
|
||||
ASSERT_EQ(n, this->Iter());
|
||||
|
||||
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
|
||||
ASSERT_EQ(n, this->Iter() * 2);
|
||||
}
|
||||
|
||||
TEST_F(RegenTest, GpuHist) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
|
||||
ASSERT_EQ(n, 1);
|
||||
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic", false);
|
||||
ASSERT_EQ(n, 1);
|
||||
|
||||
n = this->TestTreeMethod<EllpackPage>("hist", "reg:logistic");
|
||||
ASSERT_EQ(n, 2);
|
||||
{
|
||||
Context ctx;
|
||||
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic");
|
||||
ASSERT_EQ(n, 2);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(RegenTest, GpuMixed) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", false);
|
||||
ASSERT_EQ(n, 1);
|
||||
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
|
||||
ASSERT_EQ(n, this->Iter() + 1);
|
||||
|
||||
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
|
||||
ASSERT_EQ(n, this->Iter());
|
||||
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
|
||||
ASSERT_EQ(n, this->Iter() + 1);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user