diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 6671f05e3..9a87e5222 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -125,26 +125,25 @@ class HistogramCuts { /** * \brief Search the bin index for numerical feature. */ - bst_bin_t SearchBin(Entry const& e) const { - return SearchBin(e.fvalue, e.index); - } + bst_bin_t SearchBin(Entry const& e) const { return SearchBin(e.fvalue, e.index); } /** * \brief Search the bin index for categorical feature. */ - bst_bin_t SearchCatBin(Entry const &e) const { + bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const { auto const &ptrs = this->Ptrs(); auto const &vals = this->Values(); - auto end = ptrs.at(e.index + 1) + vals.cbegin(); - auto beg = ptrs[e.index] + vals.cbegin(); + auto end = ptrs.at(fidx + 1) + vals.cbegin(); + auto beg = ptrs[fidx] + vals.cbegin(); // Truncates the value in case it's not perfectly rounded. - auto v = static_cast(common::AsCat(e.fvalue)); + auto v = static_cast(common::AsCat(value)); auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin(); - if (bin_idx == ptrs.at(e.index + 1)) { + if (bin_idx == ptrs.at(fidx + 1)) { bin_idx -= 1; } return bin_idx; } + bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); } }; /** diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 791bb47e7..6f8c5ee9f 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -17,85 +17,15 @@ namespace xgboost { GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique()} {} -GHistIndexMatrix::GHistIndexMatrix(DMatrix *x, int32_t max_bin, double sparse_thresh, - bool sorted_sketch, int32_t n_threads, +GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat, + double sparse_thresh, bool sorted_sketch, int32_t n_threads, common::Span hess) { - this->Init(x, max_bin, sparse_thresh, sorted_sketch, n_threads, hess); -} - -GHistIndexMatrix::~GHistIndexMatrix() = default; - -void GHistIndexMatrix::PushBatch(SparsePage const &batch, - common::Span ft, - size_t rbegin, size_t prev_sum, uint32_t nbins, - int32_t n_threads) { - auto page = batch.GetView(); - auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); }); - common::PartialSum(n_threads, it, it + page.Size(), prev_sum, row_ptr.begin() + rbegin); - // The number of threads is pegged to the batch size. If the OMP block is parallelized - // on anything other than the batch/block size, it should be reassigned - const size_t batch_threads = - std::max(static_cast(1), std::min(batch.Size(), static_cast(n_threads))); - - const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page - ResizeIndex(n_index, isDense_); - - CHECK_GT(cut.Values().size(), 0U); - - if (isDense_) { - index.SetBinOffset(cut.Ptrs()); - } - uint32_t const *offsets = index.Offset(); - - if (isDense_) { - // Inside the lambda functions, bin_idx is the index for cut value across all - // features. By subtracting it with starting pointer of each feature, we can reduce - // it to smaller value and compress it to smaller types. - common::BinTypeSize curent_bin_size = index.GetBinTypeSize(); - if (curent_bin_size == common::kUint8BinsTypeSize) { - common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, - [offsets](auto bin_idx, auto fidx) { - return static_cast(bin_idx - offsets[fidx]); - }); - } else if (curent_bin_size == common::kUint16BinsTypeSize) { - common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, - [offsets](auto bin_idx, auto fidx) { - return static_cast(bin_idx - offsets[fidx]); - }); - } else { - CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); - common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, - [offsets](auto bin_idx, auto fidx) { - return static_cast(bin_idx - offsets[fidx]); - }); - } - } else { - /* For sparse DMatrix we have to store index of feature for each bin - in index field to chose right offset. So offset is nullptr and index is - not reduced */ - common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, - [](auto idx, auto) { return idx; }); - } - - common::ParallelFor(nbins, n_threads, [&](bst_omp_uint idx) { - for (int32_t tid = 0; tid < n_threads; ++tid) { - hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; - hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch - } - }); -} - -void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch, - int32_t n_threads, common::Span hess) { + CHECK(p_fmat->SingleColBlock()); // We use sorted sketching for approx tree method since it's more efficient in // computation time (but higher memory usage). - cut = common::SketchOnDMatrix(p_fmat, max_bins, n_threads, sorted_sketch, hess); + cut = common::SketchOnDMatrix(p_fmat, max_bins_per_feat, n_threads, sorted_sketch, hess); - max_num_bins = max_bins; + max_num_bins = max_bins_per_feat; const uint32_t nbins = cut.Ptrs().back(); hit_count.resize(nbins, 0); hit_count_tloc_.resize(n_threads * nbins, 0); @@ -108,16 +38,12 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, row_ptr.resize(new_size); row_ptr[0] = 0; - size_t rbegin = 0; - size_t prev_sum = 0; const bool isDense = p_fmat->IsDense(); this->isDense_ = isDense; auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (const auto &batch : p_fmat->GetBatches()) { - this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); - prev_sum = row_ptr[rbegin + batch.Size()]; - rbegin += batch.Size(); + this->PushBatch(batch, ft, nbins, n_threads); } this->columns_ = std::make_unique(); @@ -131,6 +57,59 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, } } +GHistIndexMatrix::~GHistIndexMatrix() = default; + +void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span ft, + bst_bin_t n_total_bins, int32_t n_threads) { + auto page = batch.GetView(); + auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); }); + common::PartialSum(n_threads, it, it + page.Size(), static_cast(0), row_ptr.begin()); + // The number of threads is pegged to the batch size. If the OMP block is parallelized + // on anything other than the batch/block size, it should be reassigned + const size_t batch_threads = + std::max(static_cast(1), std::min(batch.Size(), static_cast(n_threads))); + + const size_t n_index = row_ptr[batch.Size()]; // number of entries in this page + ResizeIndex(n_index, isDense_); + + CHECK_GT(cut.Values().size(), 0U); + + if (isDense_) { + index.SetBinOffset(cut.Ptrs()); + } + uint32_t const *offsets = index.Offset(); + + auto n_bins_total = cut.TotalBins(); + auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries + data::SparsePageAdapterBatch adapter_batch{page}; + if (isDense_) { + // Inside the lambda functions, bin_idx is the index for cut value across all + // features. By subtracting it with starting pointer of each feature, we can reduce + // it to smaller value and compress it to smaller types. + common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) { + using T = decltype(dtype); + common::Span index_data_span = {index.data(), index.Size()}; + SetIndexData( + index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total, + [offsets](auto bin_idx, auto fidx) { return static_cast(bin_idx - offsets[fidx]); }); + }); + } else { + /* For sparse DMatrix we have to store index of feature for each bin + in index field to chose right offset. So offset is nullptr and index is + not reduced */ + common::Span index_data_span = {index.data(), n_index}; + SetIndexData(index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total, + [](auto idx, auto) { return idx; }); + } + + common::ParallelFor(n_total_bins, n_threads, [&](bst_omp_uint idx) { + for (int32_t tid = 0; tid < n_threads; ++tid) { + hit_count[idx] += hit_count_tloc_[tid * n_total_bins + idx]; + hit_count_tloc_[tid * n_total_bins + idx] = 0; // reset for next batch + } + }); +} + void GHistIndexMatrix::Init(SparsePage const &batch, common::Span ft, common::HistogramCuts const &cuts, int32_t max_bins_per_feat, bool isDense, double sparse_thresh, int32_t n_threads) { @@ -148,10 +127,7 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::SpanPushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); + this->PushBatch(batch, ft, nbins, n_threads); this->columns_ = std::make_unique(); if (!std::isnan(sparse_thresh)) { this->columns_->Init(batch, *this, sparse_thresh, n_threads); diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 3d179e0fd..7074a3d9d 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -10,6 +10,7 @@ #include "../common/categorical.h" #include "../common/hist_util.h" #include "../common/threading_utils.h" +#include "adapter.h" #include "xgboost/base.h" #include "xgboost/data.h" @@ -32,8 +33,38 @@ class GHistIndexMatrix { * \param rbegin The beginning row index of current page. (total rows in previous pages) * \param prev_sum Total number of entries in previous pages. */ - void PushBatch(SparsePage const& batch, common::Span ft, size_t rbegin, - size_t prev_sum, uint32_t nbins, int32_t n_threads); + void PushBatch(SparsePage const& batch, common::Span ft, + bst_bin_t n_total_bins, int32_t n_threads); + + template + void SetIndexData(common::Span index_data_span, common::Span ft, + size_t batch_threads, Batch const& batch, IsValid&& is_valid, size_t nbins, + GetOffset&& get_offset) { + auto batch_size = batch.Size(); + BinIdxType* index_data = index_data_span.data(); + auto const& ptrs = cut.Ptrs(); + auto const& values = cut.Values(); + common::ParallelFor(batch_size, batch_threads, [&](size_t i) { + auto line = batch.GetLine(i); + size_t ibegin = row_ptr[i]; // index of first entry for current block + size_t k = 0; + auto tid = omp_get_thread_num(); + for (size_t j = 0; j < line.Size(); ++j) { + data::COOTuple elem = line.GetElement(j); + if (is_valid(elem)) { + bst_bin_t bin_idx{-1}; + if (common::IsCat(ft, elem.column_idx)) { + bin_idx = cut.SearchCatBin(elem.value, elem.column_idx); + } else { + bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values); + } + index_data[ibegin + k] = get_offset(bin_idx, j); + ++hit_count_tloc_[tid * nbins + bin_idx]; + ++k; + } + } + }); + } public: /*! \brief row pointer to rows by element position */ @@ -50,52 +81,15 @@ class GHistIndexMatrix { size_t base_rowid{0}; GHistIndexMatrix(); - GHistIndexMatrix(DMatrix* x, int32_t max_bin, double sparse_thresh, bool sorted_sketch, - int32_t n_threads, common::Span hess = {}); + GHistIndexMatrix(DMatrix* x, bst_bin_t max_bins_per_feat, double sparse_thresh, + bool sorted_sketch, int32_t n_threads, common::Span hess = {}); ~GHistIndexMatrix(); - // Create a global histogram matrix, given cut - void Init(DMatrix* p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch, - int32_t n_threads, common::Span hess); + // Create a global histogram matrix, given cut. Used by external memory void Init(SparsePage const& page, common::Span ft, common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense, double sparse_thresh, int32_t n_threads); - // specific method for sparse data as no possibility to reduce allocated memory - template - void SetIndexData(common::Span index_data_span, - common::Span ft, - size_t batch_threads, const SparsePage &batch, - size_t rbegin, size_t nbins, GetOffset get_offset) { - const xgboost::Entry *data_ptr = batch.data.HostVector().data(); - const std::vector &offset_vec = batch.offset.HostVector(); - const size_t batch_size = batch.Size(); - CHECK_LT(batch_size, offset_vec.size()); - BinIdxType* index_data = index_data_span.data(); - auto const& ptrs = cut.Ptrs(); - auto const& values = cut.Values(); - common::ParallelFor(batch_size, batch_threads, [&](omp_ulong ridx) { - const int tid = omp_get_thread_num(); - size_t ibegin = row_ptr[rbegin + ridx]; // index of first entry for current block - size_t iend = row_ptr[rbegin + ridx + 1]; // first entry for next block - const size_t size = offset_vec[ridx + 1] - offset_vec[ridx]; - SparsePage::Inst inst = {data_ptr + offset_vec[ridx], size}; - CHECK_EQ(ibegin + inst.size(), iend); - for (bst_uint j = 0; j < inst.size(); ++j) { - auto e = inst[j]; - if (common::IsCat(ft, e.index)) { - bst_bin_t bin_idx = cut.SearchCatBin(e); - index_data[ibegin + j] = get_offset(bin_idx, j); - ++hit_count_tloc_[tid * nbins + bin_idx]; - } else { - bst_bin_t bin_idx = cut.SearchBin(e.fvalue, e.index, ptrs, values); - index_data[ibegin + j] = get_offset(bin_idx, j); - ++hit_count_tloc_[tid * nbins + bin_idx]; - } - } - }); - } - void ResizeIndex(const size_t n_index, const bool isDense); void GetFeatureCounts(size_t* counts) const { diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 4b6b0e91d..cdd38468a 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -107,22 +107,5 @@ TEST(DenseColumnWithMissing, Test) { }); } } - -void TestGHistIndexMatrixCreation(size_t nthreads) { - size_t constexpr kPageSize = 1024, kEntriesPerCol = 3; - size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - /* This should create multiple sparse pages */ - std::unique_ptr dmat{CreateSparsePageDMatrix(kEntries)}; - GHistIndexMatrix gmat(dmat.get(), 256, 0.5f, false, common::OmpGetNumThreads(nthreads)); -} - -TEST(HistIndexCreationWithExternalMemory, Test) { - // Vary the number of threads to make sure that the last batch - // is distributed properly to the available number of threads - // in the thread pool - TestGHistIndexMatrixCreation(20); - TestGHistIndexMatrixCreation(30); - TestGHistIndexMatrixCreation(40); -} } // namespace common } // namespace xgboost diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 6bf12a060..8f5d7d3d7 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -44,9 +44,7 @@ TEST(GradientIndex, FromCategoricalBasic) { h_ft.resize(kCols, FeatureType::kCategorical); BatchParam p(max_bins, 0.8); - GHistIndexMatrix gidx; - - gidx.Init(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {}); + GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {}); auto x_copy = x; std::sort(x_copy.begin(), x_copy.end()); diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index c0bd62629..d20c3d0d8 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -413,10 +413,16 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false); SparsePage concat; - GHistIndexMatrix gmat; std::vector hess(m->Info().num_row_, 1.0f); - gmat.Init(m.get(), batch_param.max_bin, std::numeric_limits::quiet_NaN(), false, - common::OmpGetNumThreads(0), hess); + for (auto const& page : m->GetBatches()) { + concat.Push(page); + } + + auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0), + false, hess); + GHistIndexMatrix gmat; + gmat.Init(concat, {}, cut, batch_param.max_bin, false, std::numeric_limits::quiet_NaN(), + common::OmpGetNumThreads(0)); single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair); single_page = single_build.Histogram()[0]; }