Initial GPU support for the approx tree method. (#9414)
This commit is contained in:
@@ -20,10 +20,11 @@ class TestGrowPolicy : public ::testing::Test {
|
||||
true);
|
||||
}
|
||||
|
||||
std::unique_ptr<Learner> TrainOneIter(std::string tree_method, std::string policy,
|
||||
int32_t max_leaves, int32_t max_depth) {
|
||||
std::unique_ptr<Learner> TrainOneIter(Context const* ctx, std::string tree_method,
|
||||
std::string policy, int32_t max_leaves, int32_t max_depth) {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
|
||||
learner->SetParam("tree_method", tree_method);
|
||||
learner->SetParam("device", ctx->DeviceName());
|
||||
if (max_leaves >= 0) {
|
||||
learner->SetParam("max_leaves", std::to_string(max_leaves));
|
||||
}
|
||||
@@ -63,7 +64,7 @@ class TestGrowPolicy : public ::testing::Test {
|
||||
|
||||
if (max_leaves == 0 && max_depth == 0) {
|
||||
// unconstrainted
|
||||
if (tree_method != "gpu_hist") {
|
||||
if (ctx->IsCPU()) {
|
||||
// GPU pre-allocates for all nodes.
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
}
|
||||
@@ -86,23 +87,23 @@ class TestGrowPolicy : public ::testing::Test {
|
||||
return learner;
|
||||
}
|
||||
|
||||
void TestCombination(std::string tree_method) {
|
||||
void TestCombination(Context const* ctx, std::string tree_method) {
|
||||
for (auto policy : {"depthwise", "lossguide"}) {
|
||||
// -1 means default
|
||||
for (auto leaves : {-1, 0, 3}) {
|
||||
for (auto depth : {-1, 0, 3}) {
|
||||
this->TrainOneIter(tree_method, policy, leaves, depth);
|
||||
this->TrainOneIter(ctx, tree_method, policy, leaves, depth);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestTreeGrowPolicy(std::string tree_method, std::string policy) {
|
||||
void TestTreeGrowPolicy(Context const* ctx, std::string tree_method, std::string policy) {
|
||||
{
|
||||
/**
|
||||
* max_leaves
|
||||
*/
|
||||
auto learner = this->TrainOneIter(tree_method, policy, 16, -1);
|
||||
auto learner = this->TrainOneIter(ctx, tree_method, policy, 16, -1);
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
@@ -115,7 +116,7 @@ class TestGrowPolicy : public ::testing::Test {
|
||||
/**
|
||||
* max_depth
|
||||
*/
|
||||
auto learner = this->TrainOneIter(tree_method, policy, -1, 3);
|
||||
auto learner = this->TrainOneIter(ctx, tree_method, policy, -1, 3);
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
|
||||
@@ -133,25 +134,36 @@ class TestGrowPolicy : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(TestGrowPolicy, Approx) {
|
||||
this->TestTreeGrowPolicy("approx", "depthwise");
|
||||
this->TestTreeGrowPolicy("approx", "lossguide");
|
||||
Context ctx;
|
||||
this->TestTreeGrowPolicy(&ctx, "approx", "depthwise");
|
||||
this->TestTreeGrowPolicy(&ctx, "approx", "lossguide");
|
||||
|
||||
this->TestCombination("approx");
|
||||
this->TestCombination(&ctx, "approx");
|
||||
}
|
||||
|
||||
TEST_F(TestGrowPolicy, Hist) {
|
||||
this->TestTreeGrowPolicy("hist", "depthwise");
|
||||
this->TestTreeGrowPolicy("hist", "lossguide");
|
||||
Context ctx;
|
||||
this->TestTreeGrowPolicy(&ctx, "hist", "depthwise");
|
||||
this->TestTreeGrowPolicy(&ctx, "hist", "lossguide");
|
||||
|
||||
this->TestCombination("hist");
|
||||
this->TestCombination(&ctx, "hist");
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_F(TestGrowPolicy, GpuHist) {
|
||||
this->TestTreeGrowPolicy("gpu_hist", "depthwise");
|
||||
this->TestTreeGrowPolicy("gpu_hist", "lossguide");
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
this->TestTreeGrowPolicy(&ctx, "hist", "depthwise");
|
||||
this->TestTreeGrowPolicy(&ctx, "hist", "lossguide");
|
||||
|
||||
this->TestCombination("gpu_hist");
|
||||
this->TestCombination(&ctx, "hist");
|
||||
}
|
||||
|
||||
TEST_F(TestGrowPolicy, GpuApprox) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
this->TestTreeGrowPolicy(&ctx, "approx", "depthwise");
|
||||
this->TestTreeGrowPolicy(&ctx, "approx", "lossguide");
|
||||
|
||||
this->TestCombination(&ctx, "approx");
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user