Pass infomation about objective to tree methods. (#7385)

* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
Jiaming Yuan
2021-11-04 01:52:44 +08:00
committed by GitHub
parent ccdabe4512
commit 4100827971
28 changed files with 178 additions and 69 deletions

View File

@@ -275,7 +275,8 @@ void TestHistogramIndexImpl() {
int constexpr kNRows = 1000, kNCols = 10;
// Build 2 matrices and build a histogram maker with that
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker, hist_maker_ext;
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}},
hist_maker_ext{ObjInfo{ObjInfo::kRegression}};
std::unique_ptr<DMatrix> hist_maker_dmat(
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
@@ -333,7 +334,7 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPa
{"gamma", std::to_string(gamma)}
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param);
@@ -394,7 +395,7 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
{"sampling_method", sampling_method},
};
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
GenericParameter generic_param(CreateEmptyGenericParam(0));
hist_maker.Configure(args, &generic_param);
@@ -539,7 +540,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
TEST(GpuHist, ConfigIO) {
GenericParameter generic_param(CreateEmptyGenericParam(0));
std::unique_ptr<TreeUpdater> updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) };
std::unique_ptr<TreeUpdater> updater{
TreeUpdater::Create("grow_gpu_hist", &generic_param, ObjInfo{ObjInfo::kRegression})};
updater->Configure(Args{});
Json j_updater { Object() };