Pass obj info by reference instead of by value. (#8889)

- Pass obj info into tree updater as const pointer.

This way we don't have to initialize the learner model param before configuring gbm, hence
breaking up the dependency of configurations.
This commit is contained in:
Jiaming Yuan
2023-03-11 01:38:28 +08:00
committed by GitHub
parent 54e001bbf4
commit 6deaec8027
18 changed files with 125 additions and 112 deletions

View File

@@ -170,8 +170,8 @@ void TestHistogramIndexImpl() {
// Build 2 matrices and build a histogram maker with that
Context ctx(CreateEmptyGenericParam(0));
tree::GPUHistMaker hist_maker{&ctx, ObjInfo{ObjInfo::kRegression}},
hist_maker_ext{&ctx, ObjInfo{ObjInfo::kRegression}};
ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task};
std::unique_ptr<DMatrix> hist_maker_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
@@ -240,7 +240,8 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
param.UpdateAllowUnknown(args);
Context ctx(CreateEmptyGenericParam(0));
tree::GPUHistMaker hist_maker{&ctx,ObjInfo{ObjInfo::kRegression}};
ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{&ctx, &task};
std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
@@ -385,8 +386,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
TEST(GpuHist, ConfigIO) {
Context ctx(CreateEmptyGenericParam(0));
std::unique_ptr<TreeUpdater> updater{
TreeUpdater::Create("grow_gpu_hist", &ctx, ObjInfo{ObjInfo::kRegression})};
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
updater->Configure(Args{});
Json j_updater { Object() };