Small cleanup to hist tree method. (#7735)
* Remove special optimization using number of bins. * Remove 1-based index for column sampling. * Remove data layout. * Unify update prediction cache.
This commit is contained in:
parent
718472dbe2
commit
996cc705af
@ -156,9 +156,8 @@ class ColumnSampler {
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
*/
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights,
|
||||
float colsample_bynode, float colsample_bylevel,
|
||||
float colsample_bytree, bool skip_index_0 = false) {
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
|
||||
float colsample_bylevel, float colsample_bytree) {
|
||||
feature_weights_ = std::move(feature_weights);
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
@ -169,10 +168,8 @@ class ColumnSampler {
|
||||
}
|
||||
Reset();
|
||||
|
||||
int begin_idx = skip_index_0 ? 1 : 0;
|
||||
feature_set_tree_->Resize(num_col - begin_idx);
|
||||
std::iota(feature_set_tree_->HostVector().begin(),
|
||||
feature_set_tree_->HostVector().end(), begin_idx);
|
||||
feature_set_tree_->Resize(num_col);
|
||||
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
|
||||
|
||||
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
|
||||
}
|
||||
|
||||
@ -55,8 +55,6 @@ class RowSetCollection {
|
||||
/*! \brief return corresponding element set given the node_id */
|
||||
inline const Elem& operator[](unsigned node_id) const {
|
||||
const Elem& e = elem_of_each_node_[node_id];
|
||||
CHECK(e.begin != nullptr)
|
||||
<< "access element that is not in the set";
|
||||
return e;
|
||||
}
|
||||
|
||||
@ -75,14 +73,10 @@ class RowSetCollection {
|
||||
CHECK_EQ(elem_of_each_node_.size(), 0U);
|
||||
|
||||
if (row_indices_.empty()) { // edge case: empty instance set
|
||||
// assign arbitrary address here, to bypass nullptr check
|
||||
// (nullptr usually indicates a nonexistent rowset, but we want to
|
||||
// indicate a valid rowset that happens to have zero length and occupies
|
||||
// the whole instance set)
|
||||
// this is okay, as BuildHist will compute (end-begin) as the set size
|
||||
const size_t* begin = reinterpret_cast<size_t*>(20);
|
||||
const size_t* end = begin;
|
||||
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
||||
constexpr size_t* kBegin = nullptr;
|
||||
constexpr size_t* kEnd = nullptr;
|
||||
static_assert(kEnd - kBegin == 0, "");
|
||||
elem_of_each_node_.emplace_back(Elem(kBegin, kEnd, 0));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -93,15 +87,19 @@ class RowSetCollection {
|
||||
|
||||
std::vector<size_t>* Data() { return &row_indices_; }
|
||||
// split rowset into two
|
||||
inline void AddSplit(unsigned node_id,
|
||||
unsigned left_node_id,
|
||||
unsigned right_node_id,
|
||||
size_t n_left,
|
||||
size_t n_right) {
|
||||
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
|
||||
size_t n_left, size_t n_right) {
|
||||
const Elem e = elem_of_each_node_[node_id];
|
||||
CHECK(e.begin != nullptr);
|
||||
size_t* all_begin = dmlc::BeginPtr(row_indices_);
|
||||
size_t* begin = all_begin + (e.begin - all_begin);
|
||||
|
||||
size_t* all_begin{nullptr};
|
||||
size_t* begin{nullptr};
|
||||
if (e.begin == nullptr) {
|
||||
CHECK_EQ(n_left, 0);
|
||||
CHECK_EQ(n_right, 0);
|
||||
} else {
|
||||
all_begin = dmlc::BeginPtr(row_indices_);
|
||||
begin = all_begin + (e.begin - all_begin);
|
||||
}
|
||||
|
||||
CHECK_EQ(n_left + n_right, e.Size());
|
||||
CHECK_LE(begin + n_left, e.end);
|
||||
|
||||
@ -266,6 +266,9 @@ class MemStackAllocator {
|
||||
throw std::bad_alloc{};
|
||||
}
|
||||
}
|
||||
MemStackAllocator(size_t required_size, T init) : MemStackAllocator{required_size} {
|
||||
std::fill_n(ptr_, required_size_, init);
|
||||
}
|
||||
|
||||
~MemStackAllocator() {
|
||||
if (required_size_ > MaxStackSize) {
|
||||
|
||||
@ -363,19 +363,54 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
// The column sampler must be constructed by caller since we need to preserve the rng
|
||||
// for the entire training session.
|
||||
explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, int32_t n_threads,
|
||||
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task,
|
||||
bool skip_0_index = false)
|
||||
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task)
|
||||
: param_{param},
|
||||
column_sampler_{std::move(sampler)},
|
||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
|
||||
n_threads_{n_threads},
|
||||
task_{task} {
|
||||
interaction_constraints_.Configure(param, info.num_col_);
|
||||
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree, skip_0_index);
|
||||
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
/**
|
||||
* \brief CPU implementation of update prediction cache, which calculates the leaf value
|
||||
* for the last tree and accumulates it to prediction vector.
|
||||
*
|
||||
* \param p_last_tree The last tree being updated by tree updater
|
||||
*/
|
||||
template <typename Partitioner, typename GradientSumT, typename ExpandEntry>
|
||||
void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_last_tree,
|
||||
std::vector<Partitioner> const &partitioner,
|
||||
HistEvaluator<GradientSumT, ExpandEntry> const &hist_evaluator,
|
||||
TrainParam const ¶m, linalg::VectorView<float> out_preds) {
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
CHECK(p_last_tree);
|
||||
auto const &tree = *p_last_tree;
|
||||
auto const &snode = hist_evaluator.Stats();
|
||||
auto evaluator = hist_evaluator.Evaluator();
|
||||
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
|
||||
size_t n_nodes = p_last_tree->GetNodes().size();
|
||||
for (auto &part : partitioner) {
|
||||
CHECK_EQ(part.Size(), n_nodes);
|
||||
common::BlockedSpace2d space(
|
||||
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
|
||||
common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) {
|
||||
if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) {
|
||||
auto const &rowset = part[nidx];
|
||||
auto const &stats = snode[nidx];
|
||||
auto leaf_value =
|
||||
evaluator.CalcWeight(nidx, param, GradStats{stats.stats}) * param.learning_rate;
|
||||
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds(*it) += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
|
||||
|
||||
@ -114,34 +114,12 @@ class GloablApproxBuilder {
|
||||
return nodes.front();
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) {
|
||||
void UpdatePredictionCache(DMatrix const *data, linalg::VectorView<float> out_preds) const {
|
||||
monitor_->Start(__func__);
|
||||
// Caching prediction seems redundant for approx tree method, as sketching takes up
|
||||
// majority of training time.
|
||||
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
|
||||
CHECK(p_last_tree_);
|
||||
|
||||
size_t n_nodes = p_last_tree_->GetNodes().size();
|
||||
|
||||
auto evaluator = evaluator_.Evaluator();
|
||||
auto const &tree = *p_last_tree_;
|
||||
auto const &snode = evaluator_.Stats();
|
||||
for (auto &part : partitioner_) {
|
||||
CHECK_EQ(part.Size(), n_nodes);
|
||||
common::BlockedSpace2d space(
|
||||
part.Size(), [&](size_t node) { return part[node].Size(); }, 1024);
|
||||
common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) {
|
||||
if (tree[nidx].IsLeaf()) {
|
||||
const auto rowset = part[nidx];
|
||||
auto const &stats = snode.at(nidx);
|
||||
auto leaf_value =
|
||||
evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate;
|
||||
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds(*it) += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, evaluator_, param_, out_preds);
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
|
||||
@ -101,18 +101,17 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
p_last_dmat_ = dmat;
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data, linalg::VectorView<float> out_preds) {
|
||||
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data,
|
||||
linalg::VectorView<float> out_preds) {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds);
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds);
|
||||
} else {
|
||||
return false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
template <bool any_missing>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||
@ -135,27 +134,29 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||
}
|
||||
|
||||
{
|
||||
auto nid = RegTree::kRoot;
|
||||
auto hist = this->histogram_builder_->Histogram()[nid];
|
||||
GradientPairT grad_stat;
|
||||
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
|
||||
data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||
auto const& gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
|
||||
const std::vector<uint32_t> &row_ptr = gmat.cut.Ptrs();
|
||||
const uint32_t ibegin = row_ptr[fid_least_bins_];
|
||||
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
|
||||
if (p_fmat->IsDense()) {
|
||||
/**
|
||||
* Specialized code for dense data: For dense data (with no missing value), the sum
|
||||
* of gradient histogram is equal to snode[nid]
|
||||
*/
|
||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
|
||||
std::vector<uint32_t> const &row_ptr = gmat.cut.Ptrs();
|
||||
CHECK_GE(row_ptr.size(), 2);
|
||||
uint32_t const ibegin = row_ptr[0];
|
||||
uint32_t const iend = row_ptr[1];
|
||||
auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot];
|
||||
auto begin = hist.data();
|
||||
for (uint32_t i = ibegin; i < iend; ++i) {
|
||||
const GradientPairT et = begin[i];
|
||||
GradientPairT const &et = begin[i];
|
||||
grad_stat.Add(et.GetGrad(), et.GetHess());
|
||||
}
|
||||
} else {
|
||||
const common::RowSetCollection::Elem e = row_set_collection[nid];
|
||||
for (const size_t *it = e.begin; it < e.end; ++it) {
|
||||
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
|
||||
for (auto const &grad : gpair_h) {
|
||||
grad_stat.Add(grad.GetGrad(), grad.GetHess());
|
||||
}
|
||||
rabit::Allreduce<rabit::op::Sum, GradientSumT>(
|
||||
reinterpret_cast<GradientSumT *>(&grad_stat), 2);
|
||||
rabit::Allreduce<rabit::op::Sum, GradientSumT>(reinterpret_cast<GradientSumT *>(&grad_stat),
|
||||
2);
|
||||
}
|
||||
|
||||
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||
@ -164,14 +165,14 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||
(*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight);
|
||||
|
||||
std::vector<CPUExpandEntry> entries{node};
|
||||
builder_monitor_.Start("EvaluateSplits");
|
||||
builder_monitor_->Start("EvaluateSplits");
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||
for (auto const& gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft,
|
||||
*p_tree, &entries);
|
||||
break;
|
||||
}
|
||||
builder_monitor_.Stop("EvaluateSplits");
|
||||
builder_monitor_->Stop("EvaluateSplits");
|
||||
node = entries.front();
|
||||
}
|
||||
|
||||
@ -204,7 +205,7 @@ template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
|
||||
const std::vector<CPUExpandEntry> &nodes_for_apply_split,
|
||||
std::vector<CPUExpandEntry> *nodes_to_evaluate, RegTree *p_tree) {
|
||||
builder_monitor_.Start("SplitSiblings");
|
||||
builder_monitor_->Start("SplitSiblings");
|
||||
auto const& row_set_collection = this->partitioner_.front().Partitions();
|
||||
for (auto const& entry : nodes_for_apply_split) {
|
||||
int nid = entry.nid;
|
||||
@ -224,7 +225,7 @@ void QuantileHistMaker::Builder<GradientSumT>::SplitSiblings(
|
||||
}
|
||||
}
|
||||
CHECK_EQ(nodes_for_subtraction_trick_.size(), nodes_for_explicit_hist_build_.size());
|
||||
builder_monitor_.Stop("SplitSiblings");
|
||||
builder_monitor_->Stop("SplitSiblings");
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
@ -235,7 +236,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
||||
DMatrix* p_fmat,
|
||||
RegTree* p_tree,
|
||||
const std::vector<GradientPair>& gpair_h) {
|
||||
builder_monitor_.Start("ExpandTree");
|
||||
builder_monitor_->Start("ExpandTree");
|
||||
int num_leaves = 0;
|
||||
|
||||
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
|
||||
@ -282,11 +283,11 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
||||
nodes_for_subtraction_trick_, p_tree);
|
||||
}
|
||||
|
||||
builder_monitor_.Start("EvaluateSplits");
|
||||
builder_monitor_->Start("EvaluateSplits");
|
||||
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
|
||||
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(),
|
||||
gmat.cut, ft, *p_tree, &nodes_to_evaluate);
|
||||
builder_monitor_.Stop("EvaluateSplits");
|
||||
builder_monitor_->Stop("EvaluateSplits");
|
||||
|
||||
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
|
||||
CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0);
|
||||
@ -296,7 +297,7 @@ void QuantileHistMaker::Builder<GradientSumT>::ExpandTree(
|
||||
}
|
||||
}
|
||||
}
|
||||
builder_monitor_.Stop("ExpandTree");
|
||||
builder_monitor_->Stop("ExpandTree");
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
@ -305,7 +306,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
const common::ColumnMatrix &column_matrix,
|
||||
HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix *p_fmat, RegTree *p_tree) {
|
||||
builder_monitor_.Start("Update");
|
||||
builder_monitor_->Start("Update");
|
||||
|
||||
std::vector<GradientPair>* gpair_ptr = &(gpair->HostVector());
|
||||
// in case 'num_parallel_trees != 1' no posibility to change initial gpair
|
||||
@ -316,7 +317,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
}
|
||||
p_last_fmat_mutable_ = p_fmat;
|
||||
|
||||
this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr);
|
||||
this->InitData(gmat, p_fmat, *p_tree, gpair_ptr);
|
||||
|
||||
if (column_matrix.AnyMissing()) {
|
||||
ExpandTree<true>(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr);
|
||||
@ -325,57 +326,28 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
}
|
||||
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
|
||||
|
||||
builder_monitor_.Stop("Update");
|
||||
builder_monitor_->Stop("Update");
|
||||
}
|
||||
|
||||
template<typename GradientSumT>
|
||||
template <typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
linalg::VectorView<float> out_preds) {
|
||||
DMatrix const *data, linalg::VectorView<float> out_preds) const {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ ||
|
||||
p_last_fmat_ != p_last_fmat_mutable_) {
|
||||
return false;
|
||||
}
|
||||
builder_monitor_.Start("UpdatePredictionCache");
|
||||
|
||||
CHECK_GT(out_preds.Size(), 0U);
|
||||
|
||||
CHECK_EQ(partitioner_.size(), 1);
|
||||
auto const &row_set_collection = this->partitioner_.front().Partitions();
|
||||
size_t n_nodes = row_set_collection.end() - row_set_collection.begin();
|
||||
common::BlockedSpace2d space(
|
||||
n_nodes, [&](size_t node) { return partitioner_.front()[node].Size(); }, 1024);
|
||||
CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId);
|
||||
common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) {
|
||||
const common::RowSetCollection::Elem rowset = row_set_collection[node];
|
||||
if (rowset.begin != nullptr && rowset.end != nullptr) {
|
||||
int nid = rowset.node_id;
|
||||
bst_float leaf_value;
|
||||
// if a node is marked as deleted by the pruner, traverse upward to locate
|
||||
// a non-deleted leaf.
|
||||
if ((*p_last_tree_)[nid].IsDeleted()) {
|
||||
while ((*p_last_tree_)[nid].IsDeleted()) {
|
||||
nid = (*p_last_tree_)[nid].Parent();
|
||||
}
|
||||
CHECK((*p_last_tree_)[nid].IsLeaf());
|
||||
}
|
||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
|
||||
for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds(*it) += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
builder_monitor_.Stop("UpdatePredictionCache");
|
||||
builder_monitor_->Start(__func__);
|
||||
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
|
||||
UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, *evaluator_, param_, out_preds);
|
||||
builder_monitor_->Stop(__func__);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix& fmat,
|
||||
std::vector<GradientPair>* gpair) {
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitSampling(const DMatrix &fmat,
|
||||
std::vector<GradientPair> *gpair) {
|
||||
const auto& info = fmat.Info();
|
||||
auto& rnd = common::GlobalRandom();
|
||||
std::vector<GradientPair>& gpair_ref = *gpair;
|
||||
@ -415,85 +387,46 @@ size_t QuantileHistMaker::Builder<GradientSumT>::GetNumberOfTrees() {
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitData(
|
||||
const GHistIndexMatrix &gmat, const DMatrix &fmat, const RegTree &tree,
|
||||
std::vector<GradientPair> *gpair) {
|
||||
builder_monitor_.Start("InitData");
|
||||
const auto& info = fmat.Info();
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitData(const GHistIndexMatrix &gmat, DMatrix *fmat,
|
||||
const RegTree &tree,
|
||||
std::vector<GradientPair> *gpair) {
|
||||
builder_monitor_->Start("InitData");
|
||||
const auto& info = fmat->Info();
|
||||
|
||||
{
|
||||
// initialize histogram collection
|
||||
uint32_t nbins = gmat.cut.Ptrs().back();
|
||||
// initialize histogram builder
|
||||
dmlc::OMPException exc;
|
||||
this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin},
|
||||
this->ctx_->Threads(), 1, rabit::IsDistributed());
|
||||
size_t page_id{0};
|
||||
int32_t n_total_bins{0};
|
||||
partitioner_.clear();
|
||||
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
if (n_total_bins == 0) {
|
||||
n_total_bins = page.cut.TotalBins();
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads());
|
||||
++page_id;
|
||||
}
|
||||
histogram_builder_->Reset(n_total_bins, BatchParam{param_.max_bin, param_.sparse_threshold},
|
||||
ctx_->Threads(), page_id, rabit::IsDistributed());
|
||||
|
||||
if (param_.subsample < 1.0f) {
|
||||
CHECK_EQ(param_.sampling_method, TrainParam::kUniform)
|
||||
<< "Only uniform sampling is supported, "
|
||||
<< "gradient-based sampling is only support by GPU Hist.";
|
||||
builder_monitor_.Start("InitSampling");
|
||||
InitSampling(fmat, gpair);
|
||||
builder_monitor_.Stop("InitSampling");
|
||||
builder_monitor_->Start("InitSampling");
|
||||
InitSampling(*fmat, gpair);
|
||||
builder_monitor_->Stop("InitSampling");
|
||||
// We should check that the partitioning was done correctly
|
||||
// and each row of the dataset fell into exactly one of the categories
|
||||
}
|
||||
}
|
||||
|
||||
partitioner_.clear();
|
||||
partitioner_.emplace_back(info.num_row_, 0, this->ctx_->Threads());
|
||||
|
||||
{
|
||||
/* determine layout of data */
|
||||
const size_t nrow = info.num_row_;
|
||||
const size_t ncol = info.num_col_;
|
||||
const size_t nnz = info.num_nonzero_;
|
||||
// number of discrete bins for feature 0
|
||||
const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0];
|
||||
if (nrow * ncol == nnz) {
|
||||
// dense data with zero-based indexing
|
||||
data_layout_ = DataLayout::kDenseDataZeroBased;
|
||||
} else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) {
|
||||
// dense data with one-based indexing
|
||||
data_layout_ = DataLayout::kDenseDataOneBased;
|
||||
} else {
|
||||
// sparse data
|
||||
data_layout_ = DataLayout::kSparseData;
|
||||
}
|
||||
}
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
if (data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||
param_, info, this->ctx_->Threads(), column_sampler_, task_, true});
|
||||
} else {
|
||||
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||
param_, info, this->ctx_->Threads(), column_sampler_, task_, false});
|
||||
}
|
||||
evaluator_.reset(new HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||
param_, info, this->ctx_->Threads(), column_sampler_, task_});
|
||||
|
||||
if (data_layout_ == DataLayout::kDenseDataZeroBased
|
||||
|| data_layout_ == DataLayout::kDenseDataOneBased) {
|
||||
/* specialized code for dense data:
|
||||
choose the column that has a least positive number of discrete bins.
|
||||
For dense data (with no missing value),
|
||||
the sum of gradient histogram is equal to snode[nid] */
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut.Ptrs();
|
||||
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
|
||||
uint32_t min_nbins_per_feature = 0;
|
||||
for (bst_uint i = 0; i < nfeature; ++i) {
|
||||
const uint32_t nbins = row_ptr[i + 1] - row_ptr[i];
|
||||
if (nbins > 0) {
|
||||
if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) {
|
||||
min_nbins_per_feature = nbins;
|
||||
fid_least_bins_ = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_GT(min_nbins_per_feature, 0U);
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("InitData");
|
||||
builder_monitor_->Stop("InitData");
|
||||
}
|
||||
|
||||
void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &nodes,
|
||||
|
||||
@ -276,21 +276,19 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
p_last_fmat_(fmat),
|
||||
histogram_builder_{new HistogramBuilder<GradientSumT, CPUExpandEntry>},
|
||||
task_{task},
|
||||
ctx_{ctx} {
|
||||
builder_monitor_.Init("Quantile::Builder");
|
||||
ctx_{ctx},
|
||||
builder_monitor_{std::make_unique<common::Monitor>()} {
|
||||
builder_monitor_->Init("Quantile::Builder");
|
||||
}
|
||||
// update one tree, growing
|
||||
void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix,
|
||||
HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree);
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
linalg::VectorView<float> out_preds);
|
||||
bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView<float> out_preds) const;
|
||||
|
||||
protected:
|
||||
// initialize temp data structure
|
||||
void InitData(const GHistIndexMatrix& gmat,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree,
|
||||
void InitData(const GHistIndexMatrix& gmat, DMatrix* fmat, const RegTree& tree,
|
||||
std::vector<GradientPair>* gpair);
|
||||
|
||||
size_t GetNumberOfTrees();
|
||||
@ -330,10 +328,6 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
std::vector<GradientPair> gpair_local_;
|
||||
|
||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||
of InitNewNode() */
|
||||
uint32_t fid_least_bins_;
|
||||
|
||||
std::unique_ptr<TreeUpdater> pruner_;
|
||||
std::unique_ptr<HistEvaluator<GradientSumT, CPUExpandEntry>> evaluator_;
|
||||
// Right now there's only 1 partitioner in this vector, when external memory is fully
|
||||
@ -352,13 +346,12 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||
|
||||
enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
||||
DataLayout data_layout_;
|
||||
std::unique_ptr<HistogramBuilder<GradientSumT, CPUExpandEntry>> histogram_builder_;
|
||||
ObjInfo task_;
|
||||
// Context for number of threads
|
||||
GenericParameter const* ctx_;
|
||||
|
||||
common::Monitor builder_monitor_;
|
||||
std::unique_ptr<common::Monitor> builder_monitor_;
|
||||
};
|
||||
common::Monitor updater_monitor_;
|
||||
|
||||
|
||||
@ -254,8 +254,7 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
std::vector<float> feature_weights;
|
||||
|
||||
maker.column_sampler.Init(kNCols, feature_weights, param.colsample_bynode,
|
||||
param.colsample_bylevel, param.colsample_bytree,
|
||||
false);
|
||||
param.colsample_bylevel, param.colsample_bytree);
|
||||
|
||||
RegTree tree;
|
||||
MetaInfo info;
|
||||
|
||||
@ -35,8 +35,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
||||
ASSERT_EQ(this->data_layout_, RealImpl::DataLayout::kSparseData);
|
||||
RealImpl::InitData(gmat, p_fmat, tree, gpair);
|
||||
|
||||
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
|
||||
* part of QuantileHist updater logic, but we include it here because
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user