Small cleanup for histogram routines. (#9427)

* Small cleanup for histogram routines.

- Extract hist train param from GPU hist.
- Make histogram const after construction.
- Unify parameter names.
This commit is contained in:
Jiaming Yuan
2023-08-02 18:28:26 +08:00
committed by GitHub
parent c2b85ab68a
commit e93a274823
17 changed files with 182 additions and 111 deletions

View File

@@ -39,6 +39,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
param.UpdateAllowUnknown(
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 4);
@@ -55,6 +56,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 10);
@@ -81,6 +83,7 @@ void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical,
RegTree tree{1u, cols};
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), sliced.get(), position, {&tree});
Json json{Object{}};
@@ -104,6 +107,7 @@ void TestColumnSplit(bool categorical) {
std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&expected_tree});
}