This commit is contained in:
Jiaming Yuan 2024-06-02 02:07:55 +08:00 committed by GitHub
parent 92cba25fe2
commit 4f48647932
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,7 +52,7 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
} }
} // anonymous namespace } // anonymous namespace
class GloablApproxBuilder { class GlobalApproxBuilder {
protected: protected:
TrainParam const *param_; TrainParam const *param_;
HistMakerTrainParam const *hist_param_{nullptr}; HistMakerTrainParam const *hist_param_{nullptr};
@ -161,7 +161,7 @@ class GloablApproxBuilder {
} }
public: public:
explicit GloablApproxBuilder(TrainParam const *param, HistMakerTrainParam const *hist_param, explicit GlobalApproxBuilder(TrainParam const *param, HistMakerTrainParam const *hist_param,
MetaInfo const &info, Context const *ctx, MetaInfo const &info, Context const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler, std::shared_ptr<common::ColumnSampler> column_sampler,
ObjInfo const *task, common::Monitor *monitor) ObjInfo const *task, common::Monitor *monitor)
@ -248,7 +248,7 @@ class GloablApproxBuilder {
class GlobalApproxUpdater : public TreeUpdater { class GlobalApproxUpdater : public TreeUpdater {
common::Monitor monitor_; common::Monitor monitor_;
// specializations for different histogram precision. // specializations for different histogram precision.
std::unique_ptr<GloablApproxBuilder> pimpl_; std::unique_ptr<GlobalApproxBuilder> pimpl_;
// pointer to the last DMatrix, used for update prediction cache. // pointer to the last DMatrix, used for update prediction cache.
DMatrix *cached_{nullptr}; DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_; std::shared_ptr<common::ColumnSampler> column_sampler_;
@ -289,7 +289,7 @@ class GlobalApproxUpdater : public TreeUpdater {
if (!column_sampler_) { if (!column_sampler_) {
column_sampler_ = common::MakeColumnSampler(ctx_); column_sampler_ = common::MakeColumnSampler(ctx_);
} }
pimpl_ = std::make_unique<GloablApproxBuilder>(param, &hist_param_, m->Info(), ctx_, pimpl_ = std::make_unique<GlobalApproxBuilder>(param, &hist_param_, m->Info(), ctx_,
column_sampler_, task_, &monitor_); column_sampler_, task_, &monitor_);
linalg::Matrix<GradientPair> h_gpair; linalg::Matrix<GradientPair> h_gpair;