From 4d81c741e91c7660648f02d77b61ede33cef8c8d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 22 Mar 2022 00:13:20 +0800 Subject: [PATCH] External memory support for hist (#7531) * Generate column matrix from gHistIndex. * Avoid synchronization with the sparse page once the cache is written. * Cleanups: Remove member variables/functions, change the update routine to look like approx and gpu_hist. * Remove pruner. --- amalgamation/xgboost-all0.cc | 9 +- demo/guide-python/external_memory.py | 14 +- demo/guide-python/feature_weights.py | 2 +- doc/tutorials/external_memory.rst | 13 +- src/common/column_matrix.h | 291 ++++++++------ src/data/ellpack_page_source.cu | 9 +- src/data/ellpack_page_source.h | 22 +- src/data/gradient_index.cc | 28 +- src/data/gradient_index.h | 7 +- src/data/gradient_index_format.cc | 7 +- src/data/gradient_index_page_source.cc | 11 +- src/data/gradient_index_page_source.h | 7 +- src/data/sparse_page_dmatrix.cc | 22 +- src/data/sparse_page_dmatrix.cu | 3 + src/data/sparse_page_source.h | 52 ++- src/tree/hist/histogram.h | 26 +- src/tree/hist/param.cc | 10 + src/tree/updater_approx.cc | 35 +- src/tree/updater_quantile_hist.cc | 367 +++++++----------- src/tree/updater_quantile_hist.h | 102 ++--- tests/cpp/common/test_column_matrix.cc | 14 +- .../test_gradient_index_page_raw_format.cc | 19 +- tests/cpp/tree/hist/test_histogram.cc | 6 + tests/cpp/tree/test_quantile_hist.cc | 157 +------- tests/python/test_data_iterator.py | 16 +- 25 files changed, 563 insertions(+), 686 deletions(-) create mode 100644 src/tree/hist/param.cc diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 46203387b..86df12d95 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -48,17 +48,18 @@ #include "../src/predictor/cpu_predictor.cc" // trees +#include "../src/tree/constraints.cc" +#include "../src/tree/hist/param.cc" #include "../src/tree/param.cc" #include "../src/tree/tree_model.cc" #include "../src/tree/tree_updater.cc" +#include "../src/tree/updater_approx.cc" #include "../src/tree/updater_colmaker.cc" -#include "../src/tree/updater_quantile_hist.cc" +#include "../src/tree/updater_histmaker.cc" #include "../src/tree/updater_prune.cc" +#include "../src/tree/updater_quantile_hist.cc" #include "../src/tree/updater_refresh.cc" #include "../src/tree/updater_sync.cc" -#include "../src/tree/updater_histmaker.cc" -#include "../src/tree/updater_approx.cc" -#include "../src/tree/constraints.cc" // linear #include "../src/linear/linear_updater.cc" diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index 3e864a53e..703ee8f6c 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -7,6 +7,9 @@ instead of Quantile DMatrix. The feature is not ready for production use yet. .. versionadded:: 1.5.0 + +See :doc:`the tutorial ` for more details. + """ import os import xgboost @@ -77,9 +80,14 @@ def main(tmpdir: str) -> xgboost.Booster: missing = np.NaN Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) - # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some - # caveats. This is still an experimental feature. - booster = xgboost.train({"tree_method": "approx"}, Xy, evals=[(Xy, "Train")]) + # Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in + # doc for details. + booster = xgboost.train( + {"tree_method": "approx", "max_depth": 2}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=10, + ) return booster diff --git a/demo/guide-python/feature_weights.py b/demo/guide-python/feature_weights.py index f0b4907aa..34c8ed440 100644 --- a/demo/guide-python/feature_weights.py +++ b/demo/guide-python/feature_weights.py @@ -27,7 +27,7 @@ def main(args): dtrain.set_info(feature_weights=fw) bst = xgboost.train({'tree_method': 'hist', - 'colsample_bynode': 0.5}, + 'colsample_bynode': 0.2}, dtrain, num_boost_round=10, evals=[(dtrain, 'd')]) feature_map = bst.get_fscore() diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index b9acf09cb..e90f4fcb4 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -127,9 +127,12 @@ the tree method still concatenate all the chunks into 1 final histogram index du performance reason, but in compressed format. So its scalability has an upper bound but still has lower memory cost in general. -******** -CPU Hist -******** +*********** +CPU Version +*********** -It's limited by the same factor of GPU Hist, except that gradient based sampling is not -yet supported on CPU. +For CPU histogram based tree methods (``approx``, ``hist``) it's recommended to use +``grow_policy=depthwise`` for performance reason. Iterating over data batches is slow, +with ``depthwise`` policy XGBoost can build a entire layer of tree nodes with a few +iterations, while with ``lossguide`` XGBoost needs to iterate over the data set for each +tree node. diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 051a4cd44..d289db05e 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017 by Contributors + * Copyright 2017-2022 by Contributors * \file column_matrix.h * \brief Utility for fast column-wise access * \author Philip Cho @@ -8,21 +8,22 @@ #ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_ +#include + +#include #include -#include #include -#include "hist_util.h" +#include + #include "../data/gradient_index.h" +#include "hist_util.h" namespace xgboost { namespace common { class ColumnMatrix; /*! \brief column type */ -enum ColumnType { - kDenseColumn, - kSparseColumn -}; +enum ColumnType : uint8_t { kDenseColumn, kSparseColumn }; /*! \brief a column storage, to be used with ApplySplit. Note that each bin id is stored as index[i] + index_base. @@ -34,9 +35,7 @@ class Column { static constexpr int32_t kMissingId = -1; Column(ColumnType type, common::Span index, const uint32_t index_base) - : type_(type), - index_(index), - index_base_(index_base) {} + : type_(type), index_(index), index_base_(index_base) {} virtual ~Column() = default; @@ -65,12 +64,11 @@ class Column { }; template -class SparseColumn: public Column { +class SparseColumn : public Column { public: - SparseColumn(ColumnType type, common::Span index, - uint32_t index_base, common::Span row_ind) - : Column(type, index, index_base), - row_ind_(row_ind) {} + SparseColumn(ColumnType type, common::Span index, uint32_t index_base, + common::Span row_ind) + : Column(type, index, index_base), row_ind_(row_ind) {} const size_t* GetRowData() const { return row_ind_.data(); } @@ -98,9 +96,7 @@ class SparseColumn: public Column { return p - row_data; } - size_t GetRowIdx(size_t idx) const { - return row_ind_.data()[idx]; - } + size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; } private: /* indexes of rows */ @@ -108,11 +104,10 @@ class SparseColumn: public Column { }; template -class DenseColumn: public Column { +class DenseColumn : public Column { public: - DenseColumn(ColumnType type, common::Span index, - uint32_t index_base, const std::vector& missing_flags, - size_t feature_offset) + DenseColumn(ColumnType type, common::Span index, uint32_t index_base, + const std::vector& missing_flags, size_t feature_offset) : Column(type, index, index_base), missing_flags_(missing_flags), feature_offset_(feature_offset) {} @@ -126,9 +121,7 @@ class DenseColumn: public Column { } } - size_t GetInitialState(const size_t first_row_id) const { - return 0; - } + size_t GetInitialState(const size_t first_row_id) const { return 0; } private: /* flags for missing values in dense columns */ @@ -141,28 +134,26 @@ class DenseColumn: public Column { class ColumnMatrix { public: // get number of features - inline bst_uint GetNumFeature() const { - return static_cast(type_.size()); - } + bst_feature_t GetNumFeature() const { return static_cast(type_.size()); } // construct column matrix from GHistIndexMatrix - inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold, int32_t n_threads) { - const int32_t nfeature = static_cast(gmat.cut.Ptrs().size() - 1); + inline void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, + int32_t n_threads) { + auto const nfeature = static_cast(gmat.cut.Ptrs().size() - 1); const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column feature_counts_.resize(nfeature); type_.resize(nfeature); std::fill(feature_counts_.begin(), feature_counts_.end(), 0); uint32_t max_val = std::numeric_limits::max(); - for (int32_t fid = 0; fid < nfeature; ++fid) { + for (bst_feature_t fid = 0; fid < nfeature; ++fid) { CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val); } bool all_dense = gmat.IsDense(); gmat.GetFeatureCounts(&feature_counts_[0]); // classify features - for (int32_t fid = 0; fid < nfeature; ++fid) { - if (static_cast(feature_counts_[fid]) - < sparse_threshold * nrow) { + for (bst_feature_t fid = 0; fid < nfeature; ++fid) { + if (static_cast(feature_counts_[fid]) < sparse_threshold * nrow) { type_[fid] = kSparseColumn; all_dense = false; } else { @@ -175,7 +166,7 @@ class ColumnMatrix { feature_offsets_.resize(nfeature + 1); size_t accum_index_ = 0; feature_offsets_[0] = accum_index_; - for (int32_t fid = 1; fid < nfeature + 1; ++fid) { + for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) { if (type_[fid - 1] == kDenseColumn) { accum_index_ += static_cast(nrow); } else { @@ -197,6 +188,7 @@ class ColumnMatrix { const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature); any_missing_ = !noMissingValues; + missing_flags_.clear(); if (noMissingValues) { missing_flags_.resize(feature_offsets_[nfeature], false); } else { @@ -207,33 +199,33 @@ class ColumnMatrix { if (all_dense) { BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); if (gmat_bin_size == kUint8BinsTypeSize) { - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, n_threads); } else if (gmat_bin_size == kUint16BinsTypeSize) { - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, n_threads); } else { CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize); - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, n_threads); } - /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize - but for ColumnMatrix we still have a chance to reduce the memory consumption */ + /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize + but for ColumnMatrix we still have a chance to reduce the memory consumption */ } else { if (bins_type_size_ == kUint8BinsTypeSize) { - SetIndex(gmat.index.data(), gmat, nfeature); + SetIndex(page, gmat.index.data(), gmat, nfeature); } else if (bins_type_size_ == kUint16BinsTypeSize) { - SetIndex(gmat.index.data(), gmat, nfeature); + SetIndex(page, gmat.index.data(), gmat, nfeature); } else { - CHECK_EQ(bins_type_size_, kUint32BinsTypeSize); - SetIndex(gmat.index.data(), gmat, nfeature); + CHECK_EQ(bins_type_size_, kUint32BinsTypeSize); + SetIndex(page, gmat.index.data(), gmat, nfeature); } } } /* Set the number of bytes based on numeric limit of maximum number of bins provided by user */ void SetTypeSize(size_t max_num_bins) { - if ( (max_num_bins - 1) <= static_cast(std::numeric_limits::max()) ) { + if ((max_num_bins - 1) <= static_cast(std::numeric_limits::max())) { bins_type_size_ = kUint8BinsTypeSize; } else if ((max_num_bins - 1) <= static_cast(std::numeric_limits::max())) { bins_type_size_ = kUint16BinsTypeSize; @@ -250,24 +242,24 @@ class ColumnMatrix { const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature const size_t column_size = feature_offsets_[fid + 1] - feature_offset; - common::Span bin_index = { reinterpret_cast( - &index_[feature_offset * bins_type_size_]), - column_size }; + common::Span bin_index = { + reinterpret_cast(&index_[feature_offset * bins_type_size_]), + column_size}; std::unique_ptr > res; if (type_[fid] == ColumnType::kDenseColumn) { CHECK_EQ(any_missing, any_missing_); res.reset(new DenseColumn(type_[fid], bin_index, index_base_[fid], - missing_flags_, feature_offset)); + missing_flags_, feature_offset)); } else { res.reset(new SparseColumn(type_[fid], bin_index, index_base_[fid], - {&row_ind_[feature_offset], column_size})); + {&row_ind_[feature_offset], column_size})); } return res; } template - inline void SetIndexAllDense(T const* index, const GHistIndexMatrix& gmat, const size_t nrow, - const size_t nfeature, const bool noMissingValues, + inline void SetIndexAllDense(SparsePage const& page, T const* index, const GHistIndexMatrix& gmat, + const size_t nrow, const size_t nfeature, const bool noMissingValues, int32_t n_threads) { T* local_index = reinterpret_cast(&index_[0]); @@ -275,98 +267,155 @@ class ColumnMatrix { and if no missing values were observed it could be handled much faster. */ if (noMissingValues) { ParallelFor(nrow, n_threads, [&](auto rid) { - const size_t ibegin = rid*nfeature; - const size_t iend = (rid+1)*nfeature; + const size_t ibegin = rid * nfeature; + const size_t iend = (rid + 1) * nfeature; size_t j = 0; for (size_t i = ibegin; i < iend; ++i, ++j) { - const size_t idx = feature_offsets_[j]; - local_index[idx + rid] = index[i]; + const size_t idx = feature_offsets_[j]; + local_index[idx + rid] = index[i]; } }); } else { /* to handle rows in all batches, sum of all batch sizes equal to gmat.row_ptr.size() - 1 */ - size_t rbegin = 0; - for (const auto &batch : gmat.p_fmat->GetBatches()) { - const xgboost::Entry* data_ptr = batch.data.HostVector().data(); - const std::vector& offset_vec = batch.offset.HostVector(); - const size_t batch_size = batch.Size(); - CHECK_LT(batch_size, offset_vec.size()); - for (size_t rid = 0; rid < batch_size; ++rid) { - const size_t size = offset_vec[rid + 1] - offset_vec[rid]; - SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; - const size_t ibegin = gmat.row_ptr[rbegin + rid]; - const size_t iend = gmat.row_ptr[rbegin + rid + 1]; - CHECK_EQ(ibegin + inst.size(), iend); - size_t j = 0; - size_t fid = 0; - for (size_t i = ibegin; i < iend; ++i, ++j) { - fid = inst[j].index; - const size_t idx = feature_offsets_[fid]; - /* rbegin allows to store indexes from specific SparsePage batch */ - local_index[idx + rbegin + rid] = index[i]; - missing_flags_[idx + rbegin + rid] = false; - } - } - rbegin += batch.Size(); + auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { + // T* begin = &local_index[feature_offsets_[fid]]; + const size_t idx = feature_offsets_[fid]; + /* rbegin allows to store indexes from specific SparsePage batch */ + local_index[idx + rid] = bin_id; + + missing_flags_[idx + rid] = false; + }; + this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx); + } + } + + // FIXME(jiamingy): In the future we might want to simply use binary search to simplify + // this and remove the dependency on SparsePage. This way we can have quantilized + // matrix for host similar to `DeviceQuantileDMatrix`. + template + void SetIndexSparse(SparsePage const& batch, T* index, const GHistIndexMatrix& gmat, + const size_t nfeature, BinFn&& assign_bin) { + std::vector num_nonzeros(nfeature, 0ul); + const xgboost::Entry* data_ptr = batch.data.HostVector().data(); + const std::vector& offset_vec = batch.offset.HostVector(); + auto rbegin = 0; + const size_t batch_size = gmat.Size(); + CHECK_LT(batch_size, offset_vec.size()); + + for (size_t rid = 0; rid < batch_size; ++rid) { + const size_t ibegin = gmat.row_ptr[rbegin + rid]; + const size_t iend = gmat.row_ptr[rbegin + rid + 1]; + const size_t size = offset_vec[rid + 1] - offset_vec[rid]; + SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; + + CHECK_EQ(ibegin + inst.size(), iend); + size_t j = 0; + for (size_t i = ibegin; i < iend; ++i, ++j) { + const uint32_t bin_id = index[i]; + auto fid = inst[j].index; + assign_bin(bin_id, rid, fid); } } } - template - inline void SetIndex(uint32_t const* index, const GHistIndexMatrix& gmat, + template + inline void SetIndex(SparsePage const& page, uint32_t const* index, const GHistIndexMatrix& gmat, const size_t nfeature) { + T* local_index = reinterpret_cast(&index_[0]); std::vector num_nonzeros; num_nonzeros.resize(nfeature); std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0); - T* local_index = reinterpret_cast(&index_[0]); - size_t rbegin = 0; - for (const auto &batch : gmat.p_fmat->GetBatches()) { - const xgboost::Entry* data_ptr = batch.data.HostVector().data(); - const std::vector& offset_vec = batch.offset.HostVector(); - const size_t batch_size = batch.Size(); - CHECK_LT(batch_size, offset_vec.size()); - for (size_t rid = 0; rid < batch_size; ++rid) { - const size_t ibegin = gmat.row_ptr[rbegin + rid]; - const size_t iend = gmat.row_ptr[rbegin + rid + 1]; - size_t fid = 0; - const size_t size = offset_vec[rid + 1] - offset_vec[rid]; - SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; - - CHECK_EQ(ibegin + inst.size(), iend); - size_t j = 0; - for (size_t i = ibegin; i < iend; ++i, ++j) { - const uint32_t bin_id = index[i]; - - fid = inst[j].index; - if (type_[fid] == kDenseColumn) { - T* begin = &local_index[feature_offsets_[fid]]; - begin[rid + rbegin] = bin_id - index_base_[fid]; - missing_flags_[feature_offsets_[fid] + rid + rbegin] = false; - } else { - T* begin = &local_index[feature_offsets_[fid]]; - begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; - row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid + rbegin; - ++num_nonzeros[fid]; - } - } + auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { + if (type_[fid] == kDenseColumn) { + T* begin = &local_index[feature_offsets_[fid]]; + begin[rid] = bin_id - index_base_[fid]; + missing_flags_[feature_offsets_[fid] + rid] = false; + } else { + T* begin = &local_index[feature_offsets_[fid]]; + begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; + row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid; + ++num_nonzeros[fid]; } - rbegin += batch.Size(); - } - } - BinTypeSize GetTypeSize() const { - return bins_type_size_; + }; + this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx); } + BinTypeSize GetTypeSize() const { return bins_type_size_; } + // This is just an utility function - bool NoMissingValues(const size_t n_elements, - const size_t n_row, const size_t n_features) { + bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) { return n_elements == n_features * n_row; } // And this returns part of state - bool AnyMissing() const { - return any_missing_; + bool AnyMissing() const { return any_missing_; } + + // IO procedures for external memory. + bool Read(dmlc::SeekStream* fi, uint32_t const* index_base) { + fi->Read(&index_); + fi->Read(&feature_counts_); +#if !DMLC_LITTLE_ENDIAN + // s390x + std::vector::type> int_types; + fi->Read(&int_types); + type_.resize(int_types.size()); + std::transform( + int_types.begin(), int_types.end(), type_.begin(), + [](std::underlying_type::type i) { return static_cast(i); }); +#else + fi->Read(&type_); +#endif // !DMLC_LITTLE_ENDIAN + + fi->Read(&row_ind_); + fi->Read(&feature_offsets_); + index_base_ = index_base; +#if !DMLC_LITTLE_ENDIAN + std::underlying_type::type v; + fi->Read(&v); + bins_type_size_ = static_cast(v); +#else + fi->Read(&bins_type_size_); +#endif + + fi->Read(&any_missing_); + return true; + } + + size_t Write(dmlc::Stream* fo) const { + size_t bytes{0}; + + auto write_vec = [&](auto const& vec) { + fo->Write(vec); + bytes += vec.size() * sizeof(typename std::remove_reference_t::value_type) + + sizeof(uint64_t); + }; + write_vec(index_); + write_vec(feature_counts_); +#if !DMLC_LITTLE_ENDIAN + // s390x + std::vector::type> int_types(type_.size()); + std::transform(type_.begin(), type_.end(), int_types.begin(), [](ColumnType t) { + return static_cast::type>(t); + }); + write_vec(int_types); +#else + write_vec(type_); +#endif // !DMLC_LITTLE_ENDIAN + write_vec(row_ind_); + write_vec(feature_offsets_); + +#if !DMLC_LITTLE_ENDIAN + auto v = static_cast::type>(bins_type_size_); + fo->Write(v); +#else + fo->Write(bins_type_size_); +#endif // DMLC_LITTLE_ENDIAN + bytes += sizeof(bins_type_size_); + fo->Write(any_missing_); + bytes += sizeof(any_missing_); + + return bytes; } private: diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 6d79250a0..872cb0cc6 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 XGBoost contributors + * Copyright 2019-2022 XGBoost contributors */ #include #include @@ -12,6 +12,13 @@ namespace data { void EllpackPageSource::Fetch() { dh::safe_cuda(cudaSetDevice(param_.gpu_id)); if (!this->ReadCache()) { + if (count_ != 0 && !sync_) { + // source is initialized to be the 0th page during construction, so when count_ is 0 + // there's no need to increment the source. + ++(*source_); + } + // This is not read from cache so we still need it to be synced with sparse page source. + CHECK_EQ(count_, source_->Iter()); auto const &csr = source_->Page(); this->page_.reset(new EllpackPage{}); auto *impl = this->page_->Impl(); diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 9a1551d53..dc0802472 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 by XGBoost Contributors + * Copyright 2019-2022 by XGBoost Contributors */ #ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ @@ -25,15 +25,17 @@ class EllpackPageSource : public PageSourceIncMixIn { std::unique_ptr cuts_; public: - EllpackPageSource( - float missing, int nthreads, bst_feature_t n_features, size_t n_batches, - std::shared_ptr cache, BatchParam param, - std::unique_ptr cuts, bool is_dense, - size_t row_stride, common::Span feature_types, - std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), - is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)}, - feature_types_{feature_types}, cuts_{std::move(cuts)} { + EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, + std::shared_ptr cache, BatchParam param, + std::unique_ptr cuts, bool is_dense, size_t row_stride, + common::Span feature_types, + std::shared_ptr source) + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false), + is_dense_{is_dense}, + row_stride_{row_stride}, + param_{std::move(param)}, + feature_types_{feature_types}, + cuts_{std::move(cuts)} { this->source_ = source; this->Fetch(); } diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index eef8f9519..907a8c45c 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -144,7 +144,6 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, hit_count.resize(nbins, 0); hit_count_tloc_.resize(n_threads * nbins, 0); - this->p_fmat = p_fmat; size_t new_size = 1; for (const auto &batch : p_fmat->GetBatches()) { new_size += batch.Size(); @@ -164,6 +163,16 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, prev_sum = row_ptr[rbegin + batch.Size()]; rbegin += batch.Size(); } + this->columns_ = std::make_unique(); + + // hessian is empty when hist tree method is used or when dataset is empty + if (hess.empty() && !std::isnan(sparse_thresh)) { + // hist + CHECK(!sorted_sketch); + for (auto const &page : p_fmat->GetBatches()) { + this->columns_->Init(page, *this, sparse_thresh, n_threads); + } + } } void GHistIndexMatrix::Init(SparsePage const &batch, common::Span ft, @@ -187,6 +196,10 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::SpanPushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); + this->columns_ = std::make_unique(); + if (!std::isnan(sparse_thresh)) { + this->columns_->Init(batch, *this, sparse_thresh, n_threads); + } } void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { @@ -205,4 +218,17 @@ void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { index.Resize((sizeof(uint32_t)) * n_index); } } + +common::ColumnMatrix const &GHistIndexMatrix::Transpose() const { + CHECK(columns_); + return *columns_; +} + +bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) { + return this->columns_->Read(fi, this->cut.Ptrs().data()); +} + +size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const { + return this->columns_->Write(fo); +} } // namespace xgboost diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 48e6b3716..5a41d7b2a 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -40,7 +40,6 @@ class GHistIndexMatrix { std::vector hit_count; /*! \brief The corresponding cuts */ common::HistogramCuts cut; - DMatrix* p_fmat; /*! \brief max_bin for each feature. */ size_t max_num_bins; /*! \brief base row index for current page (used by external memory) */ @@ -119,8 +118,12 @@ class GHistIndexMatrix { return row_ptr.empty() ? 0 : row_ptr.size() - 1; } + bool ReadColumnPage(dmlc::SeekStream* fi); + size_t WriteColumnPage(dmlc::Stream* fo) const; + + common::ColumnMatrix const& Transpose() const; + private: - // unused at the moment: https://github.com/dmlc/xgboost/pull/7531 std::unique_ptr columns_; std::vector hit_count_tloc_; bool isDense_; diff --git a/src/data/gradient_index_format.cc b/src/data/gradient_index_format.cc index ff260efbf..4b3fd0ea0 100644 --- a/src/data/gradient_index_format.cc +++ b/src/data/gradient_index_format.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2021 XGBoost contributors + * Copyright 2021-2022 XGBoost contributors */ #include "sparse_page_writer.h" #include "gradient_index.h" @@ -7,7 +7,6 @@ namespace xgboost { namespace data { - class GHistIndexRawFormat : public SparsePageFormat { public: bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override { @@ -50,6 +49,8 @@ class GHistIndexRawFormat : public SparsePageFormat { if (is_dense) { page->index.SetBinOffset(page->cut.Ptrs()); } + + page->ReadColumnPage(fi); return true; } @@ -81,6 +82,8 @@ class GHistIndexRawFormat : public SparsePageFormat { bytes += sizeof(page.base_rowid); fo->Write(page.IsDense()); bytes += sizeof(page.IsDense()); + + bytes += page.WriteColumnPage(fo); return bytes; } }; diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index 9ec69d904..09d8ada80 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -7,11 +7,18 @@ namespace xgboost { namespace data { void GradientIndexPageSource::Fetch() { if (!this->ReadCache()) { + if (count_ != 0 && !sync_) { + // source is initialized to be the 0th page during construction, so when count_ is 0 + // there's no need to increment the source. + ++(*source_); + } + // This is not read from cache so we still need it to be synced with sparse page source. + CHECK_EQ(count_, source_->Iter()); auto const& csr = source_->Page(); this->page_.reset(new GHistIndexMatrix()); CHECK_NE(cuts_.Values().size(), 0); - this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, - sparse_thresh_, nthreads_); + this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, sparse_thresh_, + nthreads_); this->WriteCache(); } } diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index 30b53a294..db71c1c6d 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -22,13 +22,14 @@ class GradientIndexPageSource : public PageSourceIncMixIn { public: GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, std::shared_ptr cache, BatchParam param, - common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat, + common::HistogramCuts cuts, bool is_dense, common::Span feature_types, std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, + std::isnan(param.sparse_thresh)), cuts_{std::move(cuts)}, is_dense_{is_dense}, - max_bin_per_feat_{max_bin_per_feat}, + max_bin_per_feat_{param.max_bin}, feature_types_{feature_types}, sparse_thresh_{param.sparse_thresh} { this->source_ = source; diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index a9fd9b7c1..a90150ce8 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -159,21 +159,6 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam ¶m) { CHECK_GE(param.max_bin, 2); - if (param.hess.empty() && !param.regen) { - // hist method doesn't support full external memory implementation, so we concatenate - // all index here. - if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { - this->InitializeSparsePage(); - ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh, - param.regen, ctx_.Threads()}); - this->InitializeSparsePage(); - batch_param_ = param; - } - auto begin_iter = BatchIterator( - new SimpleBatchIteratorImpl(ghist_index_page_)); - return BatchSet(begin_iter); - } - auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); this->InitializeSparsePage(); if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) { @@ -190,10 +175,9 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam ghist_index_source_.reset(); CHECK_NE(cuts.Values().size(), 0); auto ft = this->info_.feature_types.ConstHostSpan(); - ghist_index_source_.reset( - new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_, - this->n_batches_, cache_info_.at(id), param, std::move(cuts), - this->IsDense(), param.max_bin, ft, sparse_page_source_)); + ghist_index_source_.reset(new GradientIndexPageSource( + this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_, + cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_)); } else { CHECK(ghist_index_source_); ghist_index_source_->Reset(); diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 82e1f3ce0..b36a0e2a3 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -11,6 +11,9 @@ namespace data { BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) { CHECK_GE(param.gpu_id, 0); CHECK_GE(param.max_bin, 2); + if (!(batch_param_ != BatchParam{})) { + CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; + } auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); size_t row_stride = 0; this->InitializeSparsePage(); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 4bada04c8..0a3e32e75 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -23,6 +23,7 @@ #include "proxy_dmatrix.h" #include "../common/common.h" +#include "../common/timer.h" namespace xgboost { namespace data { @@ -118,26 +119,30 @@ class SparsePageSourceImpl : public BatchIteratorImpl { size_t n_prefetch_batches = std::min(kPreFetch, n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; size_t fetch_it = count_; + for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring - if (ring_->at(fetch_it).valid()) { continue; } + if (ring_->at(fetch_it).valid()) { + continue; + } auto const *self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() { + common::Timer timer; + timer.Start(); std::unique_ptr> fmt{CreatePageFormat("raw")}; auto n = self->cache_info_->ShardName(); size_t offset = self->cache_info_->offset.at(fetch_it); - std::unique_ptr fi{ - dmlc::SeekStream::CreateForRead(n.c_str())}; + std::unique_ptr fi{dmlc::SeekStream::CreateForRead(n.c_str())}; fi->Seek(offset); CHECK_EQ(fi->Tell(), offset); auto page = std::make_shared(); CHECK(fmt->Read(page.get(), fi.get())); + LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds."; return page; }); } - CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), - [](auto const &f) { return f.valid(); }), + CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; page_ = (*ring_)[count_].get(); @@ -146,12 +151,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl { void WriteCache() { CHECK(!cache_info_->written); + common::Timer timer; + timer.Start(); std::unique_ptr> fmt{CreatePageFormat("raw")}; if (!fo_) { auto n = cache_info_->ShardName(); fo_.reset(dmlc::Stream::Create(n.c_str(), "w")); } auto bytes = fmt->Write(*page_, fo_.get()); + timer.Stop(); + + LOG(INFO) << static_cast(bytes) / 1024.0 / 1024.0 << " MB written in " + << timer.ElapsedSeconds() << " seconds."; cache_info_->offset.push_back(bytes); } @@ -280,15 +291,24 @@ template class PageSourceIncMixIn : public SparsePageSourceImpl { protected: std::shared_ptr source_; + using Super = SparsePageSourceImpl; + // synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page + // so we avoid fetching it. + bool sync_{true}; public: - using SparsePageSourceImpl::SparsePageSourceImpl; + PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, + std::shared_ptr cache, bool sync) + : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {} + PageSourceIncMixIn& operator++() final { TryLockGuard guard{this->single_threaded_}; - ++(*source_); + if (sync_) { + ++(*source_); + } ++this->count_; - this->at_end_ = source_->AtEnd(); + this->at_end_ = this->count_ == this->n_batches_; if (this->at_end_) { this->cache_info_->Commit(); @@ -299,7 +319,10 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { } else { this->Fetch(); } - CHECK_EQ(source_->Iter(), this->count_); + + if (sync_) { + CHECK_EQ(source_->Iter(), this->count_); + } return *this; } }; @@ -318,12 +341,9 @@ class CSCPageSource : public PageSourceIncMixIn { } public: - CSCPageSource( - float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, - std::shared_ptr cache, - std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, - n_batches, cache) { + CSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, + std::shared_ptr cache, std::shared_ptr source) + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) { this->source_ = source; this->Fetch(); } @@ -349,7 +369,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn { SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache, std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache) { + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) { this->source_ = source; this->Fetch(); } diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 242825b25..6020de28d 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_ #define XGBOOST_TREE_HIST_HISTOGRAM_H_ @@ -8,10 +8,11 @@ #include #include -#include "rabit/rabit.h" -#include "xgboost/tree_model.h" #include "../../common/hist_util.h" #include "../../data/gradient_index.h" +#include "expand_entry.h" +#include "rabit/rabit.h" +#include "xgboost/tree_model.h" namespace xgboost { namespace tree { @@ -323,6 +324,25 @@ template class HistogramBuilder { (*sync_count) = std::max(1, n_left); } }; + +// Construct a work space for building histogram. Eventually we should move this +// function into histogram builder once hist tree method supports external memory. +template +common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners, + std::vector const &nodes_to_build) { + std::vector partition_size(nodes_to_build.size(), 0); + for (auto const &partition : partitioners) { + size_t k = 0; + for (auto node : nodes_to_build) { + auto n_rows_in_node = partition.Partitions()[node.nid].Size(); + partition_size[k] = std::max(partition_size[k], n_rows_in_node); + k++; + } + } + common::BlockedSpace2d space{ + nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256}; + return space; +} } // namespace tree } // namespace xgboost #endif // XGBOOST_TREE_HIST_HISTOGRAM_H_ diff --git a/src/tree/hist/param.cc b/src/tree/hist/param.cc new file mode 100644 index 000000000..05f1a24ad --- /dev/null +++ b/src/tree/hist/param.cc @@ -0,0 +1,10 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "param.h" + +namespace xgboost { +namespace tree { +DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); +} // namespace tree +} // namespace xgboost diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index c91c1018c..3bad6f7da 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -94,7 +94,7 @@ class GloablApproxBuilder { rabit::Allreduce(reinterpret_cast(&root_sum), 2); std::vector nodes{best}; size_t i = 0; - auto space = this->ConstructHistSpace(nodes); + auto space = ConstructHistSpace(partitioner_, nodes); for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, {}, gpair); @@ -123,25 +123,6 @@ class GloablApproxBuilder { monitor_->Stop(__func__); } - // Construct a work space for building histogram. Eventually we should move this - // function into histogram builder once hist tree method supports external memory. - common::BlockedSpace2d ConstructHistSpace( - std::vector const &nodes_to_build) const { - std::vector partition_size(nodes_to_build.size(), 0); - for (auto const &partition : partitioner_) { - size_t k = 0; - for (auto node : nodes_to_build) { - auto n_rows_in_node = partition.Partitions()[node.nid].Size(); - partition_size[k] = std::max(partition_size[k], n_rows_in_node); - k++; - } - } - common::BlockedSpace2d space{nodes_to_build.size(), - [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, - 256}; - return space; - } - void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, std::vector const &valid_candidates, std::vector const &gpair, common::Span hess) { @@ -164,7 +145,7 @@ class GloablApproxBuilder { } size_t i = 0; - auto space = this->ConstructHistSpace(nodes_to_build); + auto space = ConstructHistSpace(partitioner_, nodes_to_build); for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes_to_build, nodes_to_sub, gpair); @@ -191,7 +172,7 @@ class GloablApproxBuilder { Driver driver(static_cast(param_.grow_policy)); auto &tree = *p_tree; driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); - bst_node_t num_leaves = 1; + bst_node_t num_leaves{1}; auto expand_set = driver.Pop(); /** @@ -223,10 +204,10 @@ class GloablApproxBuilder { } monitor_->Start("UpdatePosition"); - size_t i = 0; + size_t page_id = 0; for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { - partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree); - i++; + partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree); + page_id++; } monitor_->Stop("UpdatePosition"); @@ -288,9 +269,9 @@ class GlobalApproxUpdater : public TreeUpdater { out["hist_param"] = ToJson(hist_param_); } - void InitData(TrainParam const ¶m, HostDeviceVector *gpair, + void InitData(TrainParam const ¶m, HostDeviceVector const *gpair, std::vector *sampled) { - auto const &h_gpair = gpair->HostVector(); + auto const &h_gpair = gpair->ConstHostVector(); sampled->resize(h_gpair.size()); std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); auto &rnd = common::GlobalRandom(); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 8fda5e8dc..0e1b6db47 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -4,80 +4,39 @@ * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Checn, Egor Smirnov */ -#include +#include "./updater_quantile_hist.h" + #include #include -#include -#include #include #include -#include #include #include #include +#include "../common/column_matrix.h" +#include "../common/hist_util.h" +#include "../common/random.h" +#include "../common/threading_utils.h" +#include "constraints.h" +#include "hist/evaluate_splits.h" +#include "param.h" #include "xgboost/logging.h" #include "xgboost/tree_updater.h" -#include "constraints.h" -#include "param.h" -#include "./updater_quantile_hist.h" -#include "./split_evaluator.h" -#include "../common/random.h" -#include "../common/hist_util.h" -#include "../common/row_set.h" -#include "../common/column_matrix.h" -#include "../common/threading_utils.h" - namespace xgboost { namespace tree { DMLC_REGISTRY_FILE_TAG(updater_quantile_hist); -DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); - -void QuantileHistMaker::Configure(const Args& args) { - // initialize pruner - if (!pruner_) { - pruner_.reset(TreeUpdater::Create("prune", ctx_, task_)); - } - pruner_->Configure(args); +void QuantileHistMaker::Configure(const Args &args) { param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args); } -template -void QuantileHistMaker::SetBuilder(const size_t n_trees, - std::unique_ptr>* builder, DMatrix* dmat) { - builder->reset( - new Builder(n_trees, param_, std::move(pruner_), dmat, task_, ctx_)); -} - -template -void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr>& builder, - HostDeviceVector *gpair, - DMatrix *dmat, - GHistIndexMatrix const& gmat, - const std::vector &trees) { - for (auto tree : trees) { - builder->Update(gmat, column_matrix_, gpair, dmat, tree); - } -} - -void QuantileHistMaker::Update(HostDeviceVector *gpair, - DMatrix *dmat, +void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { - auto it = dmat->GetBatches(HistBatch(param_)).begin(); - auto p_gmat = it.Page(); - if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { - updater_monitor_.Start("GmatInitialization"); - column_matrix_.Init(*p_gmat, param_.sparse_threshold, ctx_->Threads()); - updater_monitor_.Stop("GmatInitialization"); - // A proper solution is puting cut matrix in DMatrix, see: - // https://github.com/dmlc/xgboost/issues/5143 - is_gmat_initialized_ = true; - } // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); @@ -86,19 +45,23 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, const size_t n_trees = trees.size(); if (hist_maker_param_.single_precision_histogram) { if (!float_builder_) { - this->SetBuilder(n_trees, &float_builder_, dmat); + float_builder_.reset(new Builder(n_trees, param_, dmat, task_, ctx_)); } - CallBuilderUpdate(float_builder_, gpair, dmat, *p_gmat, trees); } else { if (!double_builder_) { - SetBuilder(n_trees, &double_builder_, dmat); + double_builder_.reset(new Builder(n_trees, param_, dmat, task_, ctx_)); + } + } + + for (auto p_tree : trees) { + if (hist_maker_param_.single_precision_histogram) { + this->float_builder_->UpdateTree(gpair, dmat, p_tree); + } else { + this->double_builder_->UpdateTree(gpair, dmat, p_tree); } - CallBuilderUpdate(double_builder_, gpair, dmat, *p_gmat, trees); } param_.learning_rate = lr; - - p_last_dmat_ = dmat; } bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data, @@ -113,23 +76,18 @@ bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data, } template -template -void QuantileHistMaker::Builder::InitRoot( - DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h, - int *num_leaves, std::vector *expand) { +CPUExpandEntry QuantileHistMaker::Builder::InitRoot( + DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); - nodes_for_explicit_hist_build_.push_back(node); - - auto const& row_set_collection = partitioner_.front().Partitions(); size_t page_id = 0; - for (auto const& gidx : - p_fmat->GetBatches(HistBatch(param_))) { - this->histogram_builder_->BuildHist( - page_id, gidx, p_tree, row_set_collection, - nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); + auto space = ConstructHistSpace(partitioner_, {node}); + for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { + std::vector nodes_to_build{node}; + std::vector nodes_to_sub; + this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree, + partitioner_.at(page_id).Partitions(), nodes_to_build, + nodes_to_sub, gpair_h); ++page_id; } @@ -165,168 +123,132 @@ void QuantileHistMaker::Builder::InitRoot( (*p_tree)[RegTree::kRoot].SetLeaf(param_.learning_rate * weight); std::vector entries{node}; - builder_monitor_->Start("EvaluateSplits"); + monitor_->Start("EvaluateSplits"); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); - for (auto const& gmat : p_fmat->GetBatches(HistBatch(param_))) { - evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, - *p_tree, &entries); + for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { + evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, &entries); break; } - builder_monitor_->Stop("EvaluateSplits"); + monitor_->Stop("EvaluateSplits"); node = entries.front(); } - expand->push_back(node); - ++(*num_leaves); + return node; } -template -void QuantileHistMaker::Builder::AddSplitsToTree( - const std::vector& expand, - RegTree *p_tree, - int *num_leaves, - std::vector* nodes_for_apply_split) { - for (auto const& entry : expand) { - if (entry.IsValid(param_, *num_leaves)) { - nodes_for_apply_split->push_back(entry); - evaluator_->ApplyTreeSplit(entry, p_tree); - (*num_leaves)++; - } - } -} - -// Split nodes to 2 sets depending on amount of rows in each node -// Histograms for small nodes will be built explicitly -// Histograms for big nodes will be built by 'Subtraction Trick' -// Exception: in distributed setting, we always build the histogram for the left child node -// and use 'Subtraction Trick' to built the histogram for the right child node. -// This ensures that the workers operate on the same set of tree nodes. template -void QuantileHistMaker::Builder::SplitSiblings( - const std::vector &nodes_for_apply_split, - std::vector *nodes_to_evaluate, RegTree *p_tree) { - builder_monitor_->Start("SplitSiblings"); - auto const& row_set_collection = this->partitioner_.front().Partitions(); - for (auto const& entry : nodes_for_apply_split) { - int nid = entry.nid; +void QuantileHistMaker::Builder::BuildHistogram( + DMatrix *p_fmat, RegTree *p_tree, std::vector const &valid_candidates, + std::vector const &gpair) { + std::vector nodes_to_build(valid_candidates.size()); + std::vector nodes_to_sub(valid_candidates.size()); - const int cleft = (*p_tree)[nid].LeftChild(); - const int cright = (*p_tree)[nid].RightChild(); - const CPUExpandEntry left_node = CPUExpandEntry(cleft, p_tree->GetDepth(cleft), 0.0); - const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0); - nodes_to_evaluate->push_back(left_node); - nodes_to_evaluate->push_back(right_node); - if (row_set_collection[cleft].Size() < row_set_collection[cright].Size()) { - nodes_for_explicit_hist_build_.push_back(left_node); - nodes_for_subtraction_trick_.push_back(right_node); - } else { - nodes_for_explicit_hist_build_.push_back(right_node); - nodes_for_subtraction_trick_.push_back(left_node); + size_t n_idx = 0; + for (auto const &c : valid_candidates) { + auto left_nidx = (*p_tree)[c.nid].LeftChild(); + auto right_nidx = (*p_tree)[c.nid].RightChild(); + auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); } + nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}; + nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}; + n_idx++; + } + + size_t page_id{0}; + auto space = ConstructHistSpace(partitioner_, nodes_to_build); + for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { + histogram_builder_->BuildHist(page_id, space, gidx, p_tree, + partitioner_.at(page_id).Partitions(), nodes_to_build, + nodes_to_sub, gpair); + ++page_id; } - CHECK_EQ(nodes_for_subtraction_trick_.size(), nodes_for_explicit_hist_build_.size()); - builder_monitor_->Stop("SplitSiblings"); } -template -template +template void QuantileHistMaker::Builder::ExpandTree( - const GHistIndexMatrix& gmat, - const common::ColumnMatrix& column_matrix, - DMatrix* p_fmat, - RegTree* p_tree, - const std::vector& gpair_h) { - builder_monitor_->Start("ExpandTree"); - int num_leaves = 0; + DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { + monitor_->Start(__func__); Driver driver(static_cast(param_.grow_policy)); - std::vector expand; - InitRoot(p_fmat, p_tree, gpair_h, &num_leaves, &expand); - driver.Push(expand[0]); + driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); + bst_node_t num_leaves{1}; + auto expand_set = driver.Pop(); - int32_t depth = 0; - while (!driver.IsEmpty()) { - expand = driver.Pop(); - depth = expand[0].depth + 1; - std::vector nodes_for_apply_split; - std::vector nodes_to_evaluate; - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); - - AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split); - - if (nodes_for_apply_split.size() != 0) { - HistRowPartitioner &partitioner = this->partitioner_.front(); - if (gmat.cut.HasCategorical()) { - partitioner.UpdatePosition(this->ctx_, gmat, column_matrix, - nodes_for_apply_split, p_tree); - } else { - partitioner.UpdatePosition(this->ctx_, gmat, column_matrix, - nodes_for_apply_split, p_tree); + while (!expand_set.empty()) { + // candidates that can be further splited. + std::vector valid_candidates; + // candidaates that can be applied. + std::vector applied; + int32_t depth = expand_set.front().depth + 1; + for (auto const& candidate : expand_set) { + if (!candidate.IsValid(param_, num_leaves)) { + continue; } - - SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); - - if (param_.max_depth == 0 || depth < param_.max_depth) { - size_t i = 0; - for (auto const &gidx : p_fmat->GetBatches(HistBatch(param_))) { - this->histogram_builder_->BuildHist(i, gidx, p_tree, partitioner_.front().Partitions(), - nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, gpair_h); - ++i; - } - } else { - int starting_index = std::numeric_limits::max(); - int sync_count = 0; - this->histogram_builder_->AddHistRows( - &starting_index, &sync_count, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, p_tree); - } - - builder_monitor_->Start("EvaluateSplits"); - auto ft = p_fmat->Info().feature_types.ConstHostSpan(); - evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), - gmat.cut, ft, *p_tree, &nodes_to_evaluate); - builder_monitor_->Stop("EvaluateSplits"); - - for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) { - CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0); - CPUExpandEntry right_node = nodes_to_evaluate.at(i * 2 + 1); - driver.Push(left_node); - driver.Push(right_node); + evaluator_->ApplyTreeSplit(candidate, p_tree); + applied.push_back(candidate); + num_leaves++; + if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) { + valid_candidates.emplace_back(candidate); } } + + monitor_->Start("UpdatePosition"); + size_t page_id{0}; + for (auto const &page : p_fmat->GetBatches(HistBatch(param_))) { + partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree); + ++page_id; + } + monitor_->Stop("UpdatePosition"); + + std::vector best_splits; + if (!valid_candidates.empty()) { + this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair_h); + auto const &tree = *p_tree; + for (auto const &candidate : valid_candidates) { + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); + CPUExpandEntry l_best{left_child_nidx, depth, 0.0}; + CPUExpandEntry r_best{right_child_nidx, depth, 0.0}; + best_splits.push_back(l_best); + best_splits.push_back(r_best); + } + auto const &histograms = histogram_builder_->Histogram(); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + for (auto const &gmat : p_fmat->GetBatches(HistBatch(param_))) { + evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, &best_splits); + break; + } + } + driver.Push(best_splits.begin(), best_splits.end()); + expand_set = driver.Pop(); } - builder_monitor_->Stop("ExpandTree"); + + monitor_->Stop(__func__); } template -void QuantileHistMaker::Builder::Update( - const GHistIndexMatrix &gmat, - const common::ColumnMatrix &column_matrix, - HostDeviceVector *gpair, - DMatrix *p_fmat, RegTree *p_tree) { - builder_monitor_->Start("Update"); +void QuantileHistMaker::Builder::UpdateTree(HostDeviceVector *gpair, + DMatrix *p_fmat, RegTree *p_tree) { + monitor_->Start(__func__); - std::vector* gpair_ptr = &(gpair->HostVector()); + std::vector *gpair_ptr = &(gpair->HostVector()); // in case 'num_parallel_trees != 1' no posibility to change initial gpair if (GetNumberOfTrees() != 1) { gpair_local_.resize(gpair_ptr->size()); gpair_local_ = *gpair_ptr; gpair_ptr = &gpair_local_; } - p_last_fmat_mutable_ = p_fmat; - this->InitData(gmat, p_fmat, *p_tree, gpair_ptr); + this->InitData(p_fmat, *p_tree, gpair_ptr); - if (column_matrix.AnyMissing()) { - ExpandTree(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr); - } else { - ExpandTree(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr); - } - pruner_->Update(gpair, p_fmat, std::vector{p_tree}); + ExpandTree(p_fmat, p_tree, *gpair_ptr); - builder_monitor_->Stop("Update"); + monitor_->Stop(__func__); } template @@ -334,21 +256,21 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( DMatrix const *data, linalg::VectorView out_preds) const { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // conjunction with Update(). - if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_ || - p_last_fmat_ != p_last_fmat_mutable_) { + if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { return false; } - builder_monitor_->Start(__func__); + monitor_->Start(__func__); CHECK_EQ(out_preds.Size(), data->Info().num_row_); UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, *evaluator_, param_, out_preds); - builder_monitor_->Stop(__func__); + monitor_->Stop(__func__); return true; } template void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, std::vector *gpair) { - const auto& info = fmat.Info(); + monitor_->Start(__func__); + const auto &info = fmat.Info(); auto& rnd = common::GlobalRandom(); std::vector& gpair_ref = *gpair; @@ -380,6 +302,7 @@ void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, } exc.Rethrow(); #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG + monitor_->Stop(__func__); } template size_t QuantileHistMaker::Builder::GetNumberOfTrees() { @@ -387,10 +310,9 @@ size_t QuantileHistMaker::Builder::GetNumberOfTrees() { } template -void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix &gmat, DMatrix *fmat, - const RegTree &tree, +void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, std::vector *gpair) { - builder_monitor_->Start("InitData"); + monitor_->Start(__func__); const auto& info = fmat->Info(); { @@ -406,18 +328,14 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix & partitioner_.emplace_back(page.Size(), page.base_rowid, this->ctx_->Threads()); ++page_id; } - histogram_builder_->Reset(n_total_bins, BatchParam{param_.max_bin, param_.sparse_threshold}, - ctx_->Threads(), page_id, rabit::IsDistributed()); + histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, + rabit::IsDistributed()); if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) - << "Only uniform sampling is supported, " - << "gradient-based sampling is only support by GPU Hist."; - builder_monitor_->Start("InitSampling"); + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; InitSampling(*fmat, gpair); - builder_monitor_->Stop("InitSampling"); - // We should check that the partitioning was done correctly - // and each row of the dataset fell into exactly one of the categories } } @@ -426,7 +344,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix & evaluator_.reset(new HistEvaluator{ param_, info, this->ctx_->Threads(), column_sampler_, task_}); - builder_monitor_->Stop("InitData"); + monitor_->Stop(__func__); } void HistRowPartitioner::FindSplitConditions(const std::vector &nodes, @@ -470,21 +388,8 @@ void HistRowPartitioner::AddSplitsToRowSet(const std::vector &no template struct QuantileHistMaker::Builder; template struct QuantileHistMaker::Builder; -XGBOOST_REGISTER_TREE_UPDATER(FastHistMaker, "grow_fast_histmaker") -.describe("(Deprecated, use grow_quantile_histmaker instead.)" - " Grow tree using quantized histogram.") -.set_body( - [](ObjInfo task) { - LOG(WARNING) << "grow_fast_histmaker is deprecated, " - << "use grow_quantile_histmaker instead."; - return new QuantileHistMaker(task); - }); - XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") -.describe("Grow tree using quantized histogram.") -.set_body( - [](ObjInfo task) { - return new QuantileHistMaker(task); - }); + .describe("Grow tree using quantized histogram.") + .set_body([](ObjInfo task) { return new QuantileHistMaker(task); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index b7d028113..3c03a371e 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -7,7 +7,6 @@ #ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ #define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_ -#include #include #include @@ -29,7 +28,6 @@ #include "constraints.h" #include "./param.h" #include "./driver.h" -#include "./split_evaluator.h" #include "../common/random.h" #include "../common/timer.h" #include "../common/hist_util.h" @@ -194,6 +192,24 @@ class HistRowPartitioner { AddSplitsToRowSet(nodes, p_tree); } + void UpdatePosition(GenericParameter const* ctx, GHistIndexMatrix const& page, + std::vector const& applied, RegTree const* p_tree) { + auto const& column_matrix = page.Transpose(); + if (page.cut.HasCategorical()) { + if (column_matrix.AnyMissing()) { + this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); + } else { + this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); + } + } else { + if (column_matrix.AnyMissing()) { + this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); + } else { + this->template UpdatePosition(ctx, page, column_matrix, applied, p_tree); + } + } + } + auto const& Partitions() const { return row_set_collection_; } size_t Size() const { return std::distance(row_set_collection_.begin(), row_set_collection_.end()); @@ -209,9 +225,7 @@ inline BatchParam HistBatch(TrainParam const& param) { /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: - explicit QuantileHistMaker(ObjInfo task) : task_{task} { - updater_monitor_.Init("QuantileHistMaker"); - } + explicit QuantileHistMaker(ObjInfo task) : task_{task} {} void Configure(const Args& args) override; void Update(HostDeviceVector* gpair, @@ -256,10 +270,6 @@ class QuantileHistMaker: public TreeUpdater { CPUHistMakerTrainParam hist_maker_param_; // training parameter TrainParam param_; - // column accessor - common::ColumnMatrix column_matrix_; - DMatrix const* p_last_dmat_ {nullptr}; - bool is_gmat_initialized_ {false}; // actual builder that runs the algorithm template @@ -267,60 +277,40 @@ class QuantileHistMaker: public TreeUpdater { public: using GradientPairT = xgboost::detail::GradientPairInternal; // constructor - explicit Builder(const size_t n_trees, const TrainParam& param, - std::unique_ptr pruner, DMatrix const* fmat, ObjInfo task, - GenericParameter const* ctx) + explicit Builder(const size_t n_trees, const TrainParam& param, DMatrix const* fmat, + ObjInfo task, GenericParameter const* ctx) : n_trees_(n_trees), param_(param), - pruner_(std::move(pruner)), p_last_fmat_(fmat), histogram_builder_{new HistogramBuilder}, task_{task}, ctx_{ctx}, - builder_monitor_{std::make_unique()} { - builder_monitor_->Init("Quantile::Builder"); + monitor_{std::make_unique()} { + monitor_->Init("Quantile::Builder"); } // update one tree, growing - void Update(const GHistIndexMatrix& gmat, const common::ColumnMatrix& column_matrix, - HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree); + void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree); bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView out_preds) const; - protected: + private: // initialize temp data structure - void InitData(const GHistIndexMatrix& gmat, DMatrix* fmat, const RegTree& tree, - std::vector* gpair); + void InitData(DMatrix* fmat, const RegTree& tree, std::vector* gpair); size_t GetNumberOfTrees(); void InitSampling(const DMatrix& fmat, std::vector* gpair); - template - void InitRoot(DMatrix* p_fmat, - RegTree *p_tree, - const std::vector &gpair_h, - int *num_leaves, std::vector *expand); + CPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree, + const std::vector& gpair_h); - // Split nodes to 2 sets depending on amount of rows in each node - // Histograms for small nodes will be built explicitly - // Histograms for big nodes will be built by 'Subtraction Trick' - void SplitSiblings(const std::vector& nodes, - std::vector* nodes_to_evaluate, - RegTree *p_tree); + void BuildHistogram(DMatrix* p_fmat, RegTree* p_tree, + std::vector const& valid_candidates, + std::vector const& gpair); - void AddSplitsToTree(const std::vector& expand, - RegTree *p_tree, - int *num_leaves, - std::vector* nodes_for_apply_split); + void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h); - template - void ExpandTree(const GHistIndexMatrix& gmat, - const common::ColumnMatrix& column_matrix, - DMatrix* p_fmat, - RegTree* p_tree, - const std::vector& gpair_h); - - // --data fields-- + private: const size_t n_trees_; const TrainParam& param_; std::shared_ptr column_sampler_{ @@ -328,48 +318,24 @@ class QuantileHistMaker: public TreeUpdater { std::vector gpair_local_; - std::unique_ptr pruner_; std::unique_ptr> evaluator_; - // Right now there's only 1 partitioner in this vector, when external memory is fully - // supported we will have number of partitioners equal to number of pages. std::vector partitioner_; // back pointers to tree and data matrix const RegTree* p_last_tree_{nullptr}; DMatrix const* const p_last_fmat_; - DMatrix* p_last_fmat_mutable_; - // key is the node id which should be calculated by Subtraction Trick, value is the node which - // provides the evidence for subtraction - std::vector nodes_for_subtraction_trick_; - // list of nodes whose histograms would be built explicitly. - std::vector nodes_for_explicit_hist_build_; - - enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; std::unique_ptr> histogram_builder_; ObjInfo task_; // Context for number of threads GenericParameter const* ctx_; - std::unique_ptr builder_monitor_; + std::unique_ptr monitor_; }; - common::Monitor updater_monitor_; - - template - void SetBuilder(const size_t n_trees, std::unique_ptr>*, DMatrix *dmat); - - template - void CallBuilderUpdate(const std::unique_ptr>& builder, - HostDeviceVector *gpair, - DMatrix *dmat, - GHistIndexMatrix const& gmat, - const std::vector &trees); protected: std::unique_ptr> float_builder_; std::unique_ptr> double_builder_; - - std::unique_ptr pruner_; ObjInfo task_; }; } // namespace tree diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 46d89fe97..2626b6fb3 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -21,7 +21,9 @@ TEST(DenseColumn, Test) { GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); + for (auto const& page : dmat->GetBatches()) { + column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); + } for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) { @@ -68,7 +70,9 @@ TEST(SparseColumn, Test) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0)); + for (auto const& page : dmat->GetBatches()) { + column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0)); + } switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { auto col = column_matrix.GetColumn(0); @@ -106,9 +110,11 @@ TEST(DenseColumnWithMissing, Test) { static_cast(std::numeric_limits::max()) + 2}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); - GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)}; + GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); + for (auto const& page : dmat->GetBatches()) { + column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); + } switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { auto col = column_matrix.GetColumn(0); diff --git a/tests/cpp/data/test_gradient_index_page_raw_format.cc b/tests/cpp/data/test_gradient_index_page_raw_format.cc index b24ee8770..fa1a10faa 100644 --- a/tests/cpp/data/test_gradient_index_page_raw_format.cc +++ b/tests/cpp/data/test_gradient_index_page_raw_format.cc @@ -3,6 +3,7 @@ */ #include +#include "../../../src/common/column_matrix.h" #include "../../../src/data/gradient_index.h" #include "../../../src/data/sparse_page_source.h" #include "../helpers.h" @@ -15,33 +16,31 @@ TEST(GHistIndexPageRawFormat, IO) { auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); dmlc::TemporaryDirectory tmpdir; std::string path = tmpdir.path + "/ghistindex.page"; + auto batch = BatchParam{256, 0.5}; { std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; - for (auto const &index : - m->GetBatches({GenericParameter::kCpuId, 256})) { + for (auto const &index : m->GetBatches(batch)) { format->Write(index, fo.get()); } } GHistIndexMatrix page; - std::unique_ptr fi{ - dmlc::SeekStream::CreateForRead(path.c_str())}; + std::unique_ptr fi{dmlc::SeekStream::CreateForRead(path.c_str())}; format->Read(&page, fi.get()); - for (auto const &gidx : - m->GetBatches({GenericParameter::kCpuId, 256})) { + for (auto const &gidx : m->GetBatches(batch)) { auto const &loaded = gidx; ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs()); ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues()); ASSERT_EQ(loaded.cut.Values(), page.cut.Values()); ASSERT_EQ(loaded.base_rowid, page.base_rowid); ASSERT_EQ(loaded.IsDense(), page.IsDense()); - ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), - page.index.begin())); - ASSERT_TRUE(std::equal(loaded.index.Offset(), - loaded.index.Offset() + loaded.index.OffsetSize(), + ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin())); + ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(), page.index.Offset())); + + ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize()); } } } // namespace data diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 1fa229999..06147afa3 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -446,6 +446,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { TEST(CPUHistogram, ExternalMemory) { int32_t constexpr kBins = 256; TestHistogramExternalMemory(BatchParam{kBins, common::Span{}, false}, true); + + float sparse_thresh{0.5}; + TestHistogramExternalMemory({kBins, sparse_thresh}, false); + sparse_thresh = std::numeric_limits::quiet_NaN(); + TestHistogramExternalMemory({kBins, sparse_thresh}, false); + } } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 6286ec9ca..0c89cd5e8 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -18,138 +18,6 @@ namespace xgboost { namespace tree { - -class QuantileHistMock : public QuantileHistMaker { - static double constexpr kEps = 1e-6; - - template - struct BuilderMock : public QuantileHistMaker::Builder { - using RealImpl = QuantileHistMaker::Builder; - - BuilderMock(const TrainParam ¶m, std::unique_ptr pruner, - DMatrix const *fmat, GenericParameter const* ctx) - : RealImpl(1, param, std::move(pruner), fmat, ObjInfo{ObjInfo::kRegression}, ctx) {} - - public: - void TestInitData(const GHistIndexMatrix& gmat, - std::vector* gpair, - DMatrix* p_fmat, - const RegTree& tree) { - RealImpl::InitData(gmat, p_fmat, tree, gpair); - - /* The creation of HistCutMatrix and GHistIndexMatrix are not technically - * part of QuantileHist updater logic, but we include it here because - * QuantileHist updater object currently stores GHistIndexMatrix - * internally. According to https://github.com/dmlc/xgboost/pull/3803, - * we should eventually move GHistIndexMatrix out of the QuantileHist - * updater. */ - - const size_t num_row = p_fmat->Info().num_row_; - const size_t num_col = p_fmat->Info().num_col_; - /* Validate HistCutMatrix */ - ASSERT_EQ(gmat.cut.Ptrs().size(), num_col + 1); - for (size_t fid = 0; fid < num_col; ++fid) { - const size_t ibegin = gmat.cut.Ptrs()[fid]; - const size_t iend = gmat.cut.Ptrs()[fid + 1]; - // Ordered, but empty feature is allowed. - ASSERT_LE(ibegin, iend); - for (size_t i = ibegin; i < iend - 1; ++i) { - // Quantile points must be sorted in ascending order - // No duplicates allowed - ASSERT_LT(gmat.cut.Values()[i], gmat.cut.Values()[i + 1]) - << "ibegin: " << ibegin << ", " - << "iend: " << iend; - } - } - - /* Validate GHistIndexMatrix */ - ASSERT_EQ(gmat.row_ptr.size(), num_row + 1); - ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()), - gmat.cut.Ptrs().back()); - for (const auto& batch : p_fmat->GetBatches()) { - auto page = batch.GetView(); - for (size_t i = 0; i < batch.Size(); ++i) { - const size_t rid = batch.base_rowid + i; - ASSERT_LT(rid, num_row); - const size_t gmat_row_offset = gmat.row_ptr[rid]; - ASSERT_LT(gmat_row_offset, gmat.index.Size()); - SparsePage::Inst inst = page[i]; - ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]); - for (size_t j = 0; j < inst.size(); ++j) { - // Each entry of GHistIndexMatrix represents a bin ID - const size_t bin_id = gmat.index[gmat_row_offset + j]; - const size_t fid = inst[j].index; - // The bin ID must correspond to correct feature - ASSERT_GE(bin_id, gmat.cut.Ptrs()[fid]); - ASSERT_LT(bin_id, gmat.cut.Ptrs()[fid + 1]); - // The bin ID must correspond to a region between two - // suitable quantile points - ASSERT_LT(inst[j].fvalue, gmat.cut.Values()[bin_id]); - if (bin_id > gmat.cut.Ptrs()[fid]) { - ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]); - } else { - ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]); - } - } - } - } - } - }; - - int static constexpr kNRows = 8, kNCols = 16; - std::shared_ptr dmat_; - GenericParameter ctx_; - const std::vector > cfg_; - std::shared_ptr > float_builder_; - std::shared_ptr > double_builder_; - - public: - explicit QuantileHistMock( - const std::vector >& args, - const bool single_precision_histogram = false, bool batch = true) : - QuantileHistMaker{ObjInfo{ObjInfo::kRegression}}, cfg_{args} { - QuantileHistMaker::Configure(args); - dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); - ctx_.UpdateAllowUnknown(Args{}); - if (single_precision_histogram) { - float_builder_.reset(new BuilderMock(param_, std::move(pruner_), dmat_.get(), &ctx_)); - } else { - double_builder_.reset( - new BuilderMock(param_, std::move(pruner_), dmat_.get(), &ctx_)); - } - } - ~QuantileHistMock() override = default; - - static size_t GetNumColumns() { return kNCols; } - - void TestInitData() { - int32_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)}; - - RegTree tree = RegTree(); - tree.param.UpdateAllowUnknown(cfg_); - - std::vector gpair = - { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, - {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; - if (double_builder_) { - double_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree); - } else { - float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree); - } - } -}; - -TEST(QuantileHist, InitData) { - std::vector> cfg - {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}}; - QuantileHistMock maker(cfg); - maker.TestInitData(); - const bool single_precision_histogram = true; - QuantileHistMock maker_float(cfg, single_precision_histogram); - maker_float.TestInitData(); -} - TEST(QuantileHist, Partitioner) { size_t n_samples = 1024, n_features = 1, base_rowid = 0; GenericParameter ctx; @@ -163,45 +31,44 @@ TEST(QuantileHist, Partitioner) { auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); std::vector candidates{{0, 0, 0.4}}; - auto grad = GenerateRandomGradients(n_samples); - std::vector hess(grad.Size()); - std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(), - [](auto gpair) { return gpair.GetHess(); }); + auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads()); - for (auto const& page : Xy->GetBatches({64, 0.5})) { + for (auto const& page : Xy->GetBatches()) { + GHistIndexMatrix gmat; + gmat.Init(page, {}, cuts, 64, false, 0.5, ctx.Threads()); bst_feature_t const split_ind = 0; common::ColumnMatrix column_indices; - column_indices.Init(page, 0.5, ctx.Threads()); + column_indices.Init(page, gmat, 0.5, ctx.Threads()); { - auto min_value = page.cut.MinValues()[split_ind]; + auto min_value = gmat.cut.MinValues()[split_ind]; RegTree tree; HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()}; GetSplit(&tree, min_value, &candidates); - partitioner.UpdatePosition(&ctx, page, column_indices, candidates, &tree); + partitioner.UpdatePosition(&ctx, gmat, column_indices, candidates, &tree); ASSERT_EQ(partitioner.Size(), 3); ASSERT_EQ(partitioner[1].Size(), 0); ASSERT_EQ(partitioner[2].Size(), n_samples); } { HistRowPartitioner partitioner{n_samples, base_rowid, ctx.Threads()}; - auto ptr = page.cut.Ptrs()[split_ind + 1]; - float split_value = page.cut.Values().at(ptr / 2); + auto ptr = gmat.cut.Ptrs()[split_ind + 1]; + float split_value = gmat.cut.Values().at(ptr / 2); RegTree tree; GetSplit(&tree, split_value, &candidates); auto left_nidx = tree[RegTree::kRoot].LeftChild(); - partitioner.UpdatePosition(&ctx, page, column_indices, candidates, &tree); + partitioner.UpdatePosition(&ctx, gmat, column_indices, candidates, &tree); auto elem = partitioner[left_nidx]; ASSERT_LT(elem.Size(), n_samples); ASSERT_GT(elem.Size(), 1); for (auto it = elem.begin; it != elem.end; ++it) { - auto value = page.cut.Values().at(page.index[*it]); + auto value = gmat.cut.Values().at(gmat.index[*it]); ASSERT_LE(value, split_value); } auto right_nidx = tree[RegTree::kRoot].RightChild(); elem = partitioner[right_nidx]; for (auto it = elem.begin; it != elem.end; ++it) { - auto value = page.cut.Values().at(page.index[*it]); + auto value = gmat.cut.Values().at(gmat.index[*it]); ASSERT_GT(value, split_value) << *it; } } diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 946127d13..e4254bb9e 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -1,7 +1,7 @@ import xgboost as xgb from xgboost.data import SingleBatchInternalIter as SingleBatch import numpy as np -from testing import IteratorForTest +from testing import IteratorForTest, non_increasing from typing import Tuple, List import pytest from hypothesis import given, strategies, settings @@ -108,7 +108,7 @@ def run_data_iterator( evals_result=results_from_it, verbose_eval=False, ) - it_predt = from_it.predict(Xy) + assert non_increasing(results_from_it["Train"]["rmse"]) X, y = it.as_arrays() Xy = xgb.DMatrix(X, y) @@ -125,13 +125,13 @@ def run_data_iterator( verbose_eval=False, ) arr_predt = from_arrays.predict(Xy) + assert non_increasing(results_from_arrays["Train"]["rmse"]) - if tree_method != "gpu_hist": - rtol = 1e-1 # flaky - else: - # Model can be sensitive to quantiles, use 1e-2 to relax the test. - np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-2) - rtol = 1e-6 + rtol = 1e-2 + # CPU sketching is more memory efficient but less consistent due to small chunks + it_predt = from_it.predict(Xy) + arr_predt = from_arrays.predict(Xy) + np.testing.assert_allclose(it_predt, arr_predt, rtol=rtol) np.testing.assert_allclose( results_from_it["Train"]["rmse"],