Extract Sketch Entry from hist maker. (#7503)

* Extract Sketch Entry from hist maker.

* Add a new sketch container for sorted inputs.
* Optimize bin search.
This commit is contained in:
Jiaming Yuan
2021-12-18 05:36:56 +08:00
committed by GitHub
parent b4a1236cfc
commit 9ab73f737e
15 changed files with 393 additions and 217 deletions

View File

@@ -702,11 +702,9 @@ class HistogramCuts;
/*!
* A sketch matrix storing sketches for each feature.
*/
class HostSketchContainer {
public:
using WQSketch = WQuantileSketch<float, float>;
private:
template <typename WQSketch>
class SketchContainerImpl {
protected:
std::vector<WQSketch> sketches_;
std::vector<std::set<bst_cat_t>> categories_;
std::vector<FeatureType> const feature_types_;
@@ -724,7 +722,7 @@ class HostSketchContainer {
* \param max_bins maximum number of bins for each feature.
* \param use_group whether is assigned to group to data instance.
*/
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins,
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
common::Span<FeatureType const> feature_types, bool use_group,
int32_t n_threads);
@@ -755,20 +753,139 @@ class HostSketchContainer {
return group_ind;
}
// Gather sketches from all workers.
void GatherSketchInfo(std::vector<WQSketch::SummaryContainer> const &reduced,
void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<bst_row_t> *p_worker_segments,
std::vector<bst_row_t> *p_sketches_scan,
std::vector<WQSketch::Entry> *p_global_sketches);
std::vector<typename WQSketch::Entry> *p_global_sketches);
// Merge sketches from all workers.
void AllReduce(std::vector<WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t>* p_num_cuts);
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t> *p_num_cuts);
/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const &page, MetaInfo const &info,
Span<float> const hessian = {});
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
void MakeCuts(HistogramCuts* cuts);
};
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
public:
using WQSketch = WQuantileSketch<float, float>;
public:
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
bool use_group, Span<float const> hessian, int32_t n_threads);
};
/**
* \brief Quantile structure accepts sorted data, extracted from histmaker.
*/
struct SortedQuantile {
/*! \brief total sum of amount to be met */
double sum_total{0.0};
/*! \brief statistics used in the sketch */
double rmin, wmin;
/*! \brief last seen feature value */
bst_float last_fvalue;
/*! \brief current size of sketch */
double next_goal;
// pointer to the sketch to put things in
common::WXQuantileSketch<bst_float, bst_float>* sketch;
// initialize the space
inline void Init(unsigned max_size) {
next_goal = -1.0f;
rmin = wmin = 0.0f;
sketch->temp.Reserve(max_size + 1);
sketch->temp.size = 0;
}
/*!
* \brief push a new element to sketch
* \param fvalue feature value, comes in sorted ascending order
* \param w weight
* \param max_size
*/
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
if (next_goal == -1.0f) {
next_goal = 0.0f;
last_fvalue = fvalue;
wmin = w;
return;
}
if (last_fvalue != fvalue) {
double rmax = rmin + wmin;
if (rmax >= next_goal && sketch->temp.size != max_size) {
if (sketch->temp.size == 0 ||
last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
// push to sketch
sketch->temp.data[sketch->temp.size] =
common::WXQuantileSketch<bst_float, bst_float>::Entry(
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
static_cast<bst_float>(wmin), last_fvalue);
CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size
<< ", stemp.size" << sketch->temp.size;
++sketch->temp.size;
}
if (sketch->temp.size == max_size) {
next_goal = sum_total * 2.0f + 1e-5f;
} else {
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
}
} else {
if (rmax >= next_goal) {
LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total
<< ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size;
}
}
rmin = rmax;
wmin = w;
last_fvalue = fvalue;
} else {
wmin += w;
}
}
/*! \brief push final unfinished value to the sketch */
inline void Finalize(unsigned max_size) {
double rmax = rmin + wmin;
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
CHECK_LE(sketch->temp.size, max_size)
<< "Finalize: invalid maximum size, max_size=" << max_size
<< ", stemp.size=" << sketch->temp.size;
// push to sketch
sketch->temp.data[sketch->temp.size] = common::WXQuantileSketch<bst_float, bst_float>::Entry(
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax), static_cast<bst_float>(wmin),
last_fvalue);
++sketch->temp.size;
}
sketch->PushTemp();
}
};
class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float, float>> {
std::vector<SortedQuantile> sketches_;
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
public:
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
std::vector<size_t> columns_size, bool use_group,
Span<float const> hessian, int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
n_threads} {
monitor_.Init(__func__);
sketches_.resize(info.num_col_);
size_t i = 0;
for (auto &sketch : sketches_) {
sketch.sketch = &Super::sketches_[i];
sketch.Init(max_bins_);
auto eps = 2.0 / max_bins;
sketch.sketch->Init(columns_size_[i], eps);
++i;
}
}
/**
* \brief Push a sorted CSC page.
*/
void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_H_