Partial rewrite EllpackPage (#5352)
This commit is contained in:
@@ -44,17 +44,35 @@ class HistogramCuts {
|
||||
using BinIdx = uint32_t;
|
||||
common::Monitor monitor_;
|
||||
|
||||
std::vector<bst_float> cut_values_;
|
||||
std::vector<uint32_t> cut_ptrs_;
|
||||
std::vector<float> min_vals_; // storing minimum value in a sketch set.
|
||||
|
||||
public:
|
||||
HostDeviceVector<bst_float> cut_values_;
|
||||
HostDeviceVector<uint32_t> cut_ptrs_;
|
||||
HostDeviceVector<float> min_vals_; // storing minimum value in a sketch set.
|
||||
|
||||
HistogramCuts();
|
||||
HistogramCuts(HistogramCuts const& that) = delete;
|
||||
HistogramCuts(HistogramCuts const& that) {
|
||||
cut_values_.Resize(that.cut_values_.Size());
|
||||
cut_ptrs_.Resize(that.cut_ptrs_.Size());
|
||||
min_vals_.Resize(that.min_vals_.Size());
|
||||
cut_values_.Copy(that.cut_values_);
|
||||
cut_ptrs_.Copy(that.cut_ptrs_);
|
||||
min_vals_.Copy(that.min_vals_);
|
||||
}
|
||||
|
||||
HistogramCuts(HistogramCuts&& that) noexcept(true) {
|
||||
*this = std::forward<HistogramCuts&&>(that);
|
||||
}
|
||||
HistogramCuts& operator=(HistogramCuts const& that) = delete;
|
||||
|
||||
HistogramCuts& operator=(HistogramCuts const& that) {
|
||||
cut_values_.Resize(that.cut_values_.Size());
|
||||
cut_ptrs_.Resize(that.cut_ptrs_.Size());
|
||||
min_vals_.Resize(that.min_vals_.Size());
|
||||
cut_values_.Copy(that.cut_values_);
|
||||
cut_ptrs_.Copy(that.cut_ptrs_);
|
||||
min_vals_.Copy(that.min_vals_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) {
|
||||
monitor_ = std::move(that.monitor_);
|
||||
cut_ptrs_ = std::move(that.cut_ptrs_);
|
||||
@@ -67,28 +85,30 @@ class HistogramCuts {
|
||||
void Build(DMatrix* dmat, uint32_t const max_num_bins);
|
||||
/* \brief How many bins a feature has. */
|
||||
uint32_t FeatureBins(uint32_t feature) const {
|
||||
return cut_ptrs_.at(feature+1) - cut_ptrs_[feature];
|
||||
return cut_ptrs_.ConstHostVector().at(feature + 1) -
|
||||
cut_ptrs_.ConstHostVector()[feature];
|
||||
}
|
||||
|
||||
// Getters. Cuts should be of no use after building histogram indices, but currently
|
||||
// it's deeply linked with quantile_hist, gpu sketcher and gpu_hist. So we preserve
|
||||
// these for now.
|
||||
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_; }
|
||||
std::vector<float> const& Values() const { return cut_values_; }
|
||||
std::vector<float> const& MinValues() const { return min_vals_; }
|
||||
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_.ConstHostVector(); }
|
||||
std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); }
|
||||
std::vector<float> const& MinValues() const { return min_vals_.ConstHostVector(); }
|
||||
|
||||
size_t TotalBins() const { return cut_ptrs_.back(); }
|
||||
size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); }
|
||||
|
||||
// Return the index of a cut point that is strictly greater than the input
|
||||
// value, or the last available index if none exists
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
auto beg = cut_ptrs_.at(column_id);
|
||||
auto end = cut_ptrs_.at(column_id + 1);
|
||||
auto it = std::upper_bound(cut_values_.cbegin() + beg, cut_values_.cbegin() + end, value);
|
||||
if (it == cut_values_.cend()) {
|
||||
it = cut_values_.cend() - 1;
|
||||
auto beg = cut_ptrs_.ConstHostVector().at(column_id);
|
||||
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
|
||||
const auto &values = cut_values_.ConstHostVector();
|
||||
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
||||
if (it == values.cend()) {
|
||||
it = values.cend() - 1;
|
||||
}
|
||||
BinIdx idx = it - cut_values_.cbegin();
|
||||
BinIdx idx = it - values.cbegin();
|
||||
return idx;
|
||||
}
|
||||
|
||||
@@ -133,8 +153,8 @@ class CutsBuilder {
|
||||
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
|
||||
for (size_t i = 1; i < required_cuts; ++i) {
|
||||
bst_float cpt = summary.data[i].value;
|
||||
if (i == 1 || cpt > p_cuts_->cut_values_.back()) {
|
||||
p_cuts_->cut_values_.push_back(cpt);
|
||||
if (i == 1 || cpt > p_cuts_->cut_values_.ConstHostVector().back()) {
|
||||
p_cuts_->cut_values_.HostVector().push_back(cpt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user