From 6762c45494f5eabdf3c40912cf5c005ced6d332f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 23 Feb 2022 11:37:21 +0800 Subject: [PATCH] Small cleanup to gradient index and hist. (#7668) * Code comments. * Const accessor to index. * Remove some weird variables in the `Index` class. * Simplify the `MemStackAllocator`. --- src/common/column_matrix.h | 8 +- src/common/hist_util.h | 118 ++++++++++---------------- src/common/threading_utils.h | 37 ++++++++ src/data/gradient_index.cc | 47 +++++----- src/data/gradient_index.h | 21 +++-- src/data/gradient_index_format.cc | 18 +--- src/tree/updater_approx.h | 10 +-- src/tree/updater_quantile_hist.cc | 6 +- src/tree/updater_quantile_hist.h | 6 +- tests/cpp/common/test_hist_util.cc | 6 +- tests/cpp/tree/hist/test_histogram.cc | 19 +++-- tests/cpp/tree/test_quantile_hist.cc | 1 - 12 files changed, 149 insertions(+), 148 deletions(-) diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 747004cc0..051a4cd44 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -266,9 +266,9 @@ class ColumnMatrix { } template - inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat, - const size_t nrow, const size_t nfeature, - const bool noMissingValues, int32_t n_threads) { + inline void SetIndexAllDense(T const* index, const GHistIndexMatrix& gmat, const size_t nrow, + const size_t nfeature, const bool noMissingValues, + int32_t n_threads) { T* local_index = reinterpret_cast(&index_[0]); /* missing values make sense only for column with type kDenseColumn, @@ -313,7 +313,7 @@ class ColumnMatrix { } template - inline void SetIndex(uint32_t* index, const GHistIndexMatrix& gmat, + inline void SetIndex(uint32_t const* index, const GHistIndexMatrix& gmat, const size_t nfeature) { std::vector num_nonzeros; num_nonzeros.resize(nfeature); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index d138d102d..442bddfcd 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -197,19 +197,27 @@ enum BinTypeSize : uint32_t { kUint32BinsTypeSize = 4 }; +/** + * \brief Optionally compressed gradient index. The compression works only with dense + * data. + * + * The main body of construction code is in gradient_index.cc, this struct is only a + * storage class. + */ struct Index { - Index() { - SetBinTypeSize(binTypeSize_); - } + Index() { SetBinTypeSize(binTypeSize_); } Index(const Index& i) = delete; Index& operator=(Index i) = delete; Index(Index&& i) = delete; Index& operator=(Index&& i) = delete; uint32_t operator[](size_t i) const { - if (offset_ptr_ != nullptr) { - return func_(data_ptr_, i) + offset_ptr_[i%p_]; + if (!bin_offset_.empty()) { + // dense, compressed + auto fidx = i % bin_offset_.size(); + // restore the index by adding back its feature offset. + return func_(data_.data(), i) + bin_offset_[fidx]; } else { - return func_(data_ptr_, i); + return func_(data_.data(), i); } } void SetBinTypeSize(BinTypeSize binTypeSize) { @@ -225,35 +233,32 @@ struct Index { func_ = &GetValueFromUint32; break; default: - CHECK(binTypeSize == kUint8BinsTypeSize || - binTypeSize == kUint16BinsTypeSize || + CHECK(binTypeSize == kUint8BinsTypeSize || binTypeSize == kUint16BinsTypeSize || binTypeSize == kUint32BinsTypeSize); } } BinTypeSize GetBinTypeSize() const { return binTypeSize_; } - template - T* data() const { // NOLINT - return static_cast(data_ptr_); + template + T const* data() const { // NOLINT + return reinterpret_cast(data_.data()); } - uint32_t* Offset() const { - return offset_ptr_; + template + T* data() { // NOLINT + return reinterpret_cast(data_.data()); } - size_t OffsetSize() const { - return offset_.size(); + uint32_t const* Offset() const { return bin_offset_.data(); } + size_t OffsetSize() const { return bin_offset_.size(); } + size_t Size() const { return data_.size() / (binTypeSize_); } + + void Resize(const size_t n_bytes) { + data_.resize(n_bytes); } - size_t Size() const { - return data_.size() / (binTypeSize_); - } - void Resize(const size_t nBytesData) { - data_.resize(nBytesData); - data_ptr_ = reinterpret_cast(data_.data()); - } - void ResizeOffset(const size_t nDisps) { - offset_.resize(nDisps); - offset_ptr_ = offset_.data(); - p_ = nDisps; + // set the offset used in compression, cut_ptrs is the CSC indptr in HistogramCuts + void SetBinOffset(std::vector const& cut_ptrs) { + bin_offset_.resize(cut_ptrs.size() - 1); // resize to number of features. + std::copy_n(cut_ptrs.begin(), bin_offset_.size(), bin_offset_.begin()); } std::vector::const_iterator begin() const { // NOLINT return data_.begin(); @@ -270,24 +275,23 @@ struct Index { } private: - static uint32_t GetValueFromUint8(void *t, size_t i) { - return reinterpret_cast(t)[i]; + // Functions to decompress the index. + static uint32_t GetValueFromUint8(uint8_t const* t, size_t i) { return t[i]; } + static uint32_t GetValueFromUint16(uint8_t const* t, size_t i) { + return reinterpret_cast(t)[i]; } - static uint32_t GetValueFromUint16(void* t, size_t i) { - return reinterpret_cast(t)[i]; - } - static uint32_t GetValueFromUint32(void* t, size_t i) { - return reinterpret_cast(t)[i]; + static uint32_t GetValueFromUint32(uint8_t const* t, size_t i) { + return reinterpret_cast(t)[i]; } - using Func = uint32_t (*)(void*, size_t); + using Func = uint32_t (*)(uint8_t const*, size_t); std::vector data_; - std::vector offset_; // size of this field is equal to number of features - void* data_ptr_; + // starting position of each feature inside the cut values (the indptr of the CSC cut matrix + // HistogramCuts without the last entry.) Used for bin compression. + std::vector bin_offset_; + BinTypeSize binTypeSize_ {kUint8BinsTypeSize}; - size_t p_ {1}; - uint32_t* offset_ptr_ {nullptr}; Func func_; }; @@ -304,9 +308,11 @@ int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end, } previous_middle = middle; + // index into all the bins auto gidx = data[middle]; if (gidx >= fidx_begin && gidx < fidx_end) { + // Found the intersection. return static_cast(gidx); } else if (gidx < fidx_begin) { begin = middle; @@ -636,42 +642,6 @@ class GHistBuilder { /*! \brief number of all bins over all features */ uint32_t nbins_ { 0 }; }; - -/*! - * \brief A C-style array with in-stack allocation. As long as the array is smaller than - * MaxStackSize, it will be allocated inside the stack. Otherwise, it will be - * heap-allocated. - */ -template -class MemStackAllocator { - public: - explicit MemStackAllocator(size_t required_size): required_size_(required_size) { - } - - T* Get() { - if (!ptr_) { - if (MaxStackSize >= required_size_) { - ptr_ = stack_mem_; - } else { - ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); - do_free_ = true; - } - } - - return ptr_; - } - - ~MemStackAllocator() { - if (do_free_) free(ptr_); - } - - - private: - T* ptr_ = nullptr; - bool do_free_ = false; - size_t required_size_; - T stack_mem_[MaxStackSize]; -}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_HIST_UTIL_H_ diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 44d15b900..4691fce7c 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -246,6 +246,43 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) { n_threads = std::max(n_threads, 1); return n_threads; } + + +/*! + * \brief A C-style array with in-stack allocation. As long as the array is smaller than + * MaxStackSize, it will be allocated inside the stack. Otherwise, it will be + * heap-allocated. + */ +template +class MemStackAllocator { + public: + explicit MemStackAllocator(size_t required_size) : required_size_(required_size) { + if (MaxStackSize >= required_size_) { + ptr_ = stack_mem_; + } else { + ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); + } + if (!ptr_) { + throw std::bad_alloc{}; + } + } + + ~MemStackAllocator() { + if (required_size_ > MaxStackSize) { + free(ptr_); + } + } + T& operator[](size_t i) { return ptr_[i]; } + T const& operator[](size_t i) const { return ptr_[i]; } + + // FIXME(jiamingy): Remove this once we merge partitioner cleanup for hist. + auto Get() { return ptr_; } + + private: + T* ptr_ = nullptr; + size_t required_size_; + T stack_mem_[MaxStackSize]; +}; } // namespace common } // namespace xgboost diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index abd80264d..eef8f9519 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -10,6 +10,7 @@ #include "../common/column_matrix.h" #include "../common/hist_util.h" +#include "../common/threading_utils.h" namespace xgboost { @@ -34,7 +35,6 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, std::max(static_cast(1), std::min(batch.Size(), static_cast(n_threads))); auto page = batch.GetView(); common::MemStackAllocator partial_sums(batch_threads); - size_t *p_part = partial_sums.Get(); size_t block_size = batch.Size() / batch_threads; @@ -48,10 +48,10 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, size_t iend = (tid == (batch_threads - 1) ? batch.Size() : (block_size * (tid + 1))); - size_t sum = 0; - for (size_t i = ibegin; i < iend; ++i) { - sum += page[i].size(); - row_ptr[rbegin + 1 + i] = sum; + size_t running_sum = 0; + for (size_t ridx = ibegin; ridx < iend; ++ridx) { + running_sum += page[ridx].size(); + row_ptr[rbegin + 1 + ridx] = running_sum; } }); } @@ -59,9 +59,9 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, #pragma omp single { exc.Run([&]() { - p_part[0] = prev_sum; + partial_sums[0] = prev_sum; for (size_t i = 1; i < batch_threads; ++i) { - p_part[i] = p_part[i - 1] + row_ptr[rbegin + i * block_size]; + partial_sums[i] = partial_sums[i - 1] + row_ptr[rbegin + i * block_size]; } }); } @@ -74,55 +74,52 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, : (block_size * (tid + 1))); for (size_t i = ibegin; i < iend; ++i) { - row_ptr[rbegin + 1 + i] += p_part[tid]; + row_ptr[rbegin + 1 + i] += partial_sums[tid]; } }); } } exc.Rethrow(); - const size_t n_offsets = cut.Ptrs().size() - 1; - const size_t n_index = row_ptr[rbegin + batch.Size()]; + const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page ResizeIndex(n_index, isDense_); CHECK_GT(cut.Values().size(), 0U); - uint32_t *offsets = nullptr; if (isDense_) { - index.ResizeOffset(n_offsets); - offsets = index.Offset(); - for (size_t i = 0; i < n_offsets; ++i) { - offsets[i] = cut.Ptrs()[i]; - } + index.SetBinOffset(cut.Ptrs()); } + uint32_t const *offsets = index.Offset(); if (isDense_) { + // Inside the lambda functions, bin_idx is the index for cut value across all + // features. By subtracting it with starting pointer of each feature, we can reduce + // it to smaller value and compress it to smaller types. common::BinTypeSize curent_bin_size = index.GetBinTypeSize(); if (curent_bin_size == common::kUint8BinsTypeSize) { common::Span index_data_span = {index.data(), n_index}; SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); + [offsets](auto bin_idx, auto fidx) { + return static_cast(bin_idx - offsets[fidx]); }); } else if (curent_bin_size == common::kUint16BinsTypeSize) { common::Span index_data_span = {index.data(), n_index}; SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); + [offsets](auto bin_idx, auto fidx) { + return static_cast(bin_idx - offsets[fidx]); }); } else { CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); common::Span index_data_span = {index.data(), n_index}; SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); + [offsets](auto bin_idx, auto fidx) { + return static_cast(bin_idx - offsets[fidx]); }); } - + } else { /* For sparse DMatrix we have to store index of feature for each bin in index field to chose right offset. So offset is nullptr and index is not reduced */ - } else { common::Span index_data_span = {index.data(), n_index}; SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, [](auto idx, auto) { return idx; }); @@ -194,11 +191,13 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span(std::numeric_limits::max())) && isDense) { + // compress dense index to uint8 index.SetBinTypeSize(common::kUint8BinsTypeSize); index.Resize((sizeof(uint8_t)) * n_index); } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + // compress dense index to uint16 index.SetBinTypeSize(common::kUint16BinsTypeSize); index.Resize((sizeof(uint16_t)) * n_index); } else { diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 83da8c784..48e6b3716 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -21,6 +21,13 @@ namespace xgboost { * index for CPU histogram. On GPU ellpack page is used. */ class GHistIndexMatrix { + /** + * \brief Push a page into index matrix, the function is only necessary because hist has + * partial support for external memory. + * + * \param rbegin The beginning row index of current page. (total rows in previous pages) + * \param prev_sum Total number of entries in previous pages. + */ void PushBatch(SparsePage const& batch, common::Span ft, size_t rbegin, size_t prev_sum, uint32_t nbins, int32_t n_threads); @@ -64,12 +71,12 @@ class GHistIndexMatrix { BinIdxType* index_data = index_data_span.data(); auto const& ptrs = cut.Ptrs(); auto const& values = cut.Values(); - common::ParallelFor(batch_size, batch_threads, [&](omp_ulong i) { + common::ParallelFor(batch_size, batch_threads, [&](omp_ulong ridx) { const int tid = omp_get_thread_num(); - size_t ibegin = row_ptr[rbegin + i]; - size_t iend = row_ptr[rbegin + i + 1]; - const size_t size = offset_vec[i + 1] - offset_vec[i]; - SparsePage::Inst inst = {data_ptr + offset_vec[i], size}; + size_t ibegin = row_ptr[rbegin + ridx]; // index of first entry for current block + size_t iend = row_ptr[rbegin + ridx + 1]; // first entry for next block + const size_t size = offset_vec[ridx + 1] - offset_vec[ridx]; + SparsePage::Inst inst = {data_ptr + offset_vec[ridx], size}; CHECK_EQ(ibegin + inst.size(), iend); for (bst_uint j = 0; j < inst.size(); ++j) { auto e = inst[j]; @@ -103,6 +110,10 @@ class GHistIndexMatrix { return isDense_; } void SetDense(bool is_dense) { isDense_ = is_dense; } + /** + * \brief Get the local row index. + */ + size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; } bst_row_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; diff --git a/src/data/gradient_index_format.cc b/src/data/gradient_index_format.cc index 19baeb406..ff260efbf 100644 --- a/src/data/gradient_index_format.cc +++ b/src/data/gradient_index_format.cc @@ -16,14 +16,6 @@ class GHistIndexRawFormat : public SparsePageFormat { } // indptr fi->Read(&page->row_ptr); - // offset - using OffsetT = std::iterator_traitsindex.Offset())>::value_type; - std::vector offset; - if (!fi->Read(&offset)) { - return false; - } - page->index.ResizeOffset(offset.size()); - std::copy(offset.begin(), offset.end(), page->index.Offset()); // data std::vector data; if (!fi->Read(&data)) { @@ -55,6 +47,9 @@ class GHistIndexRawFormat : public SparsePageFormat { return false; } page->SetDense(is_dense); + if (is_dense) { + page->index.SetBinOffset(page->cut.Ptrs()); + } return true; } @@ -65,13 +60,6 @@ class GHistIndexRawFormat : public SparsePageFormat { fo->Write(page.row_ptr); bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) + sizeof(uint64_t); - // offset - using OffsetT = std::iterator_traits::value_type; - std::vector offset(page.index.OffsetSize()); - std::copy(page.index.Offset(), - page.index.Offset() + page.index.OffsetSize(), offset.begin()); - fo->Write(offset); - bytes += page.index.OffsetSize() * sizeof(OffsetT) + sizeof(uint64_t); // data std::vector data(page.index.begin(), page.index.end()); fo->Write(data); diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h index 158ab2b2c..ec54da19e 100644 --- a/src/tree/updater_approx.h +++ b/src/tree/updater_approx.h @@ -35,14 +35,12 @@ class ApproxRowPartitioner { std::vector const &cut_ptrs, std::vector const &cut_values) { int32_t gidx = -1; - auto const &row_ptr = index.row_ptr; - auto get_rid = [&](size_t ridx) { return row_ptr[ridx - index.base_rowid]; }; - if (index.IsDense()) { - gidx = index.index[get_rid(ridx) + fidx]; + // RowIdx returns the starting pos of this row + gidx = index.index[index.RowIdx(ridx) + fidx]; } else { - auto begin = get_rid(ridx); - auto end = get_rid(ridx + 1); + auto begin = index.RowIdx(ridx); + auto end = index.RowIdx(ridx + 1); auto f_begin = cut_ptrs[fidx]; auto f_end = cut_ptrs[fidx + 1]; gidx = common::BinarySearchBin(begin, end, index.index, f_begin, f_end); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 8c52ff382..616d1c571 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -135,7 +135,7 @@ void QuantileHistMaker::Builder::InitRoot( { auto nid = RegTree::kRoot; - GHistRowT hist = this->histogram_builder_->Histogram()[nid]; + auto hist = this->histogram_builder_->Histogram()[nid]; GradientPairT grad_stat; if (data_layout_ == DataLayout::kDenseDataZeroBased || data_layout_ == DataLayout::kDenseDataOneBased) { @@ -149,7 +149,7 @@ void QuantileHistMaker::Builder::InitRoot( grad_stat.Add(et.GetGrad(), et.GetHess()); } } else { - const RowSetCollection::Elem e = row_set_collection_[nid]; + const common::RowSetCollection::Elem e = row_set_collection_[nid]; for (const size_t *it = e.begin; it < e.end; ++it) { grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess()); } @@ -229,7 +229,7 @@ template template void QuantileHistMaker::Builder::ExpandTree( const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, + const common::ColumnMatrix& column_matrix, DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h) { diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 3f2b07ff9..09df175cd 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -147,7 +147,7 @@ class QuantileHistMaker: public TreeUpdater { // training parameter TrainParam param_; // column accessor - ColumnMatrix column_matrix_; + common::ColumnMatrix column_matrix_; DMatrix const* p_last_dmat_ {nullptr}; bool is_gmat_initialized_ {false}; @@ -155,7 +155,6 @@ class QuantileHistMaker: public TreeUpdater { template struct Builder { public: - using GHistRowT = GHistRow; using GradientPairT = xgboost::detail::GradientPairInternal; // constructor explicit Builder(const size_t n_trees, const TrainParam& param, @@ -164,7 +163,6 @@ class QuantileHistMaker: public TreeUpdater { : n_trees_(n_trees), param_(param), pruner_(std::move(pruner)), - p_last_tree_(nullptr), p_last_fmat_(fmat), histogram_builder_{new HistogramBuilder}, task_{task}, @@ -172,7 +170,7 @@ class QuantileHistMaker: public TreeUpdater { builder_monitor_.Init("Quantile::Builder"); } // update one tree, growing - void Update(const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix, + void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix, HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree); bool UpdatePredictionCache(const DMatrix* data, diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 13fd84691..719425dee 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -306,8 +306,8 @@ TEST(HistUtil, IndexBinBound) { } template -void CheckIndexData(T* data_ptr, uint32_t* offsets, - const GHistIndexMatrix& hmat, size_t n_cols) { +void CheckIndexData(T const* data_ptr, uint32_t const* offsets, const GHistIndexMatrix& hmat, + size_t n_cols) { for (size_t i = 0; i < hmat.index.Size(); ++i) { EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]); } @@ -323,7 +323,7 @@ TEST(HistUtil, IndexBinData) { for (auto max_bin : kBinSizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0)); - uint32_t* offsets = hmat.index.Offset(); + uint32_t const* offsets = hmat.index.Offset(); EXPECT_EQ(hmat.index.Size(), kRows*kCols); switch (max_bin) { case kBinSizes[0]: diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 553550e33..1fa229999 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -6,15 +6,16 @@ #include #include "../../../../src/common/categorical.h" +#include "../../../../src/common/row_set.h" +#include "../../../../src/tree/hist/expand_entry.h" #include "../../../../src/tree/hist/histogram.h" -#include "../../../../src/tree/updater_quantile_hist.h" #include "../../categorical_helpers.h" #include "../../helpers.h" namespace xgboost { namespace tree { namespace { -void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) { +void InitRowPartitionForTest(common::RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) { auto &row_indices = *row_set->Data(); row_indices.resize(n_samples); std::iota(row_indices.begin(), row_indices.end(), base_rowid); @@ -91,7 +92,7 @@ void TestSyncHist(bool is_distributed) { uint32_t total_bins = gmat.cut.Ptrs().back(); histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); - RowSetCollection row_set_collection_; + common::RowSetCollection row_set_collection_; { row_set_collection_.Clear(); std::vector &row_indices = *row_set_collection_.Data(); @@ -256,7 +257,7 @@ void TestBuildHistogram(bool is_distributed) { RegTree tree; - RowSetCollection row_set_collection; + common::RowSetCollection row_set_collection; row_set_collection.Clear(); std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kNRows); @@ -318,7 +319,7 @@ void TestHistogramCategorical(size_t n_categories) { auto gpair = GenerateRandomGradients(kRows, 0, 2); - RowSetCollection row_set_collection; + common::RowSetCollection row_set_collection; row_set_collection.Clear(); std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kRows); @@ -381,13 +382,13 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { std::vector nodes; nodes.emplace_back(0, tree.GetDepth(0), 0.0f); - GHistRow multi_page; + common::GHistRow multi_page; HistogramBuilder multi_build; { /** * Multi page */ - std::vector rows_set; + std::vector rows_set; for (auto const &page : m->GetBatches(batch_param)) { CHECK_LT(page.base_rowid, m->Info().num_row_); auto n_rows_in_node = page.Size(); @@ -417,12 +418,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { } HistogramBuilder single_build; - GHistRow single_page; + common::GHistRow single_page; { /** * Single page */ - RowSetCollection row_set_collection; + common::RowSetCollection row_set_collection; InitRowPartitionForTest(&row_set_collection, n_samples); single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false); diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index fc7c43ad7..d043c5bb5 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -24,7 +24,6 @@ class QuantileHistMock : public QuantileHistMaker { template struct BuilderMock : public QuantileHistMaker::Builder { using RealImpl = QuantileHistMaker::Builder; - using GHistRowT = typename RealImpl::GHistRowT; BuilderMock(const TrainParam ¶m, std::unique_ptr pruner, DMatrix const *fmat, GenericParameter const* ctx)