Support learning rate for zero-hessian objectives. (#8866)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/reduce.h>
|
||||
@@ -756,7 +756,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
void Configure(const Args& args) override {
|
||||
// Used in test to count how many configurations are performed
|
||||
LOG(DEBUG) << "[GPU Hist]: Configure";
|
||||
param_.UpdateAllowUnknown(args);
|
||||
hist_maker_param_.UpdateAllowUnknown(args);
|
||||
dh::CheckComputeCapability();
|
||||
initialised_ = false;
|
||||
@@ -768,32 +767,26 @@ class GPUHistMaker : public TreeUpdater {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||
initialised_ = false;
|
||||
FromJson(config.at("train_param"), ¶m_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
||||
out["train_param"] = ToJson(param_);
|
||||
}
|
||||
|
||||
~GPUHistMaker() { // NOLINT
|
||||
dh::GlobalMemoryLogger().Log();
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor_.Start("Update");
|
||||
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
|
||||
// build tree
|
||||
try {
|
||||
size_t t_idx{0};
|
||||
for (xgboost::RegTree* tree : trees) {
|
||||
this->UpdateTree(gpair, dmat, tree, &out_position[t_idx]);
|
||||
this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]);
|
||||
|
||||
if (hist_maker_param_.debug_synchronize) {
|
||||
this->CheckTreesSynchronized(tree);
|
||||
@@ -804,12 +797,10 @@ class GPUHistMaker : public TreeUpdater {
|
||||
} catch (const std::exception& e) {
|
||||
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
||||
}
|
||||
|
||||
param_.learning_rate = lr;
|
||||
monitor_.Stop("Update");
|
||||
}
|
||||
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
void InitDataOnce(TrainParam const* param, DMatrix* dmat) {
|
||||
CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device";
|
||||
info_ = &dmat->Info();
|
||||
|
||||
@@ -818,24 +809,24 @@ class GPUHistMaker : public TreeUpdater {
|
||||
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
|
||||
BatchParam batch_param{
|
||||
ctx_->gpu_id,
|
||||
param_.max_bin,
|
||||
ctx_->gpu_id,
|
||||
param->max_bin,
|
||||
};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
info_->feature_types.SetDevice(ctx_->gpu_id);
|
||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(
|
||||
ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_,
|
||||
ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, *param,
|
||||
column_sampling_seed, info_->num_col_, batch_param));
|
||||
|
||||
p_last_fmat_ = dmat;
|
||||
initialised_ = true;
|
||||
}
|
||||
|
||||
void InitData(DMatrix* dmat, RegTree const* p_tree) {
|
||||
void InitData(TrainParam const* param, DMatrix* dmat, RegTree const* p_tree) {
|
||||
if (!initialised_) {
|
||||
monitor_.Start("InitDataOnce");
|
||||
this->InitDataOnce(dmat);
|
||||
this->InitDataOnce(param, dmat);
|
||||
monitor_.Stop("InitDataOnce");
|
||||
}
|
||||
p_last_tree_ = p_tree;
|
||||
@@ -856,10 +847,10 @@ class GPUHistMaker : public TreeUpdater {
|
||||
CHECK(*local_tree == reference_tree);
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
monitor_.Start("InitData");
|
||||
this->InitData(p_fmat, p_tree);
|
||||
this->InitData(param, p_fmat, p_tree);
|
||||
monitor_.Stop("InitData");
|
||||
|
||||
gpair->SetDevice(ctx_->gpu_id);
|
||||
@@ -878,7 +869,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
return result;
|
||||
}
|
||||
|
||||
TrainParam param_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
|
||||
std::unique_ptr<GPUHistMakerDevice<GradientSumT>> maker; // NOLINT
|
||||
|
||||
Reference in New Issue
Block a user