Unify test helpers for creating ctx. (#9274)

This commit is contained in:
Jiaming Yuan
2023-06-10 03:35:22 +08:00
committed by GitHub
parent ea0deeca68
commit 152e2fb072
36 changed files with 161 additions and 169 deletions

View File

@@ -33,8 +33,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
ObjInfo task{ObjInfo::kRegression};
param.Init(Args{});
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
: CreateEmptyGenericParam(Context::kCpuId));
Context ctx(updater == "grow_gpu_hist" ? MakeCUDACtx(0) : MakeCUDACtx(Context::kCpuId));
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
up->Configure(Args{});
RegTree tree{1u, kCols};
@@ -79,8 +78,7 @@ class UpdaterEtaTest : public ::testing::Test {
void RunTest(std::string updater) {
ObjInfo task{ObjInfo::kClassification};
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
: CreateEmptyGenericParam(Context::kCpuId));
Context ctx(updater == "grow_gpu_hist" ? MakeCUDACtx(0) : MakeCUDACtx(Context::kCpuId));
float eta = 0.4;
auto up_0 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
@@ -156,8 +154,7 @@ class TestMinSplitLoss : public ::testing::Test {
param.UpdateAllowUnknown(args);
ObjInfo task{ObjInfo::kRegression};
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
: CreateEmptyGenericParam(Context::kCpuId));
Context ctx{MakeCUDACtx(updater == "grow_gpu_hist" ? 0 : Context::kCpuId)};
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
up->Configure({});