Pass infomation about objective to tree methods. (#7385)
* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
@@ -275,7 +275,8 @@ void TestHistogramIndexImpl() {
|
||||
int constexpr kNRows = 1000, kNCols = 10;
|
||||
|
||||
// Build 2 matrices and build a histogram maker with that
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker, hist_maker_ext;
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}},
|
||||
hist_maker_ext{ObjInfo{ObjInfo::kRegression}};
|
||||
std::unique_ptr<DMatrix> hist_maker_dmat(
|
||||
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true));
|
||||
|
||||
@@ -333,7 +334,7 @@ int32_t TestMinSplitLoss(DMatrix* dmat, float gamma, HostDeviceVector<GradientPa
|
||||
{"gamma", std::to_string(gamma)}
|
||||
};
|
||||
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
hist_maker.Configure(args, &generic_param);
|
||||
|
||||
@@ -394,7 +395,7 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
{"sampling_method", sampling_method},
|
||||
};
|
||||
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker;
|
||||
tree::GPUHistMakerSpecialised<GradientPairPrecise> hist_maker{ObjInfo{ObjInfo::kRegression}};
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
hist_maker.Configure(args, &generic_param);
|
||||
|
||||
@@ -539,7 +540,8 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
|
||||
TEST(GpuHist, ConfigIO) {
|
||||
GenericParameter generic_param(CreateEmptyGenericParam(0));
|
||||
std::unique_ptr<TreeUpdater> updater {TreeUpdater::Create("grow_gpu_hist", &generic_param) };
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_gpu_hist", &generic_param, ObjInfo{ObjInfo::kRegression})};
|
||||
updater->Configure(Args{});
|
||||
|
||||
Json j_updater { Object() };
|
||||
|
||||
@@ -34,7 +34,8 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater { TreeUpdater::Create("grow_histmaker", ¶m) };
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})};
|
||||
updater->Configure(Args{
|
||||
{"interaction_constraints", "[[0, 1]]"},
|
||||
{"num_feature", std::to_string(kCols)}});
|
||||
@@ -51,7 +52,8 @@ TEST(GrowHistMaker, InteractionConstraint) {
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater { TreeUpdater::Create("grow_histmaker", ¶m) };
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})};
|
||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
||||
updater->Update(&gradients, p_dmat.get(), {&tree});
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@ TEST(Updater, Prune) {
|
||||
tree.param.UpdateAllowUnknown(cfg);
|
||||
std::vector<RegTree*> trees {&tree};
|
||||
// prepare pruner
|
||||
std::unique_ptr<TreeUpdater> pruner(TreeUpdater::Create("prune", &lparam));
|
||||
std::unique_ptr<TreeUpdater> pruner(
|
||||
TreeUpdater::Create("prune", &lparam, ObjInfo{ObjInfo::kRegression}));
|
||||
pruner->Configure(cfg);
|
||||
|
||||
// loss_chg < min_split_loss;
|
||||
|
||||
@@ -28,7 +28,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
BuilderMock(const TrainParam ¶m, std::unique_ptr<TreeUpdater> pruner,
|
||||
DMatrix const *fmat)
|
||||
: RealImpl(1, param, std::move(pruner), fmat) {}
|
||||
: RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}) {}
|
||||
|
||||
public:
|
||||
void TestInitData(const GHistIndexMatrix& gmat,
|
||||
@@ -230,7 +230,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
explicit QuantileHistMock(
|
||||
const std::vector<std::pair<std::string, std::string> >& args,
|
||||
const bool single_precision_histogram = false, bool batch = true) :
|
||||
cfg_{args} {
|
||||
QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} {
|
||||
QuantileHistMaker::Configure(args);
|
||||
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
if (single_precision_histogram) {
|
||||
|
||||
@@ -32,7 +32,8 @@ TEST(Updater, Refresh) {
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
tree.param.UpdateAllowUnknown(cfg);
|
||||
std::vector<RegTree*> trees {&tree};
|
||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh", &lparam));
|
||||
std::unique_ptr<TreeUpdater> refresher(
|
||||
TreeUpdater::Create("refresh", &lparam, ObjInfo{ObjInfo::kRegression}));
|
||||
|
||||
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
|
||||
@@ -23,7 +23,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
|
||||
void RunTest(std::string updater) {
|
||||
auto tparam = CreateEmptyGenericParam(0);
|
||||
auto up = std::unique_ptr<TreeUpdater>{
|
||||
TreeUpdater::Create(updater, &tparam)};
|
||||
TreeUpdater::Create(updater, &tparam, ObjInfo{ObjInfo::kRegression})};
|
||||
up->Configure(Args{});
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
Reference in New Issue
Block a user