Remove public access to tree model param. (#8902)

* Make tree model param a private member.
* Number of features and targets are immutable after construction.

This is to reduce the number of places where we can run configuration.
This commit is contained in:
Jiaming Yuan
2023-03-13 20:55:10 +08:00
committed by GitHub
parent 5ba3509dd3
commit 9bade7203a
14 changed files with 149 additions and 159 deletions

View File

@@ -40,8 +40,7 @@ TEST(GrowHistMaker, InteractionConstraint)
ObjInfo task{ObjInfo::kRegression};
{
// With constraints
RegTree tree;
tree.param.num_feature = kCols;
RegTree tree{1, kCols};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
TrainParam param;
@@ -58,8 +57,7 @@ TEST(GrowHistMaker, InteractionConstraint)
}
{
// Without constraints
RegTree tree;
tree.param.num_feature = kCols;
RegTree tree{1u, kCols};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1);
@@ -76,7 +74,7 @@ TEST(GrowHistMaker, InteractionConstraint)
}
namespace {
void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
void TestColumnSplit(int32_t rows, bst_feature_t cols, RegTree const& expected_tree) {
auto p_dmat = GenerateDMatrix(rows, cols);
auto p_gradients = GenerateGradients(rows);
Context ctx;
@@ -87,8 +85,7 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
std::unique_ptr<DMatrix> sliced{
p_dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
RegTree tree;
tree.param.num_feature = cols;
RegTree tree{1u, cols};
TrainParam param;
param.Init(Args{});
updater->Update(&param, p_gradients.get(), sliced.get(), position, {&tree});
@@ -107,8 +104,7 @@ TEST(GrowHistMaker, ColumnSplit) {
auto constexpr kRows = 32;
auto constexpr kCols = 16;
RegTree expected_tree;
expected_tree.param.num_feature = kCols;
RegTree expected_tree{1u, kCols};
ObjInfo task{ObjInfo::kRegression};
{
auto p_dmat = GenerateDMatrix(kRows, kCols);

View File

@@ -17,8 +17,8 @@ TEST(MultiTargetTree, JsonIO) {
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, Context::kCpuId};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
left_weight.HostView(), right_weight.HostView());
ASSERT_EQ(tree.param.num_nodes, 3);
ASSERT_EQ(tree.param.size_leaf_vector, 3);
ASSERT_EQ(tree.NumNodes(), 3);
ASSERT_EQ(tree.NumTargets(), 3);
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
ASSERT_EQ(tree.Size(), 3);
@@ -26,20 +26,19 @@ TEST(MultiTargetTree, JsonIO) {
tree.SaveModel(&jtree);
auto check_jtree = [](Json jtree, RegTree const& tree) {
ASSERT_EQ(get<String const>(jtree["tree_param"]["num_nodes"]),
std::to_string(tree.param.num_nodes));
ASSERT_EQ(get<String const>(jtree["tree_param"]["num_nodes"]), std::to_string(tree.NumNodes()));
ASSERT_EQ(get<F32Array const>(jtree["base_weights"]).size(),
tree.param.num_nodes * tree.param.size_leaf_vector);
ASSERT_EQ(get<I32Array const>(jtree["parents"]).size(), tree.param.num_nodes);
ASSERT_EQ(get<I32Array const>(jtree["left_children"]).size(), tree.param.num_nodes);
ASSERT_EQ(get<I32Array const>(jtree["right_children"]).size(), tree.param.num_nodes);
tree.NumNodes() * tree.NumTargets());
ASSERT_EQ(get<I32Array const>(jtree["parents"]).size(), tree.NumNodes());
ASSERT_EQ(get<I32Array const>(jtree["left_children"]).size(), tree.NumNodes());
ASSERT_EQ(get<I32Array const>(jtree["right_children"]).size(), tree.NumNodes());
};
check_jtree(jtree, tree);
RegTree loaded;
loaded.LoadModel(jtree);
ASSERT_TRUE(loaded.IsMultiTarget());
ASSERT_EQ(loaded.param.num_nodes, 3);
ASSERT_EQ(loaded.NumNodes(), 3);
Json jtree1{Object{}};
loaded.SaveModel(&jtree1);

View File

@@ -32,8 +32,7 @@ TEST(Updater, Prune) {
auto ctx = CreateEmptyGenericParam(GPUIDX);
// prepare tree
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg);
RegTree tree = RegTree{1u, kCols};
std::vector<RegTree*> trees {&tree};
// prepare pruner
TrainParam param;

View File

@@ -28,9 +28,8 @@ TEST(Updater, Refresh) {
{"num_feature", std::to_string(kCols)},
{"reg_lambda", "1"}};
RegTree tree = RegTree();
RegTree tree = RegTree{1u, kCols};
auto ctx = CreateEmptyGenericParam(GPUIDX);
tree.param.UpdateAllowUnknown(cfg);
std::vector<RegTree*> trees{&tree};
ObjInfo task{ObjInfo::kRegression};

View File

@@ -11,9 +11,8 @@
namespace xgboost {
TEST(Tree, ModelShape) {
bst_feature_t n_features = std::numeric_limits<uint32_t>::max();
RegTree tree;
tree.param.UpdateAllowUnknown(Args{{"num_feature", std::to_string(n_features)}});
ASSERT_EQ(tree.param.num_feature, n_features);
RegTree tree{1u, n_features};
ASSERT_EQ(tree.NumFeatures(), n_features);
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/tree.model";
@@ -27,7 +26,7 @@ TEST(Tree, ModelShape) {
RegTree new_tree;
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(tmp_file.c_str(), "r"));
new_tree.Load(fi.get());
ASSERT_EQ(new_tree.param.num_feature, n_features);
ASSERT_EQ(new_tree.NumFeatures(), n_features);
}
{
// json
@@ -39,7 +38,7 @@ TEST(Tree, ModelShape) {
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()});
new_tree.LoadModel(j_loaded);
ASSERT_EQ(new_tree.param.num_feature, n_features);
ASSERT_EQ(new_tree.NumFeatures(), n_features);
}
{
// ubjson
@@ -51,7 +50,7 @@ TEST(Tree, ModelShape) {
auto j_loaded = Json::Load(StringView{dumped.data(), dumped.size()}, std::ios::binary);
new_tree.LoadModel(j_loaded);
ASSERT_EQ(new_tree.param.num_feature, n_features);
ASSERT_EQ(new_tree.NumFeatures(), n_features);
}
}
@@ -488,8 +487,7 @@ TEST(Tree, JsonIO) {
RegTree loaded_tree;
loaded_tree.LoadModel(j_tree);
ASSERT_EQ(loaded_tree.param.num_nodes, 3);
ASSERT_EQ(loaded_tree.NumNodes(), 3);
ASSERT_TRUE(loaded_tree == tree);
auto left = tree[0].LeftChild();

View File

@@ -37,8 +37,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
: CreateEmptyGenericParam(Context::kCpuId));
auto up = std::unique_ptr<TreeUpdater>{TreeUpdater::Create(updater, &ctx, &task)};
up->Configure(Args{});
RegTree tree;
tree.param.num_feature = kCols;
RegTree tree{1u, kCols};
std::vector<HostDeviceVector<bst_node_t>> position(1);
up->Update(&param, &gpairs_, p_dmat_.get(), position, {&tree});
@@ -95,16 +94,14 @@ class UpdaterEtaTest : public ::testing::Test {
param1.Init(Args{{"eta", "1.0"}});
for (size_t iter = 0; iter < 4; ++iter) {
RegTree tree_0;
RegTree tree_0{1u, kCols};
{
tree_0.param.num_feature = kCols;
std::vector<HostDeviceVector<bst_node_t>> position(1);
up_0->Update(&param0, &gpairs_, p_dmat_.get(), position, {&tree_0});
}
RegTree tree_1;
RegTree tree_1{1u, kCols};
{
tree_1.param.num_feature = kCols;
std::vector<HostDeviceVector<bst_node_t>> position(1);
up_1->Update(&param1, &gpairs_, p_dmat_.get(), position, {&tree_1});
}