Support learning rate for zero-hessian objectives. (#8866)
This commit is contained in:
@@ -32,15 +32,14 @@
|
||||
#include "xgboost/string_view.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
namespace xgboost::gbm {
|
||||
DMLC_REGISTRY_FILE_TAG(gbtree);
|
||||
|
||||
void GBTree::Configure(const Args& cfg) {
|
||||
void GBTree::Configure(Args const& cfg) {
|
||||
this->cfg_ = cfg;
|
||||
std::string updater_seq = tparam_.updater_seq;
|
||||
tparam_.UpdateAllowUnknown(cfg);
|
||||
tree_param_.UpdateAllowUnknown(cfg);
|
||||
|
||||
model_.Configure(cfg);
|
||||
|
||||
@@ -235,9 +234,11 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
|
||||
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
|
||||
CHECK_EQ(model_.param.num_parallel_tree, 1)
|
||||
<< "Boosting random forest is not supported for current objective.";
|
||||
CHECK_EQ(trees.size(), model_.param.num_parallel_tree);
|
||||
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
|
||||
auto const& position = node_position.at(tree_idx);
|
||||
obj->UpdateTreeLeaf(position, p_fmat->Info(), predictions, group_idx, trees[tree_idx].get());
|
||||
obj->UpdateTreeLeaf(position, p_fmat->Info(), tree_param_.learning_rate / trees.size(),
|
||||
predictions, group_idx, trees[tree_idx].get());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,9 +389,15 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
|
||||
|
||||
CHECK(out_position);
|
||||
out_position->resize(new_trees.size());
|
||||
|
||||
// Rescale learning rate according to the size of trees
|
||||
auto lr = tree_param_.learning_rate;
|
||||
tree_param_.learning_rate /= static_cast<float>(new_trees.size());
|
||||
for (auto& up : updaters_) {
|
||||
up->Update(gpair, p_fmat, common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
|
||||
up->Update(&tree_param_, gpair, p_fmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
|
||||
}
|
||||
tree_param_.learning_rate = lr;
|
||||
}
|
||||
|
||||
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
|
||||
@@ -404,6 +411,8 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
||||
void GBTree::LoadConfig(Json const& in) {
|
||||
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
||||
FromJson(in["gbtree_train_param"], &tparam_);
|
||||
FromJson(in["tree_train_param"], &tree_param_);
|
||||
|
||||
// Process type cannot be kUpdate from loaded model
|
||||
// This would cause all trees to be pushed to trees_to_update
|
||||
// e.g. updating a model, then saving and loading it would result in an empty model
|
||||
@@ -451,6 +460,7 @@ void GBTree::SaveConfig(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("gbtree");
|
||||
out["gbtree_train_param"] = ToJson(tparam_);
|
||||
out["tree_train_param"] = ToJson(tree_param_);
|
||||
|
||||
// Process type cannot be kUpdate from loaded model
|
||||
// This would cause all trees to be pushed to trees_to_update
|
||||
@@ -1058,5 +1068,4 @@ XGBOOST_REGISTER_GBM(Dart, "dart")
|
||||
GBTree* p = new Dart(booster_config, ctx);
|
||||
return p;
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::gbm
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../tree/param.h" // TrainParam
|
||||
#include "gbtree_model.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
@@ -405,8 +406,8 @@ class GBTree : public GradientBooster {
|
||||
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
std::string format) const override {
|
||||
[[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
std::string format) const override {
|
||||
return model_.DumpModel(fmap, with_stats, this->ctx_->Threads(), format);
|
||||
}
|
||||
|
||||
@@ -428,6 +429,8 @@ class GBTree : public GradientBooster {
|
||||
GBTreeModel model_;
|
||||
// training parameter
|
||||
GBTreeTrainParam tparam_;
|
||||
// Tree training parameter
|
||||
tree::TrainParam tree_param_;
|
||||
// ----training fields----
|
||||
bool showed_updater_warning_ {false};
|
||||
bool specified_updater_ {false};
|
||||
|
||||
Reference in New Issue
Block a user