Pass obj info by reference instead of by value. (#8889)
- Pass obj info into tree updater as const pointer. This way we don't have to initialize the learner model param before configuring gbm, hence breaking up the dependency of configurations.
This commit is contained in:
@@ -37,13 +37,13 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
auto p_gradients = GenerateGradients(kRows);
|
||||
|
||||
Context ctx;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
{
|
||||
// With constraints
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(
|
||||
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
|
||||
@@ -61,8 +61,7 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
RegTree tree;
|
||||
tree.param.num_feature = kCols;
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
@@ -81,8 +80,8 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
||||
auto p_dmat = GenerateDMatrix(rows, cols);
|
||||
auto p_gradients = GenerateGradients(rows);
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
|
||||
std::unique_ptr<DMatrix> sliced{
|
||||
@@ -110,12 +109,12 @@ TEST(GrowHistMaker, ColumnSplit) {
|
||||
|
||||
RegTree expected_tree;
|
||||
expected_tree.param.num_feature = kCols;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
{
|
||||
auto p_dmat = GenerateDMatrix(kRows, kCols);
|
||||
auto p_gradients = GenerateGradients(kRows);
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
|
||||
Reference in New Issue
Block a user