Initial GPU support for the approx tree method. (#9414)
This commit is contained in:
@@ -135,7 +135,7 @@ class TestMinSplitLoss : public ::testing::Test {
|
||||
gpair_ = GenerateRandomGradients(kRows);
|
||||
}
|
||||
|
||||
std::int32_t Update(std::string updater, float gamma) {
|
||||
std::int32_t Update(Context const* ctx, std::string updater, float gamma) {
|
||||
Args args{{"max_depth", "1"},
|
||||
{"max_leaves", "0"},
|
||||
|
||||
@@ -154,8 +154,7 @@ class TestMinSplitLoss : public ::testing::Test {
|
||||
param.UpdateAllowUnknown(args);
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
Context ctx{MakeCUDACtx(updater == "grow_gpu_hist" ? 0 : Context::kCpuId)};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, ctx, &task)};
|
||||
up->Configure({});
|
||||
|
||||
RegTree tree;
|
||||
@@ -167,16 +166,16 @@ class TestMinSplitLoss : public ::testing::Test {
|
||||
}
|
||||
|
||||
public:
|
||||
void RunTest(std::string updater) {
|
||||
void RunTest(Context const* ctx, std::string updater) {
|
||||
{
|
||||
int32_t n_nodes = Update(updater, 0.01);
|
||||
int32_t n_nodes = Update(ctx, updater, 0.01);
|
||||
// This is not strictly verified, meaning the numeber `2` is whatever GPU_Hist retured
|
||||
// when writing this test, and only used for testing larger gamma (below) does prevent
|
||||
// building tree.
|
||||
ASSERT_EQ(n_nodes, 2);
|
||||
}
|
||||
{
|
||||
int32_t n_nodes = Update(updater, 100.0);
|
||||
int32_t n_nodes = Update(ctx, updater, 100.0);
|
||||
// No new nodes with gamma == 100.
|
||||
ASSERT_EQ(n_nodes, static_cast<decltype(n_nodes)>(0));
|
||||
}
|
||||
@@ -185,10 +184,25 @@ class TestMinSplitLoss : public ::testing::Test {
|
||||
|
||||
/* Exact tree method requires a pruner as an additional updater, so not tested here. */
|
||||
|
||||
TEST_F(TestMinSplitLoss, Approx) { this->RunTest("grow_histmaker"); }
|
||||
TEST_F(TestMinSplitLoss, Approx) {
|
||||
Context ctx;
|
||||
this->RunTest(&ctx, "grow_histmaker");
|
||||
}
|
||||
|
||||
TEST_F(TestMinSplitLoss, Hist) {
|
||||
Context ctx;
|
||||
this->RunTest(&ctx, "grow_quantile_histmaker");
|
||||
}
|
||||
|
||||
TEST_F(TestMinSplitLoss, Hist) { this->RunTest("grow_quantile_histmaker"); }
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(TestMinSplitLoss, GpuHist) { this->RunTest("grow_gpu_hist"); }
|
||||
TEST_F(TestMinSplitLoss, GpuHist) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
this->RunTest(&ctx, "grow_gpu_hist");
|
||||
}
|
||||
|
||||
TEST_F(TestMinSplitLoss, GpuApprox) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
this->RunTest(&ctx, "grow_gpu_approx");
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user