Support learning rate for zero-hessian objectives. (#8866)
This commit is contained in:
@@ -5,11 +5,10 @@
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
namespace xgboost::tree {
|
||||
std::shared_ptr<DMatrix> GenerateDMatrix(std::size_t rows, std::size_t cols){
|
||||
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
|
||||
}
|
||||
@@ -45,11 +44,11 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
updater->Configure(Args{
|
||||
{"interaction_constraints", "[[0, 1]]"},
|
||||
{"num_feature", std::to_string(kCols)}});
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(
|
||||
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
updater->Update(p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||
updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||
|
||||
ASSERT_EQ(tree.NumExtraNodes(), 4);
|
||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||
@@ -64,9 +63,10 @@ TEST(GrowHistMaker, InteractionConstraint)
|
||||
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
updater->Update(p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&tree});
|
||||
|
||||
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||
@@ -83,7 +83,6 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
updater->Configure(Args{{"num_feature", std::to_string(cols)}});
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
|
||||
std::unique_ptr<DMatrix> sliced{
|
||||
@@ -91,7 +90,9 @@ void TestColumnSplit(int32_t rows, int32_t cols, RegTree const& expected_tree) {
|
||||
|
||||
RegTree tree;
|
||||
tree.param.num_feature = cols;
|
||||
updater->Update(p_gradients.get(), sliced.get(), position, {&tree});
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
||||
|
||||
EXPECT_EQ(tree.NumExtraNodes(), 10);
|
||||
EXPECT_EQ(tree[0].SplitIndex(), 1);
|
||||
@@ -115,14 +116,13 @@ TEST(GrowHistMaker, ColumnSplit) {
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_histmaker", &ctx, ObjInfo{ObjInfo::kRegression})};
|
||||
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
updater->Update(p_gradients.get(), p_dmat.get(), position, {&expected_tree});
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
updater->Update(¶m, p_gradients.get(), p_dmat.get(), position, {&expected_tree});
|
||||
}
|
||||
|
||||
auto constexpr kWorldSize = 2;
|
||||
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit, kRows, kCols, std::cref(expected_tree));
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Reference in New Issue
Block a user