Extract Sketch Entry from hist maker. (#7503)
* Extract Sketch Entry from hist maker. * Add a new sketch container for sorted inputs. * Optimize bin search.
This commit is contained in:
@@ -118,7 +118,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
|
||||
[](auto idx, auto) { return idx; });
|
||||
}
|
||||
|
||||
common::ParallelFor(bst_omp_uint(nbins), n_threads, [&](bst_omp_uint idx) {
|
||||
common::ParallelFor(nbins, n_threads, [&](bst_omp_uint idx) {
|
||||
for (int32_t tid = 0; tid < n_threads; ++tid) {
|
||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
|
||||
@@ -126,8 +126,11 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
|
||||
});
|
||||
}
|
||||
|
||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span<float> hess) {
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins, hess);
|
||||
void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch,
|
||||
common::Span<float> hess) {
|
||||
// We use sorted sketching for approx tree method since it's more efficient in
|
||||
// computation time (but higher memory usage).
|
||||
cut = common::SketchOnDMatrix(p_fmat, max_bins, sorted_sketch, hess);
|
||||
|
||||
max_num_bins = max_bins;
|
||||
const int32_t nthread = omp_get_max_threads();
|
||||
|
||||
@@ -37,14 +37,14 @@ class GHistIndexMatrix {
|
||||
size_t base_rowid{0};
|
||||
|
||||
GHistIndexMatrix() = default;
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin, common::Span<float> hess = {}) {
|
||||
this->Init(x, max_bin, hess);
|
||||
GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, common::Span<float> hess = {}) {
|
||||
this->Init(x, max_bin, sorted_sketch, hess);
|
||||
}
|
||||
// Create a global histogram matrix, given cut
|
||||
void Init(DMatrix* p_fmat, int max_num_bins, common::Span<float> hess);
|
||||
void Init(SparsePage const &page, common::Span<FeatureType const> ft,
|
||||
common::HistogramCuts const &cuts, int32_t max_bins_per_feat,
|
||||
bool is_dense, int32_t n_threads);
|
||||
void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, common::Span<float> hess);
|
||||
void Init(SparsePage const& page, common::Span<FeatureType const> ft,
|
||||
common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
|
||||
int32_t n_threads);
|
||||
|
||||
// specific method for sparse data as no possibility to reduce allocated memory
|
||||
template <typename BinIdxType, typename GetOffset>
|
||||
@@ -57,7 +57,9 @@ class GHistIndexMatrix {
|
||||
const size_t batch_size = batch.Size();
|
||||
CHECK_LT(batch_size, offset_vec.size());
|
||||
BinIdxType* index_data = index_data_span.data();
|
||||
common::ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) {
|
||||
auto const& ptrs = cut.Ptrs();
|
||||
auto const& values = cut.Values();
|
||||
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
size_t ibegin = row_ptr[rbegin + i];
|
||||
size_t iend = row_ptr[rbegin + i + 1];
|
||||
@@ -71,7 +73,7 @@ class GHistIndexMatrix {
|
||||
index_data[ibegin + j] = get_offset(bin_idx, j);
|
||||
++hit_count_tloc_[tid * nbins + bin_idx];
|
||||
} else {
|
||||
uint32_t idx = cut.SearchBin(inst[j]);
|
||||
uint32_t idx = cut.SearchBin(inst[j].fvalue, inst[j].index, ptrs, values);
|
||||
index_data[ibegin + j] = get_offset(idx, j);
|
||||
++hit_count_tloc_[tid * nbins + idx];
|
||||
}
|
||||
|
||||
@@ -94,10 +94,12 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
||||
if (!(batch_param_ != BatchParam{})) {
|
||||
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
}
|
||||
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) {
|
||||
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
CHECK_EQ(param.gpu_id, -1);
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.hess));
|
||||
// Used only by approx.
|
||||
auto sorted_sketch = param.regen;
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, sorted_sketch, param.hess));
|
||||
batch_param_ = param;
|
||||
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
|
||||
}
|
||||
|
||||
@@ -159,12 +159,12 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||
|
||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam& param) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
if (param.hess.empty()) {
|
||||
if (param.hess.empty() && !param.regen) {
|
||||
// hist method doesn't support full external memory implementation, so we concatenate
|
||||
// all index here.
|
||||
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
|
||||
this->InitializeSparsePage();
|
||||
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin});
|
||||
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.regen});
|
||||
this->InitializeSparsePage();
|
||||
batch_param_ = param;
|
||||
}
|
||||
@@ -175,20 +175,23 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
|
||||
|
||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
this->InitializeSparsePage();
|
||||
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) {
|
||||
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{}) ||
|
||||
param.regen) {
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
auto cuts = common::SketchOnDMatrix(this, param.max_bin, param.hess);
|
||||
// Use sorted sketch for approx.
|
||||
auto sorted_sketch = param.regen;
|
||||
auto cuts = common::SketchOnDMatrix(this, param.max_bin, sorted_sketch, param.hess);
|
||||
this->InitializeSparsePage(); // reset after use.
|
||||
|
||||
batch_param_ = param;
|
||||
ghist_index_source_.reset();
|
||||
CHECK_NE(cuts.Values().size(), 0);
|
||||
auto ft = this->info_.feature_types.ConstHostSpan();
|
||||
ghist_index_source_.reset(new GradientIndexPageSource(
|
||||
this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
|
||||
this->IsDense(), param.max_bin, ft, sparse_page_source_));
|
||||
ghist_index_source_.reset(
|
||||
new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_,
|
||||
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
|
||||
this->IsDense(), param.max_bin, ft, sparse_page_source_));
|
||||
} else {
|
||||
CHECK(ghist_index_source_);
|
||||
ghist_index_source_->Reset();
|
||||
|
||||
Reference in New Issue
Block a user