Extract Sketch Entry from hist maker. (#7503)
* Extract Sketch Entry from hist maker. * Add a new sketch container for sorted inputs. * Optimize bin search.
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
#include "random.h"
|
||||
#include "column_matrix.h"
|
||||
#include "quantile.h"
|
||||
#include "./../tree/updater_quantile_hist.h"
|
||||
#include "../data/gradient_index.h"
|
||||
|
||||
#if defined(XGBOOST_MM_PREFETCH_PRESENT)
|
||||
|
||||
@@ -92,18 +92,20 @@ class HistogramCuts {
|
||||
|
||||
// Return the index of a cut point that is strictly greater than the input
|
||||
// value, or the last available index if none exists
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
auto beg = cut_ptrs_.ConstHostVector().at(column_id);
|
||||
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
|
||||
const auto &values = cut_values_.ConstHostVector();
|
||||
BinIdx SearchBin(float value, uint32_t column_id, std::vector<uint32_t> const& ptrs,
|
||||
std::vector<float> const& values) const {
|
||||
auto end = ptrs[column_id + 1];
|
||||
auto beg = ptrs[column_id];
|
||||
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
||||
BinIdx idx = it - values.cbegin();
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
idx -= !!(idx == end);
|
||||
return idx;
|
||||
}
|
||||
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
return this->SearchBin(value, column_id, Ptrs(), Values());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Search the bin index for numerical feature.
|
||||
*/
|
||||
@@ -129,7 +131,13 @@ class HistogramCuts {
|
||||
}
|
||||
};
|
||||
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
/**
|
||||
* \brief Run CPU sketching on DMatrix.
|
||||
*
|
||||
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
|
||||
* but consumes more memory.
|
||||
*/
|
||||
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false,
|
||||
Span<float> const hessian = {}) {
|
||||
HistogramCuts out;
|
||||
auto const& info = m->Info();
|
||||
@@ -146,13 +154,23 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
|
||||
reduced[i] += entries_per_column[i];
|
||||
}
|
||||
}
|
||||
HostSketchContainer container(reduced, max_bins,
|
||||
m->Info().feature_types.ConstHostSpan(),
|
||||
HostSketchContainer::UseGroup(info), threads);
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
|
||||
if (!use_sorted) {
|
||||
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
|
||||
hessian, threads);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
} else {
|
||||
SortedSketchContainer container{
|
||||
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads};
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,40 +12,34 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
HostSketchContainer::HostSketchContainer(
|
||||
std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
common::Span<FeatureType const> feature_types, bool use_group,
|
||||
int32_t n_threads)
|
||||
template <typename WQSketch>
|
||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> columns_size,
|
||||
int32_t max_bins,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
bool use_group, int32_t n_threads)
|
||||
: feature_types_(feature_types.cbegin(), feature_types.cend()),
|
||||
columns_size_{std::move(columns_size)}, max_bins_{max_bins},
|
||||
use_group_ind_{use_group}, n_threads_{n_threads} {
|
||||
columns_size_{std::move(columns_size)},
|
||||
max_bins_{max_bins},
|
||||
use_group_ind_{use_group},
|
||||
n_threads_{n_threads} {
|
||||
monitor_.Init(__func__);
|
||||
CHECK_NE(columns_size_.size(), 0);
|
||||
sketches_.resize(columns_size_.size());
|
||||
CHECK_GE(n_threads_, 1);
|
||||
categories_.resize(columns_size_.size());
|
||||
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
||||
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
||||
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
|
||||
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
|
||||
if (!IsCat(this->feature_types_, i)) {
|
||||
sketches_[i].Init(columns_size_[i], eps);
|
||||
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<bst_row_t>
|
||||
HostSketchContainer::CalcColumnSize(SparsePage const &batch,
|
||||
bst_feature_t const n_columns,
|
||||
size_t const nthreads) {
|
||||
template <typename WQSketch>
|
||||
std::vector<bst_row_t> SketchContainerImpl<WQSketch>::CalcColumnSize(SparsePage const &batch,
|
||||
bst_feature_t const n_columns,
|
||||
size_t const nthreads) {
|
||||
auto page = batch.GetView();
|
||||
std::vector<std::vector<bst_row_t>> column_sizes(nthreads);
|
||||
for (auto &column : column_sizes) {
|
||||
column.resize(n_columns, 0);
|
||||
}
|
||||
|
||||
ParallelFor(omp_ulong(page.Size()), nthreads, [&](omp_ulong i) {
|
||||
ParallelFor(page.Size(), nthreads, [&](omp_ulong i) {
|
||||
auto &local_column_sizes = column_sizes.at(omp_get_thread_num());
|
||||
auto row = page[i];
|
||||
auto const *p_row = row.data();
|
||||
@@ -54,7 +48,7 @@ HostSketchContainer::CalcColumnSize(SparsePage const &batch,
|
||||
}
|
||||
});
|
||||
std::vector<bst_row_t> entries_per_columns(n_columns, 0);
|
||||
ParallelFor(bst_omp_uint(n_columns), nthreads, [&](bst_omp_uint i) {
|
||||
ParallelFor(n_columns, nthreads, [&](bst_omp_uint i) {
|
||||
for (auto const &thread : column_sizes) {
|
||||
entries_per_columns[i] += thread[i];
|
||||
}
|
||||
@@ -62,8 +56,10 @@ HostSketchContainer::CalcColumnSize(SparsePage const &batch,
|
||||
return entries_per_columns;
|
||||
}
|
||||
|
||||
std::vector<bst_feature_t> HostSketchContainer::LoadBalance(
|
||||
SparsePage const &batch, bst_feature_t n_columns, size_t const nthreads) {
|
||||
template <typename WQSketch>
|
||||
std::vector<bst_feature_t> SketchContainerImpl<WQSketch>::LoadBalance(SparsePage const &batch,
|
||||
bst_feature_t n_columns,
|
||||
size_t const nthreads) {
|
||||
/* Some sparse datasets have their mass concentrating on small number of features. To
|
||||
* avoid waiting for a few threads running forever, we here distribute different number
|
||||
* of columns to different threads according to number of entries.
|
||||
@@ -101,9 +97,8 @@ std::vector<bst_feature_t> HostSketchContainer::LoadBalance(
|
||||
|
||||
namespace {
|
||||
// Function to merge hessian and sample weights
|
||||
std::vector<float> MergeWeights(MetaInfo const &info,
|
||||
Span<float> const hessian,
|
||||
bool use_group, int32_t n_threads) {
|
||||
std::vector<float> MergeWeights(MetaInfo const &info, Span<float const> hessian, bool use_group,
|
||||
int32_t n_threads) {
|
||||
CHECK_EQ(hessian.size(), info.num_row_);
|
||||
std::vector<float> results(hessian.size());
|
||||
auto const &group_ptr = info.group_ptr_;
|
||||
@@ -148,8 +143,9 @@ std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void HostSketchContainer::PushRowPage(
|
||||
SparsePage const &page, MetaInfo const &info, Span<float> hessian) {
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo const &info,
|
||||
Span<float const> hessian) {
|
||||
monitor_.Start(__func__);
|
||||
bst_feature_t n_columns = info.num_col_;
|
||||
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
|
||||
@@ -216,11 +212,12 @@ void HostSketchContainer::PushRowPage(
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
|
||||
void HostSketchContainer::GatherSketchInfo(
|
||||
std::vector<WQSketch::SummaryContainer> const &reduced,
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<size_t> *p_worker_segments,
|
||||
std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<WQSketch::Entry> *p_global_sketches) {
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||
auto& worker_segments = *p_worker_segments;
|
||||
worker_segments.resize(1, 0);
|
||||
auto world = rabit::GetWorldSize();
|
||||
@@ -251,8 +248,8 @@ void HostSketchContainer::GatherSketchInfo(
|
||||
auto total = worker_segments.back();
|
||||
|
||||
auto& global_sketches = *p_global_sketches;
|
||||
global_sketches.resize(total, WQSketch::Entry{0, 0, 0, 0});
|
||||
auto worker_sketch = Span<WQSketch::Entry>{global_sketches}.subspan(
|
||||
global_sketches.resize(total, typename WQSketch::Entry{0, 0, 0, 0});
|
||||
auto worker_sketch = Span<typename WQSketch::Entry>{global_sketches}.subspan(
|
||||
worker_segments[rank], worker_segments[rank + 1] - worker_segments[rank]);
|
||||
size_t cursor = 0;
|
||||
for (auto const &sketch : reduced) {
|
||||
@@ -261,14 +258,15 @@ void HostSketchContainer::GatherSketchInfo(
|
||||
cursor += sketch.size;
|
||||
}
|
||||
|
||||
static_assert(sizeof(WQSketch::Entry) / 4 == sizeof(float), "");
|
||||
static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float), "");
|
||||
rabit::Allreduce<rabit::op::Sum>(
|
||||
reinterpret_cast<float *>(global_sketches.data()),
|
||||
global_sketches.size() * sizeof(WQSketch::Entry) / sizeof(float));
|
||||
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
|
||||
}
|
||||
|
||||
void HostSketchContainer::AllReduce(
|
||||
std::vector<WQSketch::SummaryContainer> *p_reduced,
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t>* p_num_cuts) {
|
||||
monitor_.Start(__func__);
|
||||
auto& num_cuts = *p_num_cuts;
|
||||
@@ -291,7 +289,7 @@ void HostSketchContainer::AllReduce(
|
||||
std::min(global_column_size[i],
|
||||
static_cast<size_t>(max_bins_ * WQSketch::kFactor)));
|
||||
if (global_column_size[i] != 0) {
|
||||
WQSketch::SummaryContainer out;
|
||||
typename WQSketch::SummaryContainer out;
|
||||
sketches_[i].GetSummary(&out);
|
||||
reduced[i].Reserve(intermediate_num_cuts);
|
||||
CHECK(reduced[i].data);
|
||||
@@ -309,11 +307,11 @@ void HostSketchContainer::AllReduce(
|
||||
std::vector<size_t> worker_segments(1, 0); // CSC pointer to sketches.
|
||||
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
||||
|
||||
std::vector<WQSketch::Entry> global_sketches;
|
||||
std::vector<typename WQSketch::Entry> global_sketches;
|
||||
this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan,
|
||||
&global_sketches);
|
||||
|
||||
std::vector<WQSketch::SummaryContainer> final_sketches(n_columns);
|
||||
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
|
||||
ParallelFor(n_columns, n_threads_, [&](auto fidx) {
|
||||
int32_t intermediate_num_cuts = num_cuts[fidx];
|
||||
auto nbytes =
|
||||
@@ -321,8 +319,8 @@ void HostSketchContainer::AllReduce(
|
||||
|
||||
for (int32_t i = 1; i < world + 1; ++i) {
|
||||
auto size = worker_segments.at(i) - worker_segments[i - 1];
|
||||
auto worker_sketches = Span<WQSketch::Entry>{global_sketches}.subspan(
|
||||
worker_segments[i - 1], size);
|
||||
auto worker_sketches =
|
||||
Span<typename WQSketch::Entry>{global_sketches}.subspan(worker_segments[i - 1], size);
|
||||
auto worker_scan =
|
||||
Span<bst_row_t>(sketches_scan)
|
||||
.subspan((i - 1) * (n_columns + 1), (n_columns + 1));
|
||||
@@ -330,8 +328,7 @@ void HostSketchContainer::AllReduce(
|
||||
auto worker_feature = worker_sketches.subspan(
|
||||
worker_scan[fidx], worker_scan[fidx + 1] - worker_scan[fidx]);
|
||||
CHECK(worker_feature.data());
|
||||
WQSummary<float, float> summary(worker_feature.data(),
|
||||
worker_feature.size());
|
||||
typename WQSketch::Summary summary(worker_feature.data(), worker_feature.size());
|
||||
auto &out = final_sketches.at(fidx);
|
||||
out.Reduce(summary, nbytes);
|
||||
}
|
||||
@@ -342,10 +339,11 @@ void HostSketchContainer::AllReduce(
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
|
||||
void AddCutPoint(WQuantileSketch<float, float>::SummaryContainer const &summary,
|
||||
int max_bin, HistogramCuts *cuts) {
|
||||
template <typename SketchType>
|
||||
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
|
||||
HistogramCuts *cuts) {
|
||||
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
|
||||
auto& cut_values = cuts->cut_values_.HostVector();
|
||||
auto &cut_values = cuts->cut_values_.HostVector();
|
||||
for (size_t i = 1; i < required_cuts; ++i) {
|
||||
bst_float cpt = summary.data[i].value;
|
||||
if (i == 1 || cpt > cut_values.back()) {
|
||||
@@ -361,20 +359,21 @@ void AddCategories(std::set<bst_cat_t> const &categories, HistogramCuts *cuts) {
|
||||
}
|
||||
}
|
||||
|
||||
void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<WQSketch::SummaryContainer> reduced;
|
||||
std::vector<typename WQSketch::SummaryContainer> reduced;
|
||||
std::vector<int32_t> num_cuts;
|
||||
this->AllReduce(&reduced, &num_cuts);
|
||||
|
||||
cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
||||
std::vector<WQSketch::SummaryContainer> final_summaries(reduced.size());
|
||||
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
|
||||
|
||||
ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
|
||||
if (IsCat(feature_types_, fidx)) {
|
||||
return;
|
||||
}
|
||||
WQSketch::SummaryContainer &a = final_summaries[fidx];
|
||||
typename WQSketch::SummaryContainer &a = final_summaries[fidx];
|
||||
size_t max_num_bins = std::min(num_cuts[fidx], max_bins_);
|
||||
a.Reserve(max_num_bins + 1);
|
||||
CHECK(a.data);
|
||||
@@ -392,11 +391,11 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
|
||||
|
||||
for (size_t fid = 0; fid < reduced.size(); ++fid) {
|
||||
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
|
||||
WQSketch::SummaryContainer const& a = final_summaries[fid];
|
||||
typename WQSketch::SummaryContainer const& a = final_summaries[fid];
|
||||
if (IsCat(feature_types_, fid)) {
|
||||
AddCategories(categories_.at(fid), cuts);
|
||||
} else {
|
||||
AddCutPoint(a, max_num_bins, cuts);
|
||||
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
|
||||
// push a value that is greater than anything
|
||||
const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value
|
||||
: cuts->min_vals_.HostVector()[fid];
|
||||
@@ -413,5 +412,64 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
|
||||
}
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
|
||||
template class SketchContainerImpl<WQuantileSketch<float, float>>;
|
||||
template class SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
HostSketchContainer::HostSketchContainer(int32_t max_bins, MetaInfo const &info,
|
||||
std::vector<size_t> columns_size, bool use_group,
|
||||
Span<float const> hessian, int32_t n_threads)
|
||||
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
|
||||
n_threads} {
|
||||
monitor_.Init(__func__);
|
||||
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
||||
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
||||
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
|
||||
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
|
||||
if (!IsCat(this->feature_types_, i)) {
|
||||
sketches_[i].Init(columns_size_[i], eps);
|
||||
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &info,
|
||||
Span<float const> hessian) {
|
||||
monitor_.Start(__func__);
|
||||
// glue these conditions using ternary operator to avoid making data copies.
|
||||
auto const &weights =
|
||||
hessian.empty() ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight
|
||||
: info.weights_.HostVector()) // use sample weight
|
||||
: MergeWeights(info, hessian, use_group_ind_,
|
||||
n_threads_); // use hessian merged with group/sample weights
|
||||
CHECK_EQ(weights.size(), info.num_row_);
|
||||
|
||||
auto view = page.GetView();
|
||||
ParallelFor(view.Size(), n_threads_, [&](size_t fidx) {
|
||||
auto column = view[fidx];
|
||||
auto &sketch = sketches_[fidx];
|
||||
sketch.Init(max_bins_);
|
||||
// first pass
|
||||
sketch.sum_total = 0.0;
|
||||
for (auto c : column) {
|
||||
sketch.sum_total += weights[c.index];
|
||||
}
|
||||
// second pass
|
||||
if (IsCat(feature_types_, fidx)) {
|
||||
for (auto c : column) {
|
||||
categories_[fidx].emplace(AsCat(c.fvalue));
|
||||
}
|
||||
} else {
|
||||
for (auto c : column) {
|
||||
sketch.Push(c.fvalue, weights[c.index], max_bins_);
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsCat(feature_types_, fidx) && !column.empty()) {
|
||||
sketch.Finalize(max_bins_);
|
||||
}
|
||||
});
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -702,11 +702,9 @@ class HistogramCuts;
|
||||
/*!
|
||||
* A sketch matrix storing sketches for each feature.
|
||||
*/
|
||||
class HostSketchContainer {
|
||||
public:
|
||||
using WQSketch = WQuantileSketch<float, float>;
|
||||
|
||||
private:
|
||||
template <typename WQSketch>
|
||||
class SketchContainerImpl {
|
||||
protected:
|
||||
std::vector<WQSketch> sketches_;
|
||||
std::vector<std::set<bst_cat_t>> categories_;
|
||||
std::vector<FeatureType> const feature_types_;
|
||||
@@ -724,7 +722,7 @@ class HostSketchContainer {
|
||||
* \param max_bins maximum number of bins for each feature.
|
||||
* \param use_group whether is assigned to group to data instance.
|
||||
*/
|
||||
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
common::Span<FeatureType const> feature_types, bool use_group,
|
||||
int32_t n_threads);
|
||||
|
||||
@@ -755,20 +753,139 @@ class HostSketchContainer {
|
||||
return group_ind;
|
||||
}
|
||||
// Gather sketches from all workers.
|
||||
void GatherSketchInfo(std::vector<WQSketch::SummaryContainer> const &reduced,
|
||||
void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<bst_row_t> *p_worker_segments,
|
||||
std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<WQSketch::Entry> *p_global_sketches);
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
||||
// Merge sketches from all workers.
|
||||
void AllReduce(std::vector<WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t>* p_num_cuts);
|
||||
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t> *p_num_cuts);
|
||||
|
||||
/* \brief Push a CSR matrix. */
|
||||
void PushRowPage(SparsePage const &page, MetaInfo const &info,
|
||||
Span<float> const hessian = {});
|
||||
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
||||
|
||||
void MakeCuts(HistogramCuts* cuts);
|
||||
};
|
||||
|
||||
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
|
||||
public:
|
||||
using WQSketch = WQuantileSketch<float, float>;
|
||||
|
||||
public:
|
||||
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
|
||||
bool use_group, Span<float const> hessian, int32_t n_threads);
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Quantile structure accepts sorted data, extracted from histmaker.
|
||||
*/
|
||||
struct SortedQuantile {
|
||||
/*! \brief total sum of amount to be met */
|
||||
double sum_total{0.0};
|
||||
/*! \brief statistics used in the sketch */
|
||||
double rmin, wmin;
|
||||
/*! \brief last seen feature value */
|
||||
bst_float last_fvalue;
|
||||
/*! \brief current size of sketch */
|
||||
double next_goal;
|
||||
// pointer to the sketch to put things in
|
||||
common::WXQuantileSketch<bst_float, bst_float>* sketch;
|
||||
// initialize the space
|
||||
inline void Init(unsigned max_size) {
|
||||
next_goal = -1.0f;
|
||||
rmin = wmin = 0.0f;
|
||||
sketch->temp.Reserve(max_size + 1);
|
||||
sketch->temp.size = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief push a new element to sketch
|
||||
* \param fvalue feature value, comes in sorted ascending order
|
||||
* \param w weight
|
||||
* \param max_size
|
||||
*/
|
||||
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
|
||||
if (next_goal == -1.0f) {
|
||||
next_goal = 0.0f;
|
||||
last_fvalue = fvalue;
|
||||
wmin = w;
|
||||
return;
|
||||
}
|
||||
if (last_fvalue != fvalue) {
|
||||
double rmax = rmin + wmin;
|
||||
if (rmax >= next_goal && sketch->temp.size != max_size) {
|
||||
if (sketch->temp.size == 0 ||
|
||||
last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
|
||||
// push to sketch
|
||||
sketch->temp.data[sketch->temp.size] =
|
||||
common::WXQuantileSketch<bst_float, bst_float>::Entry(
|
||||
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
|
||||
static_cast<bst_float>(wmin), last_fvalue);
|
||||
CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size
|
||||
<< ", stemp.size" << sketch->temp.size;
|
||||
++sketch->temp.size;
|
||||
}
|
||||
if (sketch->temp.size == max_size) {
|
||||
next_goal = sum_total * 2.0f + 1e-5f;
|
||||
} else {
|
||||
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
|
||||
}
|
||||
} else {
|
||||
if (rmax >= next_goal) {
|
||||
LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total
|
||||
<< ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size;
|
||||
}
|
||||
}
|
||||
rmin = rmax;
|
||||
wmin = w;
|
||||
last_fvalue = fvalue;
|
||||
} else {
|
||||
wmin += w;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief push final unfinished value to the sketch */
|
||||
inline void Finalize(unsigned max_size) {
|
||||
double rmax = rmin + wmin;
|
||||
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
|
||||
CHECK_LE(sketch->temp.size, max_size)
|
||||
<< "Finalize: invalid maximum size, max_size=" << max_size
|
||||
<< ", stemp.size=" << sketch->temp.size;
|
||||
// push to sketch
|
||||
sketch->temp.data[sketch->temp.size] = common::WXQuantileSketch<bst_float, bst_float>::Entry(
|
||||
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax), static_cast<bst_float>(wmin),
|
||||
last_fvalue);
|
||||
++sketch->temp.size;
|
||||
}
|
||||
sketch->PushTemp();
|
||||
}
|
||||
};
|
||||
|
||||
class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float, float>> {
|
||||
std::vector<SortedQuantile> sketches_;
|
||||
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
public:
|
||||
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
|
||||
std::vector<size_t> columns_size, bool use_group,
|
||||
Span<float const> hessian, int32_t n_threads)
|
||||
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
|
||||
n_threads} {
|
||||
monitor_.Init(__func__);
|
||||
sketches_.resize(info.num_col_);
|
||||
size_t i = 0;
|
||||
for (auto &sketch : sketches_) {
|
||||
sketch.sketch = &Super::sketches_[i];
|
||||
sketch.Init(max_bins_);
|
||||
auto eps = 2.0 / max_bins;
|
||||
sketch.sketch->Init(columns_size_[i], eps);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* \brief Push a sorted CSC page.
|
||||
*/
|
||||
void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_QUANTILE_H_
|
||||
|
||||
Reference in New Issue
Block a user