Small cleanup for histogram routines. (#9427)

* Small cleanup for histogram routines.

- Extract hist train param from GPU hist.
- Make histogram const after construction.
- Unify parameter names.
This commit is contained in:
Jiaming Yuan 2023-08-02 18:28:26 +08:00 committed by GitHub
parent c2b85ab68a
commit e93a274823
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 182 additions and 111 deletions

View File

@ -68,6 +68,7 @@ OBJECTS= \
$(PKGROOT)/src/tree/updater_quantile_hist.o \
$(PKGROOT)/src/tree/updater_refresh.o \
$(PKGROOT)/src/tree/updater_sync.o \
$(PKGROOT)/src/tree/hist/param.o \
$(PKGROOT)/src/linear/linear_updater.o \
$(PKGROOT)/src/linear/updater_coordinate.o \
$(PKGROOT)/src/linear/updater_shotgun.o \

View File

@ -68,6 +68,7 @@ OBJECTS= \
$(PKGROOT)/src/tree/updater_quantile_hist.o \
$(PKGROOT)/src/tree/updater_refresh.o \
$(PKGROOT)/src/tree/updater_sync.o \
$(PKGROOT)/src/tree/hist/param.o \
$(PKGROOT)/src/linear/linear_updater.o \
$(PKGROOT)/src/linear/updater_coordinate.o \
$(PKGROOT)/src/linear/updater_shotgun.o \

View File

@ -574,7 +574,9 @@ template <typename Container, typename... S,
std::enable_if_t<!common::detail::IsSpan<Container>::value &&
!std::is_pointer_v<Container>> * = nullptr>
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT
using T = typename Container::value_type;
using T = std::conditional_t<std::is_const_v<Container>,
std::add_const_t<typename Container::value_type>,
typename Container::value_type>;
std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->gpu_id};

View File

@ -81,11 +81,11 @@ void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
/*!
* \brief Increment hist as dst += add in range [begin, end)
*/
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) {
double* pdst = reinterpret_cast<double*>(dst.data());
void IncrementHist(GHistRow dst, ConstGHistRow add, std::size_t begin, std::size_t end) {
double *pdst = reinterpret_cast<double *>(dst.data());
const double *padd = reinterpret_cast<const double *>(add.data());
for (size_t i = 2 * begin; i < 2 * end; ++i) {
for (std::size_t i = 2 * begin; i < 2 * end; ++i) {
pdst[i] += padd[i];
}
}
@ -207,18 +207,23 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
const size_t size = row_indices.Size();
const size_t *rid = row_indices.begin;
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
auto const *p_gpair = reinterpret_cast<const float *>(gpair.data());
const BinIdxType *gradient_index = gmat.index.data<BinIdxType>();
auto const &row_ptr = gmat.row_ptr.data();
auto base_rowid = gmat.base_rowid;
const uint32_t *offsets = gmat.index.Offset();
auto get_row_ptr = [&](size_t ridx) {
uint32_t const *offsets = gmat.index.Offset();
// There's no feature-based compression if missing value is present.
if (kAnyMissing) {
CHECK(!offsets);
} else {
CHECK(offsets);
}
auto get_row_ptr = [&](bst_row_t ridx) {
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
};
auto get_rid = [&](size_t ridx) {
return kFirstPage ? ridx : (ridx - base_rowid);
};
auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); };
const size_t n_features =
get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]);
@ -228,7 +233,7 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
// So we need to multiply each row-index/bin-index by 2
// to work with gradient pairs as a singe row FP array
for (size_t i = 0; i < size; ++i) {
for (std::size_t i = 0; i < size; ++i) {
const size_t icol_start =
kAnyMissing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
const size_t icol_end =
@ -246,7 +251,7 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
: icol_start_prefetch + n_features;
PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
PREFETCH_READ_T0(p_gpair + two * rid[i + Prefetch::kPrefetchOffset]);
for (size_t j = icol_start_prefetch; j < icol_end_prefetch;
j += Prefetch::GetPrefetchStep<uint32_t>()) {
PREFETCH_READ_T0(gradient_index + j);
@ -255,10 +260,10 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
const BinIdxType *gr_index_local = gradient_index + icol_start;
// The trick with pgh_t buffer helps the compiler to generate faster binary.
const float pgh_t[] = {pgh[idx_gh], pgh[idx_gh + 1]};
const float pgh_t[] = {p_gpair[idx_gh], p_gpair[idx_gh + 1]};
for (size_t j = 0; j < row_size; ++j) {
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) +
(kAnyMissing ? 0 : offsets[j]));
const uint32_t idx_bin =
two * (static_cast<uint32_t>(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j]));
auto hist_local = hist_data + idx_bin;
*(hist_local) += pgh_t[0];
*(hist_local + 1) += pgh_t[1];
@ -281,12 +286,10 @@ void ColsWiseBuildHistKernel(Span<GradientPair const> gpair,
auto const &row_ptr = gmat.row_ptr.data();
auto base_rowid = gmat.base_rowid;
const uint32_t *offsets = gmat.index.Offset();
auto get_row_ptr = [&](size_t ridx) {
auto get_row_ptr = [&](bst_row_t ridx) {
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
};
auto get_rid = [&](size_t ridx) {
return kFirstPage ? ridx : (ridx - base_rowid);
};
auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); };
const size_t n_features = gmat.cut.Ptrs().size() - 1;
const size_t n_columns = n_features;

View File

@ -362,6 +362,7 @@ bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t
}
using GHistRow = Span<xgboost::GradientPairPrecise>;
using ConstGHistRow = Span<xgboost::GradientPairPrecise const>;
/*!
* \brief fill a histogram by zeros
@ -371,7 +372,7 @@ void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end);
/*!
* \brief Increment hist as dst += add in range [begin, end)
*/
void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end);
void IncrementHist(GHistRow dst, ConstGHistRow add, std::size_t begin, std::size_t end);
/*!
* \brief Copy hist from src to dst in range [begin, end)

View File

@ -136,7 +136,7 @@ class BlockedSpace2d {
// Wrapper to implement nested parallelism with simple omp parallel for
template <typename Func>
void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
const size_t num_blocks_in_space = space.Size();
std::size_t n_blocks_in_space = space.Size();
CHECK_GE(nthreads, 1);
dmlc::OMPException exc;
@ -144,11 +144,10 @@ void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
{
exc.Run([&]() {
size_t tid = omp_get_thread_num();
size_t chunck_size =
num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads);
size_t chunck_size = n_blocks_in_space / nthreads + !!(n_blocks_in_space % nthreads);
size_t begin = chunck_size * tid;
size_t end = std::min(begin + chunck_size, num_blocks_in_space);
std::size_t begin = chunck_size * tid;
std::size_t end = std::min(begin + chunck_size, n_blocks_in_space);
for (auto i = begin; i < end; i++) {
func(space.GetFirstDimension(i), space.GetRange(i));
}

View File

@ -65,7 +65,7 @@ class HistEvaluator {
* pseudo-category for missing value but here we just do a complete scan to avoid
* making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
void EnumerateOneHot(common::HistogramCuts const &cut, common::ConstGHistRow hist,
bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
@ -143,7 +143,7 @@ class HistEvaluator {
*/
template <int d_step>
void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
common::GHistRow const &hist, bst_feature_t fidx, bst_node_t nidx,
common::ConstGHistRow hist, bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
@ -214,7 +214,7 @@ class HistEvaluator {
// Returns the sum of gradients corresponding to the data points that contains
// a non-missing value for the particular feature fid.
template <int d_step>
GradStats EnumerateSplit(common::HistogramCuts const &cut, const common::GHistRow &hist,
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::ConstGHistRow hist,
bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
@ -454,8 +454,8 @@ class HistEvaluator {
right_child);
}
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
auto const& Stats() const { return snode_; }
[[nodiscard]] auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
[[nodiscard]] auto const &Stats() const { return snode_; }
float InitRoot(GradStats const &root_sum) {
snode_.resize(1);
@ -510,7 +510,7 @@ class HistMultiEvaluator {
template <bst_bin_t d_step>
bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
common::Span<common::GHistRow const> hist,
common::Span<common::ConstGHistRow> hist,
linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
auto const &cut_ptr = cut.Ptrs();
@ -651,9 +651,9 @@ class HistMultiEvaluator {
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto parent_sum = stats_.Slice(entry->nid, linalg::All());
std::vector<common::GHistRow> node_hist;
std::vector<common::ConstGHistRow> node_hist;
for (auto t_hist : hist) {
node_hist.push_back((*t_hist)[entry->nid]);
node_hist.emplace_back((*t_hist)[entry->nid]);
}
auto features_set = features[nidx_in_set]->ConstHostSpan();

34
src/tree/hist/param.cc Normal file
View File

@ -0,0 +1,34 @@
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include "param.h"
#include <string> // for string
#include "../../collective/communicator-inl.h" // for GetRank, Broadcast
#include "xgboost/json.h" // for Object, Json
#include "xgboost/tree_model.h" // for RegTree
namespace xgboost::tree {
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
void HistMakerTrainParam::CheckTreesSynchronized(RegTree const* local_tree) const {
if (!this->debug_synchronize) {
return;
}
std::string s_model;
Json model{Object{}};
int rank = collective::GetRank();
if (rank == 0) {
local_tree->SaveModel(&model);
}
Json::Dump(model, &s_model, std::ios::binary);
collective::Broadcast(&s_model, 0);
RegTree ref_tree{}; // rank 0 tree
auto j_ref_tree = Json::Load(StringView{s_model}, std::ios::binary);
ref_tree.LoadModel(j_ref_tree);
CHECK(*local_tree == ref_tree);
}
} // namespace xgboost::tree

20
src/tree/hist/param.h Normal file
View File

@ -0,0 +1,20 @@
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#pragma once
#include "xgboost/parameter.h"
#include "xgboost/tree_model.h" // for RegTree
namespace xgboost::tree {
struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
bool debug_synchronize;
void CheckTreesSynchronized(RegTree const* local_tree) const;
// declare parameters
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
DMLC_DECLARE_FIELD(debug_synchronize)
.set_default(false)
.describe("Check if all distributed tree are identical after tree construction.");
}
};
} // namespace xgboost::tree

View File

@ -14,8 +14,9 @@
#include "driver.h"
#include "hist/evaluate_splits.h"
#include "hist/histogram.h"
#include "hist/param.h"
#include "hist/sampler.h" // for SampleGradient
#include "param.h"
#include "param.h" // for HistMakerTrainParam
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
@ -42,6 +43,7 @@ auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
class GloablApproxBuilder {
protected:
TrainParam const *param_;
HistMakerTrainParam const *hist_param_{nullptr};
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator evaluator_;
HistogramBuilder<CPUExpandEntry> histogram_builder_;
@ -168,10 +170,12 @@ class GloablApproxBuilder {
}
public:
explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx,
explicit GloablApproxBuilder(TrainParam const *param, HistMakerTrainParam const *hist_param,
MetaInfo const &info, Context const *ctx,
std::shared_ptr<common::ColumnSampler> column_sampler,
ObjInfo const *task, common::Monitor *monitor)
: param_{param},
hist_param_{hist_param},
col_sampler_{std::move(column_sampler)},
evaluator_{ctx, param_, info, col_sampler_},
ctx_{ctx},
@ -259,6 +263,7 @@ class GlobalApproxUpdater : public TreeUpdater {
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
ObjInfo const *task_;
HistMakerTrainParam hist_param_;
public:
explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task)
@ -266,9 +271,15 @@ class GlobalApproxUpdater : public TreeUpdater {
monitor_.Init(__func__);
}
void Configure(Args const &) override {}
void LoadConfig(Json const &) override {}
void SaveConfig(Json *) const override {}
void Configure(Args const &args) override { hist_param_.UpdateAllowUnknown(args); }
void LoadConfig(Json const &in) override {
auto const &config = get<Object const>(in);
FromJson(config.at("hist_train_param"), &hist_param_);
}
void SaveConfig(Json *p_out) const override {
auto &out = *p_out;
out["hist_train_param"] = ToJson(hist_param_);
}
void InitData(TrainParam const &param, HostDeviceVector<GradientPair> const *gpair,
linalg::Matrix<GradientPair> *sampled) {
@ -283,8 +294,9 @@ class GlobalApproxUpdater : public TreeUpdater {
void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
pimpl_ = std::make_unique<GloablApproxBuilder>(param, m->Info(), ctx_, column_sampler_, task_,
&monitor_);
CHECK(hist_param_.GetInitialised());
pimpl_ = std::make_unique<GloablApproxBuilder>(param, &hist_param_, m->Info(), ctx_,
column_sampler_, task_, &monitor_);
linalg::Matrix<GradientPair> h_gpair;
// Obtain the hessian values for weighted sketching
@ -299,6 +311,7 @@ class GlobalApproxUpdater : public TreeUpdater {
std::size_t t_idx = 0;
for (auto p_tree : trees) {
this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
hist_param_.CheckTreesSynchronized(p_tree);
++t_idx;
}
}

View File

@ -30,6 +30,7 @@
#include "gpu_hist/gradient_based_sampler.cuh"
#include "gpu_hist/histogram.cuh"
#include "gpu_hist/row_partitioner.cuh"
#include "hist/param.h"
#include "param.h"
#include "updater_gpu_common.cuh"
#include "xgboost/base.h"
@ -47,37 +48,6 @@ namespace xgboost::tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
#endif // !defined(GTEST_TEST)
// training parameters specific to this algorithm
struct GPUHistMakerTrainParam : public XGBoostParameter<GPUHistMakerTrainParam> {
bool debug_synchronize;
// declare parameters
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
DMLC_DECLARE_FIELD(debug_synchronize)
.set_default(false)
.describe("Check if all distributed tree are identical after tree construction.");
}
// Only call this method for testing
void CheckTreesSynchronized(RegTree const* local_tree) const {
if (this->debug_synchronize) {
std::string s_model;
common::MemoryBufferStream fs(&s_model);
int rank = collective::GetRank();
if (rank == 0) {
local_tree->Save(&fs);
}
fs.Seek(0);
collective::Broadcast(&s_model, 0);
RegTree reference_tree{}; // rank 0 tree
reference_tree.Load(&fs);
CHECK(*local_tree == reference_tree);
}
}
};
#if !defined(GTEST_TEST)
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
#endif // !defined(GTEST_TEST)
/**
* \struct DeviceHistogramStorage
*
@ -777,12 +747,12 @@ class GPUHistMaker : public TreeUpdater {
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_);
FromJson(config.at("hist_train_param"), &this->hist_maker_param_);
initialised_ = false;
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["gpu_hist_train_param"] = ToJson(hist_maker_param_);
out["hist_train_param"] = ToJson(hist_maker_param_);
}
~GPUHistMaker() { // NOLINT
@ -836,6 +806,7 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Stop("InitDataOnce");
}
p_last_tree_ = p_tree;
CHECK(hist_maker_param_.GetInitialised());
}
void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
@ -869,7 +840,7 @@ class GPUHistMaker : public TreeUpdater {
private:
bool initialised_{false};
GPUHistMakerTrainParam hist_maker_param_;
HistMakerTrainParam hist_maker_param_;
DMatrix* p_last_fmat_{nullptr};
RegTree const* p_last_tree_{nullptr};
@ -903,12 +874,12 @@ class GPUGlobalApproxMaker : public TreeUpdater {
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("approx_train_param"), &this->hist_maker_param_);
FromJson(config.at("hist_train_param"), &this->hist_maker_param_);
initialised_ = false;
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["approx_train_param"] = ToJson(hist_maker_param_);
out["hist_train_param"] = ToJson(hist_maker_param_);
}
~GPUGlobalApproxMaker() override { dh::GlobalMemoryLogger().Log(); }
@ -965,6 +936,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
void InitData(DMatrix* p_fmat, RegTree const* p_tree) {
this->InitDataOnce(p_fmat);
p_last_tree_ = p_tree;
CHECK(hist_maker_param_.GetInitialised());
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree,
@ -994,7 +966,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
private:
bool initialised_{false};
GPUHistMakerTrainParam hist_maker_param_;
HistMakerTrainParam hist_maker_param_;
dh::device_vector<float> hess_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
std::unique_ptr<GPUHistMakerDevice> maker_;

View File

@ -15,7 +15,6 @@
#include "../collective/aggregator.h" // for GlobalSum
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
#include "../collective/communicator.h" // for Operation
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
#include "../common/linalg_op.h" // for begin, cbegin, cend
#include "../common/random.h" // for ColumnSampler
@ -24,12 +23,12 @@
#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter
#include "../data/gradient_index.h" // for GHistIndexMatrix
#include "common_row_partitioner.h" // for CommonRowPartitioner
#include "dmlc/omp.h" // for omp_get_thread_num
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
#include "driver.h" // for Driver
#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre...
#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace
#include "hist/param.h" // for HistMakerTrainParam
#include "hist/sampler.h" // for SampleGradient
#include "param.h" // for TrainParam, SplitEntryContainer, GradStats
#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ...
@ -117,6 +116,7 @@ class MultiTargetHistBuilder {
private:
common::Monitor *monitor_{nullptr};
TrainParam const *param_{nullptr};
HistMakerTrainParam const *hist_param_{nullptr};
std::shared_ptr<common::ColumnSampler> col_sampler_;
std::unique_ptr<HistMultiEvaluator> evaluator_;
// Histogram builder for each target.
@ -306,10 +306,12 @@ class MultiTargetHistBuilder {
public:
explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param,
HistMakerTrainParam const *hist_param,
std::shared_ptr<common::ColumnSampler> column_sampler,
ObjInfo const *task, common::Monitor *monitor)
: monitor_{monitor},
param_{param},
hist_param_{hist_param},
col_sampler_{std::move(column_sampler)},
evaluator_{std::make_unique<HistMultiEvaluator>(ctx, info, param, col_sampler_)},
ctx_{ctx},
@ -331,10 +333,14 @@ class MultiTargetHistBuilder {
}
};
class HistBuilder {
/**
* @brief Tree updater for single-target trees.
*/
class HistUpdater {
private:
common::Monitor *monitor_;
TrainParam const *param_;
HistMakerTrainParam const *hist_param_{nullptr};
std::shared_ptr<common::ColumnSampler> col_sampler_;
std::unique_ptr<HistEvaluator> evaluator_;
std::vector<CommonRowPartitioner> partitioner_;
@ -349,14 +355,14 @@ class HistBuilder {
Context const *ctx_{nullptr};
public:
explicit HistBuilder(Context const *ctx, std::shared_ptr<common::ColumnSampler> column_sampler,
TrainParam const *param, DMatrix const *fmat, ObjInfo const *task,
common::Monitor *monitor)
explicit HistUpdater(Context const *ctx, std::shared_ptr<common::ColumnSampler> column_sampler,
TrainParam const *param, HistMakerTrainParam const *hist_param,
DMatrix const *fmat, ObjInfo const *task, common::Monitor *monitor)
: monitor_{monitor},
param_{param},
hist_param_{hist_param},
col_sampler_{std::move(column_sampler)},
evaluator_{std::make_unique<HistEvaluator>(ctx, param, fmat->Info(),
col_sampler_)},
evaluator_{std::make_unique<HistEvaluator>(ctx, param, fmat->Info(), col_sampler_)},
p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<CPUExpandEntry>},
task_{task},
@ -541,20 +547,27 @@ class HistBuilder {
/*! \brief construct a tree using quantized feature values */
class QuantileHistMaker : public TreeUpdater {
std::unique_ptr<HistBuilder> p_impl_{nullptr};
std::unique_ptr<HistUpdater> p_impl_{nullptr};
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
common::Monitor monitor_;
ObjInfo const *task_{nullptr};
HistMakerTrainParam hist_param_;
public:
explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task)
: TreeUpdater{ctx}, task_{task} {}
void Configure(const Args &) override {}
void LoadConfig(Json const &) override {}
void SaveConfig(Json *) const override {}
void Configure(Args const &args) override { hist_param_.UpdateAllowUnknown(args); }
void LoadConfig(Json const &in) override {
auto const &config = get<Object const>(in);
FromJson(config.at("hist_train_param"), &hist_param_);
}
void SaveConfig(Json *p_out) const override {
auto &out = *p_out;
out["hist_train_param"] = ToJson(hist_param_);
}
[[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; }
@ -562,15 +575,17 @@ class QuantileHistMaker : public TreeUpdater {
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
if (trees.front()->IsMultiTarget()) {
CHECK(hist_param_.GetInitialised());
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
if (!p_mtimpl_) {
this->p_mtimpl_ = std::make_unique<MultiTargetHistBuilder>(
ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_);
ctx_, p_fmat->Info(), param, &hist_param_, column_sampler_, task_, &monitor_);
}
} else {
CHECK(hist_param_.GetInitialised());
if (!p_impl_) {
p_impl_ =
std::make_unique<HistBuilder>(ctx_, column_sampler_, param, p_fmat, task_, &monitor_);
p_impl_ = std::make_unique<HistUpdater>(ctx_, column_sampler_, param, &hist_param_, p_fmat,
task_, &monitor_);
}
}
@ -601,6 +616,8 @@ class QuantileHistMaker : public TreeUpdater {
UpdateTree<CPUExpandEntry>(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param,
h_out_position, *tree_it);
}
hist_param_.CheckTreesSynchronized(*tree_it);
}
}

View File

@ -105,13 +105,13 @@ void TestBuildHist(bool use_shared_memory_histograms) {
gpair.SetDevice(0);
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
maker.row_partitioner = std::make_unique<RowPartitioner>(0, kNRows);
maker.hist.Init(0, page->Cuts().TotalBins());
maker.hist.AllocateHistograms({0});
maker.gpair = gpair.DeviceSpan();
maker.quantiser.reset(new GradientQuantiser(maker.gpair));
maker.quantiser = std::make_unique<GradientQuantiser>(maker.gpair);
maker.page = page.get();
maker.InitFeatureGroupsOnce();
@ -246,6 +246,7 @@ void UpdateTree(Context const* ctx, HostDeviceVector<GradientPair>* gpair, DMatr
ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{ctx, &task};
hist_maker.Configure(Args{});
std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
@ -397,14 +398,14 @@ TEST(GpuHist, ConfigIO) {
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
updater->Configure(Args{});
Json j_updater { Object() };
Json j_updater{Object{}};
updater->SaveConfig(&j_updater);
ASSERT_TRUE(IsA<Object>(j_updater["gpu_hist_train_param"]));
ASSERT_TRUE(IsA<Object>(j_updater["hist_train_param"]));
updater->LoadConfig(j_updater);
Json j_updater_roundtrip { Object() };
Json j_updater_roundtrip{Object{}};
updater->SaveConfig(&j_updater_roundtrip);
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["gpu_hist_train_param"]));
ASSERT_TRUE(IsA<Object>(j_updater_roundtrip["hist_train_param"]));
ASSERT_EQ(j_updater, j_updater_roundtrip);
}

View File

@ -39,6 +39,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
param.UpdateAllowUnknown(
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 4);
@ -55,6 +56,7 @@ TEST(GrowHistMaker, InteractionConstraint) {
std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 10);
@ -81,6 +83,7 @@ void VerifyColumnSplit(int32_t rows, bst_feature_t cols, bool categorical,
RegTree tree{1u, cols};
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), sliced.get(), position, {&tree});
Json json{Object{}};
@ -104,6 +107,7 @@ void TestColumnSplit(bool categorical) {
std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), p_dmat.get(), position, {&expected_tree});
}

View File

@ -73,6 +73,7 @@ class TestPredictionCache : public ::testing::Test {
tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_bin", "64"}});
updater->Configure(Args{});
std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Update(&param, &gpair, Xy_.get(), position, trees);
HostDeviceVector<float> out_prediction_cached;

View File

@ -13,7 +13,6 @@
#include "../../../src/tree/common_row_partitioner.h"
#include "../../../src/tree/hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "../../../src/tree/param.h"
#include "../../../src/tree/split_evaluator.h"
#include "../helpers.h"
#include "test_partitioner.h"
#include "xgboost/data.h"
@ -49,7 +48,7 @@ void TestPartitioner(bst_target_t n_targets) {
auto min_value = gmat.cut.MinValues()[split_ind];
RegTree tree{n_targets, n_features};
CommonRowPartitioner partitioner{&ctx, n_samples, base_rowid, false};
if constexpr (std::is_same<ExpandEntry, CPUExpandEntry>::value) {
if constexpr (std::is_same_v<ExpandEntry, CPUExpandEntry>) {
GetSplit(&tree, min_value, &candidates);
} else {
GetMultiSplitForTest(&tree, min_value, &candidates);
@ -217,6 +216,7 @@ void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_target
RegTree tree{n_targets, cols};
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), sliced.get(), position, {&tree});
Json json{Object{}};
@ -241,6 +241,7 @@ void TestColumnSplit(bst_target_t n_targets) {
std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), Xy.get(), position, {&expected_tree});
}

View File

@ -1459,6 +1459,7 @@ class TestWithDask:
tree_method: str,
) -> None:
params["tree_method"] = tree_method
params["debug_synchronize"] = True
params = dataset.set_params(params)
# It doesn't make sense to distribute a completely
# empty dataset.