Extract Sketch Entry from hist maker. (#7503)
* Extract Sketch Entry from hist maker. * Add a new sketch container for sorted inputs. * Optimize bin search.
This commit is contained in:
@@ -92,18 +92,20 @@ class HistogramCuts {
|
||||
|
||||
// Return the index of a cut point that is strictly greater than the input
|
||||
// value, or the last available index if none exists
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
auto beg = cut_ptrs_.ConstHostVector().at(column_id);
|
||||
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
|
||||
const auto &values = cut_values_.ConstHostVector();
|
||||
BinIdx SearchBin(float value, uint32_t column_id, std::vector<uint32_t> const& ptrs,
|
||||
std::vector<float> const& values) const {
|
||||
auto end = ptrs[column_id + 1];
|
||||
auto beg = ptrs[column_id];
|
||||
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
||||
BinIdx idx = it - values.cbegin();
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
idx -= !!(idx == end);
|
||||
return idx;
|
||||
}
|
||||
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
return this->SearchBin(value, column_id, Ptrs(), Values());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Search the bin index for numerical feature.
|
||||
*/
|
||||
@@ -129,7 +131,13 @@ class HistogramCuts {
|
||||
}
|
||||
};
|
||||
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
/**
|
||||
* \brief Run CPU sketching on DMatrix.
|
||||
*
|
||||
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
||||
* but consumes more memory.
|
||||
*/
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false,
|
||||
Span<float> const hessian = {}) {
|
||||
HistogramCuts out;
|
||||
auto const& info = m->Info();
|
||||
@@ -146,13 +154,23 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
reduced[i] += entries_per_column[i];
|
||||
}
|
||||
}
|
||||
HostSketchContainer container(reduced, max_bins,
|
||||
m->Info().feature_types.ConstHostSpan(),
|
||||
HostSketchContainer::UseGroup(info), threads);
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
|
||||
if (!use_sorted) {
|
||||
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
|
||||
hessian, threads);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
} else {
|
||||
SortedSketchContainer container{
|
||||
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads};
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user