Support learning rate for zero-hessian objectives. (#8866)
This commit is contained in:
@@ -76,7 +76,7 @@ void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
|
||||
}
|
||||
|
||||
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
||||
std::int32_t group_idx, MetaInfo const& info,
|
||||
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||
auto& tree = *p_tree;
|
||||
|
||||
@@ -87,7 +87,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
||||
size_t n_leaf = nidx.size();
|
||||
if (nptr.empty()) {
|
||||
std::vector<float> quantiles;
|
||||
UpdateLeafValues(&quantiles, nidx, p_tree);
|
||||
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -133,12 +133,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
||||
quantiles.at(k) = q;
|
||||
});
|
||||
|
||||
UpdateLeafValues(&quantiles, nidx, p_tree);
|
||||
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
|
||||
MetaInfo const&, HostDeviceVector<float> const&, float, RegTree*) {
|
||||
MetaInfo const&, float learning_rate, HostDeviceVector<float> const&,
|
||||
float, RegTree*) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@@ -140,7 +140,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
||||
}
|
||||
|
||||
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
||||
std::int32_t group_idx, MetaInfo const& info,
|
||||
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
||||
dh::device_vector<size_t> ridx;
|
||||
@@ -151,7 +151,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
||||
|
||||
if (nptr.Empty()) {
|
||||
std::vector<float> quantiles;
|
||||
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), p_tree);
|
||||
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), learning_rate, p_tree);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> quantiles;
|
||||
@@ -186,7 +186,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
||||
w_it + d_weights.size(), &quantiles);
|
||||
}
|
||||
|
||||
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), p_tree);
|
||||
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), learning_rate, p_tree);
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace obj
|
||||
|
||||
@@ -36,7 +36,7 @@ inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
|
||||
}
|
||||
|
||||
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
|
||||
RegTree* p_tree) {
|
||||
float learning_rate, RegTree* p_tree) {
|
||||
auto& tree = *p_tree;
|
||||
auto& quantiles = *p_quantiles;
|
||||
auto const& h_node_idx = nidx;
|
||||
@@ -71,7 +71,7 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
||||
auto nidx = h_node_idx[i];
|
||||
auto q = quantiles[i];
|
||||
CHECK(tree[nidx].IsLeaf());
|
||||
tree[nidx].SetLeaf(q);
|
||||
tree[nidx].SetLeaf(q * learning_rate);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,24 +85,24 @@ inline std::size_t IdxY(MetaInfo const& info, bst_group_t group_idx) {
|
||||
}
|
||||
|
||||
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
|
||||
std::int32_t group_idx, MetaInfo const& info,
|
||||
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
||||
|
||||
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
|
||||
std::int32_t group_idx, MetaInfo const& info,
|
||||
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
|
||||
} // namespace detail
|
||||
|
||||
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
|
||||
std::int32_t group_idx, MetaInfo const& info,
|
||||
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
|
||||
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
|
||||
if (ctx->IsCPU()) {
|
||||
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, predt, alpha,
|
||||
p_tree);
|
||||
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
|
||||
predt, alpha, p_tree);
|
||||
} else {
|
||||
position.SetDevice(ctx->gpu_id);
|
||||
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, predt, alpha,
|
||||
p_tree);
|
||||
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate,
|
||||
predt, alpha, p_tree);
|
||||
}
|
||||
}
|
||||
} // namespace obj
|
||||
|
||||
@@ -183,10 +183,11 @@ class QuantileRegression : public ObjFunction {
|
||||
}
|
||||
|
||||
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
||||
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
|
||||
RegTree* p_tree) const override {
|
||||
float learning_rate, HostDeviceVector<float> const& prediction,
|
||||
std::int32_t group_idx, RegTree* p_tree) const override {
|
||||
auto alpha = param_.quantile_alpha[group_idx];
|
||||
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, alpha, p_tree);
|
||||
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, learning_rate, prediction,
|
||||
alpha, p_tree);
|
||||
}
|
||||
|
||||
void Configure(Args const& args) override {
|
||||
|
||||
@@ -742,9 +742,10 @@ class MeanAbsoluteError : public ObjFunction {
|
||||
}
|
||||
|
||||
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
|
||||
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
|
||||
RegTree* p_tree) const override {
|
||||
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, 0.5, p_tree);
|
||||
float learning_rate, HostDeviceVector<float> const& prediction,
|
||||
std::int32_t group_idx, RegTree* p_tree) const override {
|
||||
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, learning_rate, prediction, 0.5,
|
||||
p_tree);
|
||||
}
|
||||
|
||||
const char* DefaultEvalMetric() const override { return "mae"; }
|
||||
|
||||
Reference in New Issue
Block a user