Pass obj info by reference instead of by value. (#8889)
- Pass obj info into tree updater as const pointer. This way we don't have to initialize the learner model param before configuring gbm, hence breaking up the dependency of configurations.
This commit is contained in:
@@ -2,9 +2,13 @@
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/task.h> // for ObjInfo
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
|
||||
@@ -26,12 +30,12 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
||||
|
||||
void RunTest(std::string updater) {
|
||||
tree::TrainParam param;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
param.Init(Args{});
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up->Configure(Args{});
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
@@ -74,18 +78,18 @@ class UpdaterEtaTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void RunTest(std::string updater) {
|
||||
ObjInfo task{ObjInfo::kClassification};
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
|
||||
float eta = 0.4;
|
||||
auto up_0 = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||
auto up_0 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up_0->Configure(Args{});
|
||||
tree::TrainParam param0;
|
||||
param0.Init(Args{{"eta", std::to_string(eta)}});
|
||||
|
||||
auto up_1 = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kClassification})};
|
||||
auto up_1 = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up_1->Configure(Args{{"eta", "1.0"}});
|
||||
tree::TrainParam param1;
|
||||
param1.Init(Args{{"eta", "1.0"}});
|
||||
@@ -153,11 +157,11 @@ class TestMinSplitLoss : public ::testing::Test {
|
||||
{"gamma", std::to_string(gamma)}};
|
||||
tree::TrainParam param;
|
||||
param.UpdateAllowUnknown(args);
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
Context ctx(updater == "grow_gpu_hist" ? CreateEmptyGenericParam(0)
|
||||
: CreateEmptyGenericParam(Context::kCpuId));
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
|
||||
up->Configure({});
|
||||
|
||||
RegTree tree;
|
||||
|
||||
Reference in New Issue
Block a user