Refactor for GHistIndex. (#7923)
* Pass sparse page as adapter, which prepares for quantile dmatrix. * Remove old external memory code like `rbegin` and extra `Init` function. * Simplify type dispatch.
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "adapter.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@@ -32,8 +33,38 @@ class GHistIndexMatrix {
|
||||
* \param rbegin The beginning row index of current page. (total rows in previous pages)
|
||||
* \param prev_sum Total number of entries in previous pages.
|
||||
*/
|
||||
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft, size_t rbegin,
|
||||
size_t prev_sum, uint32_t nbins, int32_t n_threads);
|
||||
void PushBatch(SparsePage const& batch, common::Span<FeatureType const> ft,
|
||||
bst_bin_t n_total_bins, int32_t n_threads);
|
||||
|
||||
template <typename Batch, typename BinIdxType, typename GetOffset, typename IsValid>
|
||||
void SetIndexData(common::Span<BinIdxType> index_data_span, common::Span<FeatureType const> ft,
|
||||
size_t batch_threads, Batch const& batch, IsValid&& is_valid, size_t nbins,
|
||||
GetOffset&& get_offset) {
|
||||
auto batch_size = batch.Size();
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
auto const& ptrs = cut.Ptrs();
|
||||
auto const& values = cut.Values();
|
||||
common::ParallelFor(batch_size, batch_threads, [&](size_t i) {
|
||||
auto line = batch.GetLine(i);
|
||||
size_t ibegin = row_ptr[i]; // index of first entry for current block
|
||||
size_t k = 0;
|
||||
auto tid = omp_get_thread_num();
|
||||
for (size_t j = 0; j < line.Size(); ++j) {
|
||||
data::COOTuple elem = line.GetElement(j);
|
||||
if (is_valid(elem)) {
|
||||
bst_bin_t bin_idx{-1};
|
||||
if (common::IsCat(ft, elem.column_idx)) {
|
||||
bin_idx = cut.SearchCatBin(elem.value, elem.column_idx);
|
||||
} else {
|
||||
bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values);
|
||||
}
|
||||
index_data[ibegin + k] = get_offset(bin_idx, j);
|
||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||
++k;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
/*! \brief row pointer to rows by element position */
|
||||
@@ -50,52 +81,15 @@ class GHistIndexMatrix {
|
||||
size_t base_rowid{0};
|
||||
|
||||
GHistIndexMatrix();
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin, double sparse_thresh, bool sorted_sketch,
|
||||
int32_t n_threads, common::Span<float> hess = {});
|
||||
GHistIndexMatrix(DMatrix* x, bst_bin_t max_bins_per_feat, double sparse_thresh,
|
||||
bool sorted_sketch, int32_t n_threads, common::Span<float> hess = {});
|
||||
~GHistIndexMatrix();
|
||||
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch,
|
||||
int32_t n_threads, common::Span<float> hess);
|
||||
// Create a global histogram matrix, given cut. Used by external memory
|
||||
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
|
||||
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
|
||||
double sparse_thresh, int32_t n_threads);
|
||||
|
||||
// specific method for sparse data as no possibility to reduce allocated memory
|
||||
template <typename BinIdxType, typename GetOffset>
|
||||
void SetIndexData(common::Span<BinIdxType> index_data_span,
|
||||
common::Span<FeatureType const> ft,
|
||||
size_t batch_threads, const SparsePage &batch,
|
||||
size_t rbegin, size_t nbins, GetOffset get_offset) {
|
||||
const xgboost::Entry *data_ptr = batch.data.HostVector().data();
|
||||
const std::vector<bst_row_t> &offset_vec = batch.offset.HostVector();
|
||||
const size_t batch_size = batch.Size();
|
||||
CHECK_LT(batch_size, offset_vec.size());
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
auto const& ptrs = cut.Ptrs();
|
||||
auto const& values = cut.Values();
|
||||
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong ridx) {
|
||||
const int tid = omp_get_thread_num();
|
||||
size_t ibegin = row_ptr[rbegin + ridx]; // index of first entry for current block
|
||||
size_t iend = row_ptr[rbegin + ridx + 1]; // first entry for next block
|
||||
const size_t size = offset_vec[ridx + 1] - offset_vec[ridx];
|
||||
SparsePage::Inst inst = {data_ptr + offset_vec[ridx], size};
|
||||
CHECK_EQ(ibegin + inst.size(), iend);
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
auto e = inst[j];
|
||||
if (common::IsCat(ft, e.index)) {
|
||||
bst_bin_t bin_idx = cut.SearchCatBin(e);
|
||||
index_data[ibegin + j] = get_offset(bin_idx, j);
|
||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||
} else {
|
||||
bst_bin_t bin_idx = cut.SearchBin(e.fvalue, e.index, ptrs, values);
|
||||
index_data[ibegin + j] = get_offset(bin_idx, j);
|
||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ResizeIndex(const size_t n_index, const bool isDense);
|
||||
|
||||
void GetFeatureCounts(size_t* counts) const {
|
||||
|
||||
Reference in New Issue
Block a user