Support learning rate for zero-hessian objectives. (#8866)

This commit is contained in:
Jiaming Yuan
2023-03-06 20:33:28 +08:00
committed by GitHub
parent 173096a6a7
commit 228a46e8ad
34 changed files with 464 additions and 434 deletions

View File

@@ -32,15 +32,14 @@
#include "xgboost/string_view.h"
#include "xgboost/tree_updater.h"
namespace xgboost {
namespace gbm {
namespace xgboost::gbm {
DMLC_REGISTRY_FILE_TAG(gbtree);
void GBTree::Configure(const Args& cfg) {
void GBTree::Configure(Args const& cfg) {
this->cfg_ = cfg;
std::string updater_seq = tparam_.updater_seq;
tparam_.UpdateAllowUnknown(cfg);
tree_param_.UpdateAllowUnknown(cfg);
model_.Configure(cfg);
@@ -235,9 +234,11 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
CHECK_EQ(model_.param.num_parallel_tree, 1)
<< "Boosting random forest is not supported for current objective.";
CHECK_EQ(trees.size(), model_.param.num_parallel_tree);
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
auto const& position = node_position.at(tree_idx);
obj->UpdateTreeLeaf(position, p_fmat->Info(), predictions, group_idx, trees[tree_idx].get());
obj->UpdateTreeLeaf(position, p_fmat->Info(), tree_param_.learning_rate / trees.size(),
predictions, group_idx, trees[tree_idx].get());
}
}
@@ -388,9 +389,15 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
CHECK(out_position);
out_position->resize(new_trees.size());
// Rescale learning rate according to the size of trees
auto lr = tree_param_.learning_rate;
tree_param_.learning_rate /= static_cast<float>(new_trees.size());
for (auto& up : updaters_) {
up->Update(gpair, p_fmat, common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
up->Update(&tree_param_, gpair, p_fmat,
common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
}
tree_param_.learning_rate = lr;
}
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
@@ -404,6 +411,8 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
void GBTree::LoadConfig(Json const& in) {
CHECK_EQ(get<String>(in["name"]), "gbtree");
FromJson(in["gbtree_train_param"], &tparam_);
FromJson(in["tree_train_param"], &tree_param_);
// Process type cannot be kUpdate from loaded model
// This would cause all trees to be pushed to trees_to_update
// e.g. updating a model, then saving and loading it would result in an empty model
@@ -451,6 +460,7 @@ void GBTree::SaveConfig(Json* p_out) const {
auto& out = *p_out;
out["name"] = String("gbtree");
out["gbtree_train_param"] = ToJson(tparam_);
out["tree_train_param"] = ToJson(tree_param_);
// Process type cannot be kUpdate from loaded model
// This would cause all trees to be pushed to trees_to_update
@@ -1058,5 +1068,4 @@ XGBOOST_REGISTER_GBM(Dart, "dart")
GBTree* p = new Dart(booster_config, ctx);
return p;
});
} // namespace gbm
} // namespace xgboost
} // namespace xgboost::gbm

View File

@@ -20,6 +20,7 @@
#include "../common/common.h"
#include "../common/timer.h"
#include "../tree/param.h" // TrainParam
#include "gbtree_model.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
@@ -405,8 +406,8 @@ class GBTree : public GradientBooster {
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
std::string format) const override {
[[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
std::string format) const override {
return model_.DumpModel(fmap, with_stats, this->ctx_->Threads(), format);
}
@@ -428,6 +429,8 @@ class GBTree : public GradientBooster {
GBTreeModel model_;
// training parameter
GBTreeTrainParam tparam_;
// Tree training parameter
tree::TrainParam tree_param_;
// ----training fields----
bool showed_updater_warning_ {false};
bool specified_updater_ {false};