Extract Sketch Entry from hist maker. (#7503)
* Extract Sketch Entry from hist maker. * Add a new sketch container for sorted inputs. * Optimize bin search.
This commit is contained in:
@@ -702,11 +702,9 @@ class HistogramCuts;
|
||||
/*!
|
||||
* A sketch matrix storing sketches for each feature.
|
||||
*/
|
||||
class HostSketchContainer {
|
||||
public:
|
||||
using WQSketch = WQuantileSketch<float, float>;
|
||||
|
||||
private:
|
||||
template <typename WQSketch>
|
||||
class SketchContainerImpl {
|
||||
protected:
|
||||
std::vector<WQSketch> sketches_;
|
||||
std::vector<std::set<bst_cat_t>> categories_;
|
||||
std::vector<FeatureType> const feature_types_;
|
||||
@@ -724,7 +722,7 @@ class HostSketchContainer {
|
||||
* \param max_bins maximum number of bins for each feature.
|
||||
* \param use_group whether is assigned to group to data instance.
|
||||
*/
|
||||
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
|
||||
common::Span<FeatureType const> feature_types, bool use_group,
|
||||
int32_t n_threads);
|
||||
|
||||
@@ -755,20 +753,139 @@ class HostSketchContainer {
|
||||
return group_ind;
|
||||
}
|
||||
// Gather sketches from all workers.
|
||||
void GatherSketchInfo(std::vector<WQSketch::SummaryContainer> const &reduced,
|
||||
void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<bst_row_t> *p_worker_segments,
|
||||
std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<WQSketch::Entry> *p_global_sketches);
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
||||
// Merge sketches from all workers.
|
||||
void AllReduce(std::vector<WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t>* p_num_cuts);
|
||||
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t> *p_num_cuts);
|
||||
|
||||
/* \brief Push a CSR matrix. */
|
||||
void PushRowPage(SparsePage const &page, MetaInfo const &info,
|
||||
Span<float> const hessian = {});
|
||||
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
||||
|
||||
void MakeCuts(HistogramCuts* cuts);
|
||||
};
|
||||
|
||||
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
|
||||
public:
|
||||
using WQSketch = WQuantileSketch<float, float>;
|
||||
|
||||
public:
|
||||
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
|
||||
bool use_group, Span<float const> hessian, int32_t n_threads);
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Quantile structure accepts sorted data, extracted from histmaker.
|
||||
*/
|
||||
struct SortedQuantile {
|
||||
/*! \brief total sum of amount to be met */
|
||||
double sum_total{0.0};
|
||||
/*! \brief statistics used in the sketch */
|
||||
double rmin, wmin;
|
||||
/*! \brief last seen feature value */
|
||||
bst_float last_fvalue;
|
||||
/*! \brief current size of sketch */
|
||||
double next_goal;
|
||||
// pointer to the sketch to put things in
|
||||
common::WXQuantileSketch<bst_float, bst_float>* sketch;
|
||||
// initialize the space
|
||||
inline void Init(unsigned max_size) {
|
||||
next_goal = -1.0f;
|
||||
rmin = wmin = 0.0f;
|
||||
sketch->temp.Reserve(max_size + 1);
|
||||
sketch->temp.size = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief push a new element to sketch
|
||||
* \param fvalue feature value, comes in sorted ascending order
|
||||
* \param w weight
|
||||
* \param max_size
|
||||
*/
|
||||
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
|
||||
if (next_goal == -1.0f) {
|
||||
next_goal = 0.0f;
|
||||
last_fvalue = fvalue;
|
||||
wmin = w;
|
||||
return;
|
||||
}
|
||||
if (last_fvalue != fvalue) {
|
||||
double rmax = rmin + wmin;
|
||||
if (rmax >= next_goal && sketch->temp.size != max_size) {
|
||||
if (sketch->temp.size == 0 ||
|
||||
last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
|
||||
// push to sketch
|
||||
sketch->temp.data[sketch->temp.size] =
|
||||
common::WXQuantileSketch<bst_float, bst_float>::Entry(
|
||||
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
|
||||
static_cast<bst_float>(wmin), last_fvalue);
|
||||
CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size
|
||||
<< ", stemp.size" << sketch->temp.size;
|
||||
++sketch->temp.size;
|
||||
}
|
||||
if (sketch->temp.size == max_size) {
|
||||
next_goal = sum_total * 2.0f + 1e-5f;
|
||||
} else {
|
||||
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
|
||||
}
|
||||
} else {
|
||||
if (rmax >= next_goal) {
|
||||
LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total
|
||||
<< ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size;
|
||||
}
|
||||
}
|
||||
rmin = rmax;
|
||||
wmin = w;
|
||||
last_fvalue = fvalue;
|
||||
} else {
|
||||
wmin += w;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief push final unfinished value to the sketch */
|
||||
inline void Finalize(unsigned max_size) {
|
||||
double rmax = rmin + wmin;
|
||||
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
|
||||
CHECK_LE(sketch->temp.size, max_size)
|
||||
<< "Finalize: invalid maximum size, max_size=" << max_size
|
||||
<< ", stemp.size=" << sketch->temp.size;
|
||||
// push to sketch
|
||||
sketch->temp.data[sketch->temp.size] = common::WXQuantileSketch<bst_float, bst_float>::Entry(
|
||||
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax), static_cast<bst_float>(wmin),
|
||||
last_fvalue);
|
||||
++sketch->temp.size;
|
||||
}
|
||||
sketch->PushTemp();
|
||||
}
|
||||
};
|
||||
|
||||
class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float, float>> {
|
||||
std::vector<SortedQuantile> sketches_;
|
||||
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
public:
|
||||
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
|
||||
std::vector<size_t> columns_size, bool use_group,
|
||||
Span<float const> hessian, int32_t n_threads)
|
||||
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
|
||||
n_threads} {
|
||||
monitor_.Init(__func__);
|
||||
sketches_.resize(info.num_col_);
|
||||
size_t i = 0;
|
||||
for (auto &sketch : sketches_) {
|
||||
sketch.sketch = &Super::sketches_[i];
|
||||
sketch.Init(max_bins_);
|
||||
auto eps = 2.0 / max_bins;
|
||||
sketch.sketch->Init(columns_size_[i], eps);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* \brief Push a sorted CSC page.
|
||||
*/
|
||||
void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_QUANTILE_H_
|
||||
|
||||
Reference in New Issue
Block a user