Small cleanup to hist tree method. (#7735)

* Remove special optimization using number of bins.
* Remove 1-based index for column sampling.
* Remove data layout.
* Unify update prediction cache.
This commit is contained in:
Jiaming Yuan 2022-03-20 03:44:55 +08:00 committed by GitHub
parent 718472dbe2
commit 996cc705af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 140 additions and 205 deletions

View File

@ -156,9 +156,8 @@ class ColumnSampler {
* \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0.
*/
void Init(int64_t num_col, std::vector<float> feature_weights,
float colsample_bynode, float colsample_bylevel,
float colsample_bytree, bool skip_index_0 = false) {
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
float colsample_bylevel, float colsample_bytree) {
feature_weights_ = std::move(feature_weights);
colsample_bylevel_ = colsample_bylevel;
colsample_bytree_ = colsample_bytree;
@ -169,10 +168,8 @@ class ColumnSampler {
}
Reset();
int begin_idx = skip_index_0 ? 1 : 0;
feature_set_tree_->Resize(num_col - begin_idx);
std::iota(feature_set_tree_->HostVector().begin(),
feature_set_tree_->HostVector().end(), begin_idx);
feature_set_tree_->Resize(num_col);
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
}

View File

@ -55,8 +55,6 @@ class RowSetCollection {
/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr)
<< "access element that is not in the set";
return e;
}
@ -75,14 +73,10 @@ class RowSetCollection {
CHECK_EQ(elem_of_each_node_.size(), 0U);
if (row_indices_.empty()) { // edge case: empty instance set
// assign arbitrary address here, to bypass nullptr check
// (nullptr usually indicates a nonexistent rowset, but we want to
// indicate a valid rowset that happens to have zero length and occupies
// the whole instance set)
// this is okay, as BuildHist will compute (end-begin) as the set size
const size_t* begin = reinterpret_cast<size_t*>(20);
const size_t* end = begin;
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
constexpr size_t* kBegin = nullptr;
constexpr size_t* kEnd = nullptr;
static_assert(kEnd - kBegin == 0, "");
elem_of_each_node_.emplace_back(Elem(kBegin, kEnd, 0));
return;
}
@ -93,15 +87,19 @@ class RowSetCollection {
std::vector<size_t>* Data() { return &row_indices_; }
// split rowset into two
inline void AddSplit(unsigned node_id,
unsigned left_node_id,
unsigned right_node_id,
size_t n_left,
size_t n_right) {
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
size_t n_left, size_t n_right) {
const Elem e = elem_of_each_node_[node_id];
CHECK(e.begin != nullptr);
size_t* all_begin = dmlc::BeginPtr(row_indices_);
size_t* begin = all_begin + (e.begin - all_begin);
size_t* all_begin{nullptr};
size_t* begin{nullptr};
if (e.begin == nullptr) {
CHECK_EQ(n_left, 0);
CHECK_EQ(n_right, 0);
} else {
all_begin = dmlc::BeginPtr(row_indices_);
begin = all_begin + (e.begin - all_begin);
}
CHECK_EQ(n_left + n_right, e.Size());
CHECK_LE(begin + n_left, e.end);

View File

@ -266,6 +266,9 @@ class MemStackAllocator {
throw std::bad_alloc{};
}
}
MemStackAllocator(size_t required_size, T init) : MemStackAllocator{required_size} {
std::fill_n(ptr_, required_size_, init);
}
~MemStackAllocator() {
if (required_size_ > MaxStackSize) {

View File

@ -363,19 +363,54 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(TrainParam const &param, MetaInfo const &info, int32_t n_threads,
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task,
bool skip_0_index = false)
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task)
: param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
n_threads_{n_threads},
task_{task} {
interaction_constraints_.Configure(param, info.num_col_);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree, skip_0_index);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
}
};
} // namespace tree
} // namespace xgboost
/**
* \brief CPU implementation of update prediction cache, which calculates the leaf value
* for the last tree and accumulates it to prediction vector.
*
* \param p_last_tree The last tree being updated by tree updater
*/
template <typename Partitioner, typename GradientSumT, typename ExpandEntry>
void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_last_tree,
std::vector<Partitioner> const &partitioner,
HistEvaluator<GradientSumT, ExpandEntry> const &hist_evaluator,
TrainParam const &param, linalg::VectorView<float> out_preds) {
CHECK_GT(out_preds.Size(), 0U);
CHECK(p_last_tree);
auto const &tree = *p_last_tree;
auto const &snode = hist_evaluator.Stats();
auto evaluator = hist_evaluator.Evaluator();
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
size_t n_nodes = p_last_tree->GetNodes().size();
for (auto &part : partitioner) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
auto const &rowset = part[nidx];
auto const &stats = snode[nidx];
auto leaf_value =
evaluator.CalcWeight(nidx, param, GradStats{stats.stats}) * param.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
}
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_

View File

@ -114,34 +114,12 @@ class GloablApproxBuilder {
return nodes.front();
}
void UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) {
void UpdatePredictionCache(DMatrix const *data, linalg::VectorView<float> out_preds) const {
monitor_->Start(__func__);
// Caching prediction seems redundant for approx tree method, as sketching takes up
// majority of training time.
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
CHECK(p_last_tree_);
size_t n_nodes = p_last_tree_->GetNodes().size();
auto evaluator = evaluator_.Evaluator();
auto const &tree = *p_last_tree_;
auto const &snode = evaluator_.Stats();
for (auto &part : partitioner_) {
CHECK_EQ(part.Size(), n_nodes);
common::BlockedSpace2d space(
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) {
if (tree[nidx].IsLeaf()) {
const auto rowset = part[nidx];
auto const &stats = snode.at(nidx);
auto leaf_value =
evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate;
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
}
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, evaluator_, param_, out_preds);
monitor_->Stop(__func__);
}

View File

@ -101,18 +101,17 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
p_last_dmat_ = dmat;
}
bool QuantileHistMaker::UpdatePredictionCache(
const DMatrix* data, linalg::VectorView<float> out_preds) {
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
linalg::VectorView<float> out_preds) {
if (hist_maker_param_.single_precision_histogram && float_builder_) {
return float_builder_->UpdatePredictionCache(data, out_preds);
return float_builder_->UpdatePredictionCache(data, out_preds);
} else if (double_builder_) {
return double_builder_->UpdatePredictionCache(data, out_preds);
return double_builder_->UpdatePredictionCache(data, out_preds);
} else {
return false;
return false;
}
}
template <typename GradientSumT>
template <bool any_missing>
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
@ -135,27 +134,29 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
}
{
auto nid = RegTree::kRoot;
auto hist = this->histogram_builder_->Histogram()[nid];
GradientPairT grad_stat;
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
data_layout_ == DataLayout::kDenseDataOneBased) {
auto const& gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
const std::vector<uint32_t> &row_ptr = gmat.cut.Ptrs();
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
if (p_fmat->IsDense()) {
/**
* Specialized code for dense data: For dense data (with no missing value), the sum
* of gradient histogram is equal to snode[nid]
*/
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
std::vector<uint32_t> const &row_ptr = gmat.cut.Ptrs();
CHECK_GE(row_ptr.size(), 2);
uint32_t const ibegin = row_ptr[0];
uint32_t const iend = row_ptr[1];
auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot];
auto begin = hist.data();
for (uint32_t i = ibegin; i < iend; ++i) {
const GradientPairT et = begin[i];
GradientPairT const &et = begin[i];
grad_stat.Add(et.GetGrad(), et.GetHess());
}
} else {
const common::RowSetCollection::Elem e = row_set_collection[nid];
for (const size_t *it = e.begin; it < e.end; ++it) {
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
for (auto const &grad : gpair_h) {
grad_stat.Add(grad.GetGrad(), grad.GetHess());
}
rabit::Allreduce<rabit::op::Sum, GradientSumT>(
reinterpret_cast<GradientSumT *>(&grad_stat), 2);
rabit::Allreduce<rabit::op::Sum, GradientSumT>(reinterpret_cast<GradientSumT *>(&grad_stat),
2);
}
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
@ -164,14 +165,14 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
std::vector<CPUExpandEntry> entries{node};
builder_monitor_.Start("EvaluateSplits");
builder_monitor_->Start("EvaluateSplits");
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
for (auto const& gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft,
*p_tree, &entries);
break;
}
builder_monitor_.Stop("EvaluateSplits");
builder_monitor_->Stop("EvaluateSplits");
node = entries.front();
}
@ -204,7 +205,7 @@ template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
const std::vector<CPUExpandEntry> &nodes_for_apply_split,
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
builder_monitor_.Start("SplitSiblings");
builder_monitor_->Start("SplitSiblings");
auto const& row_set_collection = this->partitioner_.front().Partitions();
for (auto const& entry : nodes_for_apply_split) {
int nid = entry.nid;
@ -224,7 +225,7 @@ void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
}
}
CHECK_EQ(nodes_for_subtraction_trick_.size(), nodes_for_explicit_hist_build_.size());
builder_monitor_.Stop("SplitSiblings");
builder_monitor_->Stop("SplitSiblings");
}
template<typename GradientSumT>
@ -235,7 +236,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
DMatrix* p_fmat,
RegTree* p_tree,
const std::vector<GradientPair>& gpair_h) {
builder_monitor_.Start("ExpandTree");
builder_monitor_->Start("ExpandTree");
int num_leaves = 0;
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
@ -282,11 +283,11 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
nodes_for_subtraction_trick_, p_tree);
}
builder_monitor_.Start("EvaluateSplits");
builder_monitor_->Start("EvaluateSplits");
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(),
gmat.cut, ft, *p_tree, &nodes_to_evaluate);
builder_monitor_.Stop("EvaluateSplits");
builder_monitor_->Stop("EvaluateSplits");
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0);
@ -296,7 +297,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
}
}
}
builder_monitor_.Stop("ExpandTree");
builder_monitor_->Stop("ExpandTree");
}
template <typename GradientSumT>
@ -305,7 +306,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
const common::ColumnMatrix &column_matrix,
HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, RegTree *p_tree) {
builder_monitor_.Start("Update");
builder_monitor_->Start("Update");
std::vector<GradientPair>* gpair_ptr = &(gpair->HostVector());
// in case 'num_parallel_trees != 1' no posibility to change initial gpair
@ -316,7 +317,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
}
p_last_fmat_mutable_ = p_fmat;
this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr);
this->InitData(gmat, p_fmat, *p_tree, gpair_ptr);
if (column_matrix.AnyMissing()) {
ExpandTree<true>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
@ -325,57 +326,28 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
}
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
builder_monitor_.Stop("Update");
builder_monitor_->Stop("Update");
}
template<typename GradientSumT>
template <typename GradientSumT>
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
const DMatrix* data,
linalg::VectorView<float> out_preds) {
DMatrix const *data, linalg::VectorView<float> out_preds) const {
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
p_last_fmat_ != p_last_fmat_mutable_) {
return false;
}
builder_monitor_.Start("UpdatePredictionCache");
CHECK_GT(out_preds.Size(), 0U);
CHECK_EQ(partitioner_.size(), 1);
auto const &row_set_collection = this->partitioner_.front().Partitions();
size_t n_nodes = row_set_collection.end() - row_set_collection.begin();
common::BlockedSpace2d space(
n_nodes, [&](size_t node) { return partitioner_.front()[node].Size(); }, 1024);
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) {
const common::RowSetCollection::Elem rowset = row_set_collection[node];
if (rowset.begin != nullptr && rowset.end != nullptr) {
int nid = rowset.node_id;
bst_float leaf_value;
// if a node is marked as deleted by the pruner, traverse upward to locate
// a non-deleted leaf.
if ((*p_last_tree_)[nid].IsDeleted()) {
while ((*p_last_tree_)[nid].IsDeleted()) {
nid = (*p_last_tree_)[nid].Parent();
}
CHECK((*p_last_tree_)[nid].IsLeaf());
}
leaf_value = (*p_last_tree_)[nid].LeafValue();
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
out_preds(*it) += leaf_value;
}
}
});
builder_monitor_.Stop("UpdatePredictionCache");
builder_monitor_->Start(__func__);
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, *evaluator_, param_, out_preds);
builder_monitor_->Stop(__func__);
return true;
}
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
std::vector<GradientPair>* gpair) {
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix &fmat,
std::vector<GradientPair> *gpair) {
const auto& info = fmat.Info();
auto& rnd = common::GlobalRandom();
std::vector<GradientPair>& gpair_ref = *gpair;
@ -415,85 +387,46 @@ size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
}
template <typename GradientSumT>
void QuantileHistMaker::Builder<GradientSumT>::InitData(
const GHistIndexMatrix &gmat, const DMatrix &fmat, const RegTree &tree,
std::vector<GradientPair> *gpair) {
builder_monitor_.Start("InitData");
const auto& info = fmat.Info();
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &gmat, DMatrix *fmat,
const RegTree &tree,
std::vector<GradientPair> *gpair) {
builder_monitor_->Start("InitData");
const auto& info = fmat->Info();
{
// initialize histogram collection
uint32_t nbins = gmat.cut.Ptrs().back();
// initialize histogram builder
dmlc::OMPException exc;
this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
this->ctx_->Threads(), 1, rabit::IsDistributed());
size_t page_id{0};
int32_t n_total_bins{0};
partitioner_.clear();
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads());
++page_id;
}
histogram_builder_->Reset(n_total_bins, BatchParam{param_.max_bin, param_.sparse_threshold},
ctx_->Threads(), page_id, rabit::IsDistributed());
if (param_.subsample < 1.0f) {
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
<< "Only uniform sampling is supported, "
<< "gradient-based sampling is only support by GPU Hist.";
builder_monitor_.Start("InitSampling");
InitSampling(fmat, gpair);
builder_monitor_.Stop("InitSampling");
builder_monitor_->Start("InitSampling");
InitSampling(*fmat, gpair);
builder_monitor_->Stop("InitSampling");
// We should check that the partitioning was done correctly
// and each row of the dataset fell into exactly one of the categories
}
}
partitioner_.clear();
partitioner_.emplace_back(info.num_row_, 0, this->ctx_->Threads());
{
/* determine layout of data */
const size_t nrow = info.num_row_;
const size_t ncol = info.num_col_;
const size_t nnz = info.num_nonzero_;
// number of discrete bins for feature 0
const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0];
if (nrow * ncol == nnz) {
// dense data with zero-based indexing
data_layout_ = DataLayout::kDenseDataZeroBased;
} else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) {
// dense data with one-based indexing
data_layout_ = DataLayout::kDenseDataOneBased;
} else {
// sparse data
data_layout_ = DataLayout::kSparseData;
}
}
// store a pointer to the tree
p_last_tree_ = &tree;
if (data_layout_ == DataLayout::kDenseDataOneBased) {
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->ctx_->Threads(), column_sampler_, task_, true});
} else {
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->ctx_->Threads(), column_sampler_, task_, false});
}
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
param_, info, this->ctx_->Threads(), column_sampler_, task_});
if (data_layout_ == DataLayout::kDenseDataZeroBased
|| data_layout_ == DataLayout::kDenseDataOneBased) {
/* specialized code for dense data:
choose the column that has a least positive number of discrete bins.
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
uint32_t min_nbins_per_feature = 0;
for (bst_uint i = 0; i < nfeature; ++i) {
const uint32_t nbins = row_ptr[i + 1] - row_ptr[i];
if (nbins > 0) {
if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) {
min_nbins_per_feature = nbins;
fid_least_bins_ = i;
}
}
}
CHECK_GT(min_nbins_per_feature, 0U);
}
builder_monitor_.Stop("InitData");
builder_monitor_->Stop("InitData");
}
void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,

View File

@ -276,21 +276,19 @@ class QuantileHistMaker: public TreeUpdater {
p_last_fmat_(fmat),
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
task_{task},
ctx_{ctx} {
builder_monitor_.Init("Quantile::Builder");
ctx_{ctx},
builder_monitor_{std::make_unique<common::Monitor>()} {
builder_monitor_->Init("Quantile::Builder");
}
// update one tree, growing
void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix,
HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
bool UpdatePredictionCache(const DMatrix* data,
linalg::VectorView<float> out_preds);
bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView<float> out_preds) const;
protected:
// initialize temp data structure
void InitData(const GHistIndexMatrix& gmat,
const DMatrix& fmat,
const RegTree& tree,
void InitData(const GHistIndexMatrix& gmat, DMatrix* fmat, const RegTree& tree,
std::vector<GradientPair>* gpair);
size_t GetNumberOfTrees();
@ -330,10 +328,6 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<GradientPair> gpair_local_;
/*! \brief feature with least # of bins. to be used for dense specialization
of InitNewNode() */
uint32_t fid_least_bins_;
std::unique_ptr<TreeUpdater> pruner_;
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
// Right now there's only 1 partitioner in this vector, when external memory is fully
@ -352,13 +346,12 @@ class QuantileHistMaker: public TreeUpdater {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
DataLayout data_layout_;
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>> histogram_builder_;
ObjInfo task_;
// Context for number of threads
GenericParameter const* ctx_;
common::Monitor builder_monitor_;
std::unique_ptr<common::Monitor> builder_monitor_;
};
common::Monitor updater_monitor_;

View File

@ -254,8 +254,7 @@ TEST(GpuHist, EvaluateRootSplit) {
std::vector<float> feature_weights;
maker.column_sampler.Init(kNCols, feature_weights, param.colsample_bynode,
param.colsample_bylevel, param.colsample_bytree,
false);
param.colsample_bylevel, param.colsample_bytree);
RegTree tree;
MetaInfo info;

View File

@ -35,8 +35,7 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<GradientPair>* gpair,
DMatrix* p_fmat,
const RegTree& tree) {
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
ASSERT_EQ(this->data_layout_, RealImpl::DataLayout::kSparseData);
RealImpl::InitData(gmat, p_fmat, tree, gpair);
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
* part of QuantileHist updater logic, but we include it here because