Extract Sketch Entry from hist maker. (#7503)

* Extract Sketch Entry from hist maker.

* Add a new sketch container for sorted inputs.
* Optimize bin search.
This commit is contained in:
Jiaming Yuan
2021-12-18 05:36:56 +08:00
committed by GitHub
parent b4a1236cfc
commit 9ab73f737e
15 changed files with 393 additions and 217 deletions

View File

@@ -92,18 +92,20 @@ class HistogramCuts {
// Return the index of a cut point that is strictly greater than the input
// value, or the last available index if none exists
BinIdx SearchBin(float value, uint32_t column_id) const {
auto beg = cut_ptrs_.ConstHostVector().at(column_id);
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
const auto &values = cut_values_.ConstHostVector();
BinIdx SearchBin(float value, uint32_t column_id, std::vector<uint32_t> const& ptrs,
std::vector<float> const& values) const {
auto end = ptrs[column_id + 1];
auto beg = ptrs[column_id];
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
BinIdx idx = it - values.cbegin();
if (idx == end) {
idx -= 1;
}
idx -= !!(idx == end);
return idx;
}
BinIdx SearchBin(float value, uint32_t column_id) const {
return this->SearchBin(value, column_id, Ptrs(), Values());
}
/**
* \brief Search the bin index for numerical feature.
*/
@@ -129,7 +131,13 @@ class HistogramCuts {
}
};
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
/**
* \brief Run CPU sketching on DMatrix.
*
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
* but consumes more memory.
*/
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false,
Span<float> const hessian = {}) {
HistogramCuts out;
auto const& info = m->Info();
@@ -146,13 +154,23 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
reduced[i] += entries_per_column[i];
}
}
HostSketchContainer container(reduced, max_bins,
m->Info().feature_types.ConstHostSpan(),
HostSketchContainer::UseGroup(info), threads);
for (auto const &page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
hessian, threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
} else {
SortedSketchContainer container{
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(&out);
}
container.MakeCuts(&out);
return out;
}