Small cleanup for histogram routines. (#9427)
* Small cleanup for histogram routines. - Extract hist train param from GPU hist. - Make histogram const after construction. - Unify parameter names.
This commit is contained in:
@@ -65,7 +65,7 @@ class HistEvaluator {
|
||||
* pseudo-category for missing value but here we just do a complete scan to avoid
|
||||
* making specialized histogram bin.
|
||||
*/
|
||||
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
|
||||
void EnumerateOneHot(common::HistogramCuts const &cut, common::ConstGHistRow hist,
|
||||
bst_feature_t fidx, bst_node_t nidx,
|
||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||
SplitEntry *p_best) const {
|
||||
@@ -143,7 +143,7 @@ class HistEvaluator {
|
||||
*/
|
||||
template <int d_step>
|
||||
void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
|
||||
common::GHistRow const &hist, bst_feature_t fidx, bst_node_t nidx,
|
||||
common::ConstGHistRow hist, bst_feature_t fidx, bst_node_t nidx,
|
||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||
SplitEntry *p_best) {
|
||||
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
|
||||
@@ -214,7 +214,7 @@ class HistEvaluator {
|
||||
// Returns the sum of gradients corresponding to the data points that contains
|
||||
// a non-missing value for the particular feature fid.
|
||||
template <int d_step>
|
||||
GradStats EnumerateSplit(common::HistogramCuts const &cut, const common::GHistRow &hist,
|
||||
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::ConstGHistRow hist,
|
||||
bst_feature_t fidx, bst_node_t nidx,
|
||||
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||
SplitEntry *p_best) const {
|
||||
@@ -454,8 +454,8 @@ class HistEvaluator {
|
||||
right_child);
|
||||
}
|
||||
|
||||
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
|
||||
auto const& Stats() const { return snode_; }
|
||||
[[nodiscard]] auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
|
||||
[[nodiscard]] auto const &Stats() const { return snode_; }
|
||||
|
||||
float InitRoot(GradStats const &root_sum) {
|
||||
snode_.resize(1);
|
||||
@@ -510,7 +510,7 @@ class HistMultiEvaluator {
|
||||
|
||||
template <bst_bin_t d_step>
|
||||
bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
|
||||
common::Span<common::GHistRow const> hist,
|
||||
common::Span<common::ConstGHistRow> hist,
|
||||
linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
|
||||
SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
@@ -651,9 +651,9 @@ class HistMultiEvaluator {
|
||||
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
|
||||
auto best = &entry->split;
|
||||
auto parent_sum = stats_.Slice(entry->nid, linalg::All());
|
||||
std::vector<common::GHistRow> node_hist;
|
||||
std::vector<common::ConstGHistRow> node_hist;
|
||||
for (auto t_hist : hist) {
|
||||
node_hist.push_back((*t_hist)[entry->nid]);
|
||||
node_hist.emplace_back((*t_hist)[entry->nid]);
|
||||
}
|
||||
auto features_set = features[nidx_in_set]->ConstHostSpan();
|
||||
|
||||
|
||||
34
src/tree/hist/param.cc
Normal file
34
src/tree/hist/param.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#include "param.h"
|
||||
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../collective/communicator-inl.h" // for GetRank, Broadcast
|
||||
#include "xgboost/json.h" // for Object, Json
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost::tree {
|
||||
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
|
||||
|
||||
void HistMakerTrainParam::CheckTreesSynchronized(RegTree const* local_tree) const {
|
||||
if (!this->debug_synchronize) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string s_model;
|
||||
Json model{Object{}};
|
||||
int rank = collective::GetRank();
|
||||
if (rank == 0) {
|
||||
local_tree->SaveModel(&model);
|
||||
}
|
||||
Json::Dump(model, &s_model, std::ios::binary);
|
||||
collective::Broadcast(&s_model, 0);
|
||||
|
||||
RegTree ref_tree{}; // rank 0 tree
|
||||
auto j_ref_tree = Json::Load(StringView{s_model}, std::ios::binary);
|
||||
ref_tree.LoadModel(j_ref_tree);
|
||||
CHECK(*local_tree == ref_tree);
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
20
src/tree/hist/param.h
Normal file
20
src/tree/hist/param.h
Normal file
@@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/tree_model.h" // for RegTree
|
||||
|
||||
namespace xgboost::tree {
|
||||
struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
|
||||
bool debug_synchronize;
|
||||
void CheckTreesSynchronized(RegTree const* local_tree) const;
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(debug_synchronize)
|
||||
.set_default(false)
|
||||
.describe("Check if all distributed tree are identical after tree construction.");
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::tree
|
||||
@@ -14,13 +14,14 @@
|
||||
#include "driver.h"
|
||||
#include "hist/evaluate_splits.h"
|
||||
#include "hist/histogram.h"
|
||||
#include "hist/param.h"
|
||||
#include "hist/sampler.h" // for SampleGradient
|
||||
#include "param.h"
|
||||
#include "param.h" // for HistMakerTrainParam
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h" // for TreeUpdater
|
||||
|
||||
@@ -42,6 +43,7 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
|
||||
class GloablApproxBuilder {
|
||||
protected:
|
||||
TrainParam const *param_;
|
||||
HistMakerTrainParam const *hist_param_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||
HistEvaluator evaluator_;
|
||||
HistogramBuilder<CPUExpandEntry> histogram_builder_;
|
||||
@@ -168,10 +170,12 @@ class GloablApproxBuilder {
|
||||
}
|
||||
|
||||
public:
|
||||
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
|
||||
explicit GloablApproxBuilder(TrainParam const *param, HistMakerTrainParam const *hist_param,
|
||||
MetaInfo const &info, Context const *ctx,
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
ObjInfo const *task, common::Monitor *monitor)
|
||||
: param_{param},
|
||||
hist_param_{hist_param},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{ctx, param_, info, col_sampler_},
|
||||
ctx_{ctx},
|
||||
@@ -259,6 +263,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
ObjInfo const *task_;
|
||||
HistMakerTrainParam hist_param_;
|
||||
|
||||
public:
|
||||
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task)
|
||||
@@ -266,9 +271,15 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
monitor_.Init(__func__);
|
||||
}
|
||||
|
||||
void Configure(Args const &) override {}
|
||||
void LoadConfig(Json const &) override {}
|
||||
void SaveConfig(Json *) const override {}
|
||||
void Configure(Args const &args) override { hist_param_.UpdateAllowUnknown(args); }
|
||||
void LoadConfig(Json const &in) override {
|
||||
auto const &config = get<Object const>(in);
|
||||
FromJson(config.at("hist_train_param"), &hist_param_);
|
||||
}
|
||||
void SaveConfig(Json *p_out) const override {
|
||||
auto &out = *p_out;
|
||||
out["hist_train_param"] = ToJson(hist_param_);
|
||||
}
|
||||
|
||||
void InitData(TrainParam const ¶m, HostDeviceVector<GradientPair> const *gpair,
|
||||
linalg::Matrix<GradientPair> *sampled) {
|
||||
@@ -283,8 +294,9 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m,
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
pimpl_ = std::make_unique<GloablApproxBuilder>(param, m->Info(), ctx_, column_sampler_, task_,
|
||||
&monitor_);
|
||||
CHECK(hist_param_.GetInitialised());
|
||||
pimpl_ = std::make_unique<GloablApproxBuilder>(param, &hist_param_, m->Info(), ctx_,
|
||||
column_sampler_, task_, &monitor_);
|
||||
|
||||
linalg::Matrix<GradientPair> h_gpair;
|
||||
// Obtain the hessian values for weighted sketching
|
||||
@@ -299,6 +311,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
std::size_t t_idx = 0;
|
||||
for (auto p_tree : trees) {
|
||||
this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
|
||||
hist_param_.CheckTreesSynchronized(p_tree);
|
||||
++t_idx;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "gpu_hist/gradient_based_sampler.cuh"
|
||||
#include "gpu_hist/histogram.cuh"
|
||||
#include "gpu_hist/row_partitioner.cuh"
|
||||
#include "hist/param.h"
|
||||
#include "param.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "xgboost/base.h"
|
||||
@@ -47,37 +48,6 @@ namespace xgboost::tree {
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
// training parameters specific to this algorithm
|
||||
struct GPUHistMakerTrainParam : public XGBoostParameter<GPUHistMakerTrainParam> {
|
||||
bool debug_synchronize;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(debug_synchronize)
|
||||
.set_default(false)
|
||||
.describe("Check if all distributed tree are identical after tree construction.");
|
||||
}
|
||||
|
||||
// Only call this method for testing
|
||||
void CheckTreesSynchronized(RegTree const* local_tree) const {
|
||||
if (this->debug_synchronize) {
|
||||
std::string s_model;
|
||||
common::MemoryBufferStream fs(&s_model);
|
||||
int rank = collective::GetRank();
|
||||
if (rank == 0) {
|
||||
local_tree->Save(&fs);
|
||||
}
|
||||
fs.Seek(0);
|
||||
collective::Broadcast(&s_model, 0);
|
||||
RegTree reference_tree{}; // rank 0 tree
|
||||
reference_tree.Load(&fs);
|
||||
CHECK(*local_tree == reference_tree);
|
||||
}
|
||||
}
|
||||
};
|
||||
#if !defined(GTEST_TEST)
|
||||
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
/**
|
||||
* \struct DeviceHistogramStorage
|
||||
*
|
||||
@@ -777,12 +747,12 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
|
||||
FromJson(config.at("hist_train_param"), &this->hist_maker_param_);
|
||||
initialised_ = false;
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
|
||||
out["hist_train_param"] = ToJson(hist_maker_param_);
|
||||
}
|
||||
|
||||
~GPUHistMaker() { // NOLINT
|
||||
@@ -836,6 +806,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
monitor_.Stop("InitDataOnce");
|
||||
}
|
||||
p_last_tree_ = p_tree;
|
||||
CHECK(hist_maker_param_.GetInitialised());
|
||||
}
|
||||
|
||||
void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
@@ -869,7 +840,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
private:
|
||||
bool initialised_{false};
|
||||
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
HistMakerTrainParam hist_maker_param_;
|
||||
|
||||
DMatrix* p_last_fmat_{nullptr};
|
||||
RegTree const* p_last_tree_{nullptr};
|
||||
@@ -903,12 +874,12 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
FromJson(config.at("approx_train_param"), &this->hist_maker_param_);
|
||||
FromJson(config.at("hist_train_param"), &this->hist_maker_param_);
|
||||
initialised_ = false;
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["approx_train_param"] = ToJson(hist_maker_param_);
|
||||
out["hist_train_param"] = ToJson(hist_maker_param_);
|
||||
}
|
||||
~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); }
|
||||
|
||||
@@ -965,6 +936,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
void InitData(DMatrix* p_fmat, RegTree const* p_tree) {
|
||||
this->InitDataOnce(p_fmat);
|
||||
p_last_tree_ = p_tree;
|
||||
CHECK(hist_maker_param_.GetInitialised());
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
|
||||
@@ -994,7 +966,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
|
||||
private:
|
||||
bool initialised_{false};
|
||||
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
HistMakerTrainParam hist_maker_param_;
|
||||
dh::device_vector<float> hess_;
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
std::unique_ptr<GPUHistMakerDevice> maker_;
|
||||
|
||||
@@ -4,18 +4,17 @@
|
||||
* \brief use quantized feature values to construct a tree
|
||||
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
||||
*/
|
||||
#include <algorithm> // for max, copy, transform
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t, int32_t
|
||||
#include <memory> // for unique_ptr, allocator, make_unique, shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <ostream> // for basic_ostream, char_traits, operator<<
|
||||
#include <utility> // for move, swap
|
||||
#include <vector> // for vector
|
||||
#include <algorithm> // for max, copy, transform
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t, int32_t
|
||||
#include <memory> // for unique_ptr, allocator, make_unique, shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <ostream> // for basic_ostream, char_traits, operator<<
|
||||
#include <utility> // for move, swap
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/aggregator.h" // for GlobalSum
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
|
||||
#include "../common/linalg_op.h" // for begin, cbegin, cend
|
||||
#include "../common/random.h" // for ColumnSampler
|
||||
@@ -24,12 +23,12 @@
|
||||
#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter
|
||||
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
||||
#include "dmlc/omp.h" // for omp_get_thread_num
|
||||
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
||||
#include "driver.h" // for Driver
|
||||
#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre...
|
||||
#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||
#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace
|
||||
#include "hist/param.h" // for HistMakerTrainParam
|
||||
#include "hist/sampler.h" // for SampleGradient
|
||||
#include "param.h" // for TrainParam, SplitEntryContainer, GradStats
|
||||
#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ...
|
||||
@@ -117,6 +116,7 @@ class MultiTargetHistBuilder {
|
||||
private:
|
||||
common::Monitor *monitor_{nullptr};
|
||||
TrainParam const *param_{nullptr};
|
||||
HistMakerTrainParam const *hist_param_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||
std::unique_ptr<HistMultiEvaluator> evaluator_;
|
||||
// Histogram builder for each target.
|
||||
@@ -306,10 +306,12 @@ class MultiTargetHistBuilder {
|
||||
|
||||
public:
|
||||
explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param,
|
||||
HistMakerTrainParam const *hist_param,
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
ObjInfo const *task, common::Monitor *monitor)
|
||||
: monitor_{monitor},
|
||||
param_{param},
|
||||
hist_param_{hist_param},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{std::make_unique<HistMultiEvaluator>(ctx, info, param, col_sampler_)},
|
||||
ctx_{ctx},
|
||||
@@ -331,10 +333,14 @@ class MultiTargetHistBuilder {
|
||||
}
|
||||
};
|
||||
|
||||
class HistBuilder {
|
||||
/**
|
||||
* @brief Tree updater for single-target trees.
|
||||
*/
|
||||
class HistUpdater {
|
||||
private:
|
||||
common::Monitor *monitor_;
|
||||
TrainParam const *param_;
|
||||
HistMakerTrainParam const *hist_param_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||
std::unique_ptr<HistEvaluator> evaluator_;
|
||||
std::vector<CommonRowPartitioner> partitioner_;
|
||||
@@ -349,14 +355,14 @@ class HistBuilder {
|
||||
Context const *ctx_{nullptr};
|
||||
|
||||
public:
|
||||
explicit HistBuilder(Context const *ctx, std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
TrainParam const *param, DMatrix const *fmat, ObjInfo const *task,
|
||||
common::Monitor *monitor)
|
||||
explicit HistUpdater(Context const *ctx, std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
TrainParam const *param, HistMakerTrainParam const *hist_param,
|
||||
DMatrix const *fmat, ObjInfo const *task, common::Monitor *monitor)
|
||||
: monitor_{monitor},
|
||||
param_{param},
|
||||
hist_param_{hist_param},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{std::make_unique<HistEvaluator>(ctx, param, fmat->Info(),
|
||||
col_sampler_)},
|
||||
evaluator_{std::make_unique<HistEvaluator>(ctx, param, fmat->Info(), col_sampler_)},
|
||||
p_last_fmat_(fmat),
|
||||
histogram_builder_{new HistogramBuilder<CPUExpandEntry>},
|
||||
task_{task},
|
||||
@@ -529,7 +535,7 @@ class HistBuilder {
|
||||
std::vector<bst_node_t> *p_out_position) {
|
||||
monitor_->Start(__func__);
|
||||
if (!task_->UpdateTreeLeaf()) {
|
||||
monitor_->Stop(__func__);
|
||||
monitor_->Stop(__func__);
|
||||
return;
|
||||
}
|
||||
for (auto const &part : partitioner_) {
|
||||
@@ -541,20 +547,27 @@ class HistBuilder {
|
||||
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker : public TreeUpdater {
|
||||
std::unique_ptr<HistBuilder> p_impl_{nullptr};
|
||||
std::unique_ptr<HistUpdater> p_impl_{nullptr};
|
||||
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
common::Monitor monitor_;
|
||||
ObjInfo const *task_{nullptr};
|
||||
HistMakerTrainParam hist_param_;
|
||||
|
||||
public:
|
||||
explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task)
|
||||
: TreeUpdater{ctx}, task_{task} {}
|
||||
void Configure(const Args &) override {}
|
||||
|
||||
void LoadConfig(Json const &) override {}
|
||||
void SaveConfig(Json *) const override {}
|
||||
void Configure(Args const &args) override { hist_param_.UpdateAllowUnknown(args); }
|
||||
void LoadConfig(Json const &in) override {
|
||||
auto const &config = get<Object const>(in);
|
||||
FromJson(config.at("hist_train_param"), &hist_param_);
|
||||
}
|
||||
void SaveConfig(Json *p_out) const override {
|
||||
auto &out = *p_out;
|
||||
out["hist_train_param"] = ToJson(hist_param_);
|
||||
}
|
||||
|
||||
[[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; }
|
||||
|
||||
@@ -562,15 +575,17 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
if (trees.front()->IsMultiTarget()) {
|
||||
CHECK(hist_param_.GetInitialised());
|
||||
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
|
||||
if (!p_mtimpl_) {
|
||||
this->p_mtimpl_ = std::make_unique<MultiTargetHistBuilder>(
|
||||
ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_);
|
||||
ctx_, p_fmat->Info(), param, &hist_param_, column_sampler_, task_, &monitor_);
|
||||
}
|
||||
} else {
|
||||
CHECK(hist_param_.GetInitialised());
|
||||
if (!p_impl_) {
|
||||
p_impl_ =
|
||||
std::make_unique<HistBuilder>(ctx_, column_sampler_, param, p_fmat, task_, &monitor_);
|
||||
p_impl_ = std::make_unique<HistUpdater>(ctx_, column_sampler_, param, &hist_param_, p_fmat,
|
||||
task_, &monitor_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -601,6 +616,8 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
UpdateTree<CPUExpandEntry>(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param,
|
||||
h_out_position, *tree_it);
|
||||
}
|
||||
|
||||
hist_param_.CheckTreesSynchronized(*tree_it);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user