Support learning rate for zero-hessian objectives. (#8866)

This commit is contained in:
Jiaming Yuan
2023-03-06 20:33:28 +08:00
committed by GitHub
parent 173096a6a7
commit 228a46e8ad
34 changed files with 464 additions and 434 deletions

View File

@@ -116,12 +116,13 @@ class ObjFunction : public Configurable {
*
* \param position The leaf index for each rows.
* \param info MetaInfo providing labels and weights.
* \param learning_rate The learning rate for current iteration.
* \param prediction Model prediction after transformation.
* \param group_idx The group index for this tree, 0 when it's not multi-target or multi-class.
* \param p_tree Tree that needs to be updated.
*/
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
MetaInfo const& /*info*/,
MetaInfo const& /*info*/, float /*learning_rate*/,
HostDeviceVector<float> const& /*prediction*/,
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}

View File

@@ -24,6 +24,9 @@
#include <vector>
namespace xgboost {
namespace tree {
struct TrainParam;
}
class Json;
struct Context;
@@ -56,8 +59,10 @@ class TreeUpdater : public Configurable {
* tree can be used.
*/
virtual bool HasNodePosition() const { return false; }
/*!
/**
* \brief perform update to the tree models
*
* \param param Hyper-parameter for constructing trees.
* \param gpair the gradient pair statistics of the data
* \param data The data matrix passed to the updater.
* \param out_position The leaf index for each row. The index is negated if that row is
@@ -67,8 +72,8 @@ class TreeUpdater : public Configurable {
* but maybe different random seeds, usually one tree is passed in at a time,
* there can be multiple trees when we train random forest style model
*/
virtual void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* data,
common::Span<HostDeviceVector<bst_node_t>> out_position,
virtual void Update(tree::TrainParam const* param, HostDeviceVector<GradientPair>* gpair,
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& out_trees) = 0;
/*!