Support learning rate for zero-hessian objectives. (#8866)
This commit is contained in:
@@ -116,12 +116,13 @@ class ObjFunction : public Configurable {
|
||||
*
|
||||
* \param position The leaf index for each rows.
|
||||
* \param info MetaInfo providing labels and weights.
|
||||
* \param learning_rate The learning rate for current iteration.
|
||||
* \param prediction Model prediction after transformation.
|
||||
* \param group_idx The group index for this tree, 0 when it's not multi-target or multi-class.
|
||||
* \param p_tree Tree that needs to be updated.
|
||||
*/
|
||||
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
|
||||
MetaInfo const& /*info*/,
|
||||
MetaInfo const& /*info*/, float /*learning_rate*/,
|
||||
HostDeviceVector<float> const& /*prediction*/,
|
||||
std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
|
||||
|
||||
|
||||
@@ -24,6 +24,9 @@
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
struct TrainParam;
|
||||
}
|
||||
|
||||
class Json;
|
||||
struct Context;
|
||||
@@ -56,8 +59,10 @@ class TreeUpdater : public Configurable {
|
||||
* tree can be used.
|
||||
*/
|
||||
virtual bool HasNodePosition() const { return false; }
|
||||
/*!
|
||||
/**
|
||||
* \brief perform update to the tree models
|
||||
*
|
||||
* \param param Hyper-parameter for constructing trees.
|
||||
* \param gpair the gradient pair statistics of the data
|
||||
* \param data The data matrix passed to the updater.
|
||||
* \param out_position The leaf index for each row. The index is negated if that row is
|
||||
@@ -67,8 +72,8 @@ class TreeUpdater : public Configurable {
|
||||
* but maybe different random seeds, usually one tree is passed in at a time,
|
||||
* there can be multiple trees when we train random forest style model
|
||||
*/
|
||||
virtual void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* data,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
virtual void Update(tree::TrainParam const* param, HostDeviceVector<GradientPair>* gpair,
|
||||
DMatrix* data, common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree*>& out_trees) = 0;
|
||||
|
||||
/*!
|
||||
|
||||
Reference in New Issue
Block a user