Support learning rate for zero-hessian objectives. (#8866)

This commit is contained in:
Jiaming Yuan
2023-03-06 20:33:28 +08:00
committed by GitHub
parent 173096a6a7
commit 228a46e8ad
34 changed files with 464 additions and 434 deletions

View File

@@ -76,7 +76,7 @@ void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
}
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
std::int32_t group_idx, MetaInfo const& info,
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
auto& tree = *p_tree;
@@ -87,7 +87,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
size_t n_leaf = nidx.size();
if (nptr.empty()) {
std::vector<float> quantiles;
UpdateLeafValues(&quantiles, nidx, p_tree);
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
return;
}
@@ -133,12 +133,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
quantiles.at(k) = q;
});
UpdateLeafValues(&quantiles, nidx, p_tree);
UpdateLeafValues(&quantiles, nidx, learning_rate, p_tree);
}
#if !defined(XGBOOST_USE_CUDA)
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
MetaInfo const&, HostDeviceVector<float> const&, float, RegTree*) {
MetaInfo const&, float learning_rate, HostDeviceVector<float> const&,
float, RegTree*) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)

View File

@@ -140,7 +140,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
}
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
std::int32_t group_idx, MetaInfo const& info,
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
dh::device_vector<size_t> ridx;
@@ -151,7 +151,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
if (nptr.Empty()) {
std::vector<float> quantiles;
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), p_tree);
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), learning_rate, p_tree);
}
HostDeviceVector<float> quantiles;
@@ -186,7 +186,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
w_it + d_weights.size(), &quantiles);
}
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), p_tree);
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), learning_rate, p_tree);
}
} // namespace detail
} // namespace obj

View File

@@ -36,7 +36,7 @@ inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
}
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const& nidx,
RegTree* p_tree) {
float learning_rate, RegTree* p_tree) {
auto& tree = *p_tree;
auto& quantiles = *p_quantiles;
auto const& h_node_idx = nidx;
@@ -71,7 +71,7 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
auto nidx = h_node_idx[i];
auto q = quantiles[i];
CHECK(tree[nidx].IsLeaf());
tree[nidx].SetLeaf(q);
tree[nidx].SetLeaf(q * learning_rate);
}
}
@@ -85,24 +85,24 @@ inline std::size_t IdxY(MetaInfo const& info, bst_group_t group_idx) {
}
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
std::int32_t group_idx, MetaInfo const& info,
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
std::int32_t group_idx, MetaInfo const& info,
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
} // namespace detail
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
std::int32_t group_idx, MetaInfo const& info,
std::int32_t group_idx, MetaInfo const& info, float learning_rate,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
if (ctx->IsCPU()) {
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, predt, alpha,
p_tree);
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, learning_rate,
predt, alpha, p_tree);
} else {
position.SetDevice(ctx->gpu_id);
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, predt, alpha,
p_tree);
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, learning_rate,
predt, alpha, p_tree);
}
}
} // namespace obj

View File

@@ -183,10 +183,11 @@ class QuantileRegression : public ObjFunction {
}
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
RegTree* p_tree) const override {
float learning_rate, HostDeviceVector<float> const& prediction,
std::int32_t group_idx, RegTree* p_tree) const override {
auto alpha = param_.quantile_alpha[group_idx];
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, alpha, p_tree);
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, learning_rate, prediction,
alpha, p_tree);
}
void Configure(Args const& args) override {

View File

@@ -742,9 +742,10 @@ class MeanAbsoluteError : public ObjFunction {
}
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
RegTree* p_tree) const override {
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, 0.5, p_tree);
float learning_rate, HostDeviceVector<float> const& prediction,
std::int32_t group_idx, RegTree* p_tree) const override {
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, learning_rate, prediction, 0.5,
p_tree);
}
const char* DefaultEvalMetric() const override { return "mae"; }