Extract Sketch Entry from hist maker. (#7503)

* Extract Sketch Entry from hist maker.

* Add a new sketch container for sorted inputs.
* Optimize bin search.
This commit is contained in:
Jiaming Yuan
2021-12-18 05:36:56 +08:00
committed by GitHub
parent b4a1236cfc
commit 9ab73f737e
15 changed files with 393 additions and 217 deletions

View File

@@ -118,7 +118,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
[](auto idx, auto) { return idx; });
}
common::ParallelFor(bst_omp_uint(nbins), n_threads, [&](bst_omp_uint idx) {
common::ParallelFor(nbins, n_threads, [&](bst_omp_uint idx) {
for (int32_t tid = 0; tid < n_threads; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
@@ -126,8 +126,11 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
});
}
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span<float> hess) {
cut = common::SketchOnDMatrix(p_fmat, max_bins, hess);
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch,
common::Span<float> hess) {
// We use sorted sketching for approx tree method since it's more efficient in
// computation time (but higher memory usage).
cut = common::SketchOnDMatrix(p_fmat, max_bins, sorted_sketch, hess);
max_num_bins = max_bins;
const int32_t nthread = omp_get_max_threads();

View File

@@ -37,14 +37,14 @@ class GHistIndexMatrix {
size_t base_rowid{0};
GHistIndexMatrix() = default;
GHistIndexMatrix(DMatrix* x, int32_t max_bin, common::Span<float> hess = {}) {
this->Init(x, max_bin, hess);
GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, common::Span<float> hess = {}) {
this->Init(x, max_bin, sorted_sketch, hess);
}
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins, common::Span<float> hess);
void Init(SparsePage const &page, common::Span<FeatureType const> ft,
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
bool is_dense, int32_t n_threads);
void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, common::Span<float> hess);
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
int32_t n_threads);
// specific method for sparse data as no possibility to reduce allocated memory
template <typename BinIdxType, typename GetOffset>
@@ -57,7 +57,9 @@ class GHistIndexMatrix {
const size_t batch_size = batch.Size();
CHECK_LT(batch_size, offset_vec.size());
BinIdxType* index_data = index_data_span.data();
common::ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) {
auto const& ptrs = cut.Ptrs();
auto const& values = cut.Values();
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong i) {
const int tid = omp_get_thread_num();
size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1];
@@ -71,7 +73,7 @@ class GHistIndexMatrix {
index_data[ibegin + j] = get_offset(bin_idx, j);
++hit_count_tloc_[tid * nbins + bin_idx];
} else {
uint32_t idx = cut.SearchBin(inst[j]);
uint32_t idx = cut.SearchBin(inst[j].fvalue, inst[j].index, ptrs, values);
index_data[ibegin + j] = get_offset(idx, j);
++hit_count_tloc_[tid * nbins + idx];
}

View File

@@ -94,10 +94,12 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) {
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) {
CHECK_GE(param.max_bin, 2);
CHECK_EQ(param.gpu_id, -1);
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.hess));
// Used only by approx.
auto sorted_sketch = param.regen;
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, sorted_sketch, param.hess));
batch_param_ = param;
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
}

View File

@@ -159,12 +159,12 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam& param) {
CHECK_GE(param.max_bin, 2);
if (param.hess.empty()) {
if (param.hess.empty() && !param.regen) {
// hist method doesn't support full external memory implementation, so we concatenate
// all index here.
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage();
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin});
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.regen});
this->InitializeSparsePage();
batch_param_ = param;
}
@@ -175,20 +175,23 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
this->InitializeSparsePage();
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) {
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{}) ||
param.regen) {
cache_info_.erase(id);
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
auto cuts = common::SketchOnDMatrix(this, param.max_bin, param.hess);
// Use sorted sketch for approx.
auto sorted_sketch = param.regen;
auto cuts = common::SketchOnDMatrix(this, param.max_bin, sorted_sketch, param.hess);
this->InitializeSparsePage(); // reset after use.
batch_param_ = param;
ghist_index_source_.reset();
CHECK_NE(cuts.Values().size(), 0);
auto ft = this->info_.feature_types.ConstHostSpan();
ghist_index_source_.reset(new GradientIndexPageSource(
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
this->IsDense(), param.max_bin, ft, sparse_page_source_));
ghist_index_source_.reset(
new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_,
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
this->IsDense(), param.max_bin, ft, sparse_page_source_));
} else {
CHECK(ghist_index_source_);
ghist_index_source_->Reset();