From 2775c2a1abd4b5b759ff517617434c8b9aeb4cc0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 10 Feb 2022 16:58:02 +0800 Subject: [PATCH] Prepare external memory support for hist. (#7638) This PR prepares the GHistIndexMatrix to host the column matrix which is used by the hist tree method by accepting sparse_threshold parameter. Some cleanups are made to ensure the correct batch param is being passed into DMatrix along with some additional tests for correctness of SimpleDMatrix. --- include/xgboost/data.h | 30 +++-- src/common/column_matrix.h | 3 +- src/data/gradient_index.cc | 43 +++--- src/data/gradient_index.h | 50 ++++--- src/data/gradient_index_page_source.cc | 4 +- src/data/gradient_index_page_source.h | 23 ++-- src/data/simple_dmatrix.cc | 22 ++-- src/data/simple_dmatrix.h | 2 +- src/data/sparse_page_dmatrix.cc | 4 +- src/predictor/gpu_predictor.cu | 4 +- src/tree/updater_approx.cc | 4 +- src/tree/updater_quantile_hist.cc | 19 +-- src/tree/updater_quantile_hist.h | 4 + tests/cpp/common/test_column_matrix.cc | 37 +++--- tests/cpp/common/test_hist_util.cc | 4 +- tests/cpp/data/test_ellpack_page.cu | 4 +- tests/cpp/data/test_gradient_index.cc | 22 +++- .../cpp/data/test_iterative_device_dmatrix.cu | 4 +- tests/cpp/predictor/test_gpu_predictor.cu | 1 + tests/cpp/tree/hist/test_evaluate_splits.cc | 7 +- tests/cpp/tree/hist/test_histogram.cc | 98 ++++++-------- tests/cpp/tree/test_approx.cc | 41 +++--- tests/cpp/tree/test_quantile_hist.cc | 15 ++- tests/cpp/tree/test_regen.cc | 124 ++++++++++++++++++ 24 files changed, 368 insertions(+), 201 deletions(-) create mode 100644 tests/cpp/tree/test_regen.cc diff --git a/include/xgboost/data.h b/include/xgboost/data.h index a728cdd90..7399b8265 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -217,24 +218,33 @@ struct BatchParam { common::Span hess; /*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */ bool regen {false}; + /*! \brief Parameter used to generate column matrix for hist. */ + double sparse_thresh{std::numeric_limits::quiet_NaN()}; BatchParam() = default; + // GPU Hist BatchParam(int32_t device, int32_t max_bin) : gpu_id{device}, max_bin{max_bin} {} + // Hist + BatchParam(int32_t max_bin, double sparse_thresh) + : max_bin{max_bin}, sparse_thresh{sparse_thresh} {} + // Approx /** * \brief Get batch with sketch weighted by hessian. The batch will be regenerated if * the span is changed, so caller should keep the span for each iteration. */ - BatchParam(int32_t device, int32_t max_bin, common::Span hessian, - bool regenerate = false) - : gpu_id{device}, max_bin{max_bin}, hess{hessian}, regen{regenerate} {} + BatchParam(int32_t max_bin, common::Span hessian, bool regenerate) + : max_bin{max_bin}, hess{hessian}, regen{regenerate} {} - bool operator!=(const BatchParam& other) const { + bool operator!=(BatchParam const& other) const { if (hess.empty() && other.hess.empty()) { return gpu_id != other.gpu_id || max_bin != other.max_bin; } return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data(); } + bool operator==(BatchParam const& other) const { + return !(*this != other); + } }; struct HostSparsePageView { @@ -477,8 +487,10 @@ class DMatrix { /** * \brief Gets batches. Use range based for loop over BatchSet to access individual batches. */ - template - BatchSet GetBatches(const BatchParam& param = {}); + template + BatchSet GetBatches(); + template + BatchSet GetBatches(const BatchParam& param); template bool PageExists() const; @@ -592,7 +604,7 @@ class DMatrix { }; template<> -inline BatchSet DMatrix::GetBatches(const BatchParam&) { +inline BatchSet DMatrix::GetBatches() { return GetRowBatches(); } @@ -607,12 +619,12 @@ inline bool DMatrix::PageExists() const { } template<> -inline BatchSet DMatrix::GetBatches(const BatchParam&) { +inline BatchSet DMatrix::GetBatches() { return GetColumnBatches(); } template<> -inline BatchSet DMatrix::GetBatches(const BatchParam&) { +inline BatchSet DMatrix::GetBatches() { return GetSortedColumnBatches(); } diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index b652bcc4a..747004cc0 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -379,12 +379,11 @@ class ColumnMatrix { std::vector feature_offsets_; // index_base_[fid]: least bin id for feature fid - uint32_t* index_base_; + uint32_t const* index_base_; std::vector missing_flags_; BinTypeSize bins_type_size_; bool any_missing_; }; - } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COLUMN_MATRIX_H_ diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 4f815fd04..abd80264d 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -2,13 +2,27 @@ * Copyright 2017-2022 by XGBoost Contributors * \brief Data type for fast histogram aggregation. */ +#include "gradient_index.h" + #include #include -#include "gradient_index.h" +#include + +#include "../common/column_matrix.h" #include "../common/hist_util.h" namespace xgboost { +GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique()} {} + +GHistIndexMatrix::GHistIndexMatrix(DMatrix *x, int32_t max_bin, double sparse_thresh, + bool sorted_sketch, int32_t n_threads, + common::Span hess) { + this->Init(x, max_bin, sparse_thresh, sorted_sketch, n_threads, hess); +} + +GHistIndexMatrix::~GHistIndexMatrix() = default; + void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span ft, size_t rbegin, size_t prev_sum, uint32_t nbins, @@ -90,18 +104,15 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, [offsets](auto idx, auto j) { return static_cast(idx - offsets[j]); }); - } else if (curent_bin_size == common::kUint16BinsTypeSize) { - common::Span index_data_span = {index.data(), - n_index}; + common::Span index_data_span = {index.data(), n_index}; SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, [offsets](auto idx, auto j) { return static_cast(idx - offsets[j]); }); } else { CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); - common::Span index_data_span = {index.data(), - n_index}; + common::Span index_data_span = {index.data(), n_index}; SetIndexData(index_data_span, ft, batch_threads, batch, rbegin, nbins, [offsets](auto idx, auto j) { return static_cast(idx - offsets[j]); @@ -125,8 +136,8 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, }); } -void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, int32_t n_threads, - common::Span hess) { +void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch, + int32_t n_threads, common::Span hess) { // We use sorted sketching for approx tree method since it's more efficient in // computation time (but higher memory usage). cut = common::SketchOnDMatrix(p_fmat, max_bins, n_threads, sorted_sketch, hess); @@ -158,11 +169,9 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch, i } } -void GHistIndexMatrix::Init(SparsePage const &batch, - common::Span ft, - common::HistogramCuts const &cuts, - int32_t max_bins_per_feat, bool isDense, - int32_t n_threads) { +void GHistIndexMatrix::Init(SparsePage const &batch, common::Span ft, + common::HistogramCuts const &cuts, int32_t max_bins_per_feat, + bool isDense, double sparse_thresh, int32_t n_threads) { CHECK_GE(n_threads, 1); base_rowid = batch.base_rowid; isDense_ = isDense; @@ -183,13 +192,13 @@ void GHistIndexMatrix::Init(SparsePage const &batch, this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); } -void GHistIndexMatrix::ResizeIndex(const size_t n_index, - const bool isDense) { +void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { index.SetBinTypeSize(common::kUint8BinsTypeSize); index.Resize((sizeof(uint8_t)) * n_index); - } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && - max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && + max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && + isDense) { index.SetBinTypeSize(common::kUint16BinsTypeSize); index.Resize((sizeof(uint16_t)) * n_index); } else { diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 76062b57c..83da8c784 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -4,12 +4,14 @@ */ #ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ #define XGBOOST_DATA_GRADIENT_INDEX_H_ +#include #include -#include "xgboost/base.h" -#include "xgboost/data.h" + #include "../common/categorical.h" #include "../common/hist_util.h" #include "../common/threading_utils.h" +#include "xgboost/base.h" +#include "xgboost/data.h" namespace xgboost { /*! @@ -32,20 +34,22 @@ class GHistIndexMatrix { /*! \brief The corresponding cuts */ common::HistogramCuts cut; DMatrix* p_fmat; + /*! \brief max_bin for each feature. */ size_t max_num_bins; + /*! \brief base row index for current page (used by external memory) */ size_t base_rowid{0}; - GHistIndexMatrix() = default; - GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, int32_t n_threads, - common::Span hess = {}) { - this->Init(x, max_bin, sorted_sketch, n_threads, hess); - } + GHistIndexMatrix(); + GHistIndexMatrix(DMatrix* x, int32_t max_bin, double sparse_thresh, bool sorted_sketch, + int32_t n_threads, common::Span hess = {}); + ~GHistIndexMatrix(); + // Create a global histogram matrix, given cut - void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, int32_t n_threads, - common::Span hess); + void Init(DMatrix* p_fmat, int max_bins, double sparse_thresh, bool sorted_sketch, + int32_t n_threads, common::Span hess); void Init(SparsePage const& page, common::Span ft, common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense, - int32_t n_threads); + double sparse_thresh, int32_t n_threads); // specific method for sparse data as no possibility to reduce allocated memory template @@ -74,7 +78,7 @@ class GHistIndexMatrix { index_data[ibegin + j] = get_offset(bin_idx, j); ++hit_count_tloc_[tid * nbins + bin_idx]; } else { - uint32_t idx = cut.SearchBin(inst[j].fvalue, inst[j].index, ptrs, values); + uint32_t idx = cut.SearchBin(e.fvalue, e.index, ptrs, values); index_data[ibegin + j] = get_offset(idx, j); ++hit_count_tloc_[tid * nbins + idx]; } @@ -82,10 +86,9 @@ class GHistIndexMatrix { }); } - void ResizeIndex(const size_t n_index, - const bool isDense); + void ResizeIndex(const size_t n_index, const bool isDense); - inline void GetFeatureCounts(size_t* counts) const { + void GetFeatureCounts(size_t* counts) const { auto nfeature = cut.Ptrs().size() - 1; for (unsigned fid = 0; fid < nfeature; ++fid) { auto ibegin = cut.Ptrs()[fid]; @@ -95,7 +98,8 @@ class GHistIndexMatrix { } } } - inline bool IsDense() const { + + bool IsDense() const { return isDense_; } void SetDense(bool is_dense) { isDense_ = is_dense; } @@ -105,6 +109,8 @@ class GHistIndexMatrix { } private: + // unused at the moment: https://github.com/dmlc/xgboost/pull/7531 + std::unique_ptr columns_; std::vector hit_count_tloc_; bool isDense_; }; @@ -117,7 +123,19 @@ class GHistIndexMatrix { */ inline bool RegenGHist(BatchParam old, BatchParam p) { // parameter is renewed or caller requests a regen - return p.regen || (old.gpu_id != p.gpu_id || old.max_bin != p.max_bin); + if (p == BatchParam{}) { + // empty parameter is passed in, don't regenerate so that we can use gindex in + // predictor, which doesn't have any training parameter. + return false; + } + + // Avoid comparing nan values. + bool l_nan = std::isnan(old.sparse_thresh); + bool r_nan = std::isnan(p.sparse_thresh); + // regenerate if parameter is changed. + bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (old.sparse_thresh != p.sparse_thresh)); + bool param_chg = old.gpu_id != p.gpu_id || old.max_bin != p.max_bin; + return p.regen || param_chg || st_chg; } } // namespace xgboost #endif // XGBOOST_DATA_GRADIENT_INDEX_H_ diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index 8f592213f..9ec69d904 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #include "gradient_index_page_source.h" @@ -11,7 +11,7 @@ void GradientIndexPageSource::Fetch() { this->page_.reset(new GHistIndexMatrix()); CHECK_NE(cuts_.Values().size(), 0); this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, - nthreads_); + sparse_thresh_, nthreads_); this->WriteCache(); } } diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index a11057d54..30b53a294 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ #define XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ @@ -7,8 +7,8 @@ #include #include -#include "sparse_page_source.h" #include "gradient_index.h" +#include "sparse_page_source.h" namespace xgboost { namespace data { @@ -17,23 +17,26 @@ class GradientIndexPageSource : public PageSourceIncMixIn { bool is_dense_; int32_t max_bin_per_feat_; common::Span feature_types_; + double sparse_thresh_; public: - GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, - size_t n_batches, std::shared_ptr cache, - BatchParam param, common::HistogramCuts cuts, - bool is_dense, int32_t max_bin_per_feat, + GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, + std::shared_ptr cache, BatchParam param, + common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat, common::Span feature_types, std::shared_ptr source) : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), - cuts_{std::move(cuts)}, is_dense_{is_dense}, - max_bin_per_feat_{max_bin_per_feat}, feature_types_{feature_types} { + cuts_{std::move(cuts)}, + is_dense_{is_dense}, + max_bin_per_feat_{max_bin_per_feat}, + feature_types_{feature_types}, + sparse_thresh_{param.sparse_thresh} { this->source_ = source; this->Fetch(); } void Fetch() final; }; -} // namespace data -} // namespace xgboost +} // namespace data +} // namespace xgboost #endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 3e1e1de79..754304fb2 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -74,12 +74,18 @@ BatchSet SimpleDMatrix::GetSortedColumnBatches() { return BatchSet(begin_iter); } +namespace { +void CheckEmpty(BatchParam const& l, BatchParam const& r) { + if (l == BatchParam{}) { + CHECK(r != BatchParam{}) << "Batch parameter is not initialized."; + } +} +} // anonymous namespace + BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) { // ELLPACK page doesn't exist, generate it - if (!(batch_param_ != BatchParam{})) { - CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; - } - if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) { + CheckEmpty(batch_param_, param); + if (!ellpack_page_ || RegenGHist(batch_param_, param)) { CHECK_GE(param.gpu_id, 0); CHECK_GE(param.max_bin, 2); ellpack_page_.reset(new EllpackPage(this, param)); @@ -91,17 +97,15 @@ BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) } BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& param) { - if (!(batch_param_ != BatchParam{})) { - CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; - } + CheckEmpty(batch_param_, param); if (!gradient_index_ || RegenGHist(batch_param_, param)) { LOG(INFO) << "Generating new Gradient Index."; CHECK_GE(param.max_bin, 2); CHECK_EQ(param.gpu_id, -1); // Used only by approx. auto sorted_sketch = param.regen; - gradient_index_.reset( - new GHistIndexMatrix(this, param.max_bin, sorted_sketch, this->ctx_.Threads(), param.hess)); + gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.sparse_thresh, + sorted_sketch, this->ctx_.Threads(), param.hess)); batch_param_ = param; CHECK_EQ(batch_param_.hess.data(), param.hess.data()); } diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 4c6a3e28c..8bb438481 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -39,7 +39,7 @@ class SimpleDMatrix : public DMatrix { /*! \brief magic number used to identify SimpleDMatrix binary files */ static const int kMagic = 0xffffab01; - private: + protected: BatchSet GetRowBatches() override; BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 22ad0f85d..a9fd9b7c1 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -164,8 +164,8 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam // all index here. if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { this->InitializeSparsePage(); - ghist_index_page_.reset( - new GHistIndexMatrix{this, param.max_bin, param.regen, ctx_.Threads()}); + ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh, + param.regen, ctx_.Threads()}); this->InitializeSparsePage(); batch_param_ = param; } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 5c61fafa0..0a09dc255 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -708,7 +708,7 @@ class GPUPredictor : public xgboost::Predictor { } } else { size_t batch_offset = 0; - for (auto const& page : dmat->GetBatches()) { + for (auto const& page : dmat->GetBatches(BatchParam{})) { dmat->Info().feature_types.SetDevice(ctx_->gpu_id); auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); this->PredictInternal( @@ -984,7 +984,7 @@ class GPUPredictor : public xgboost::Predictor { batch_offset += batch.Size(); } } else { - for (auto const& batch : p_fmat->GetBatches()) { + for (auto const& batch : p_fmat->GetBatches(BatchParam{})) { bst_row_t batch_offset = 0; EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)}; size_t num_rows = batch.Size(); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 6acc096e0..843afeec1 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -31,11 +31,11 @@ namespace { template auto BatchSpec(TrainParam const &p, common::Span hess, HistEvaluator const &evaluator) { - return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, !evaluator.Task().const_hess}; + return BatchParam{p.max_bin, hess, !evaluator.Task().const_hess}; } auto BatchSpec(TrainParam const &p, common::Span hess) { - return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, false}; + return BatchParam{p.max_bin, hess, false}; } } // anonymous namespace diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 540b157a2..8c52ff382 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -68,9 +68,7 @@ void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr *gpair, DMatrix *dmat, const std::vector &trees) { - auto it = dmat->GetBatches( - BatchParam{GenericParameter::kCpuId, param_.max_bin}) - .begin(); + auto it = dmat->GetBatches(HistBatch(param_)).begin(); auto p_gmat = it.Page(); if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { updater_monitor_.Start("GmatInitialization"); @@ -127,8 +125,8 @@ void QuantileHistMaker::Builder::InitRoot( nodes_for_explicit_hist_build_.push_back(node); size_t page_id = 0; - for (auto const &gidx : p_fmat->GetBatches( - {GenericParameter::kCpuId, param_.max_bin})) { + for (auto const& gidx : + p_fmat->GetBatches(HistBatch(param_))) { this->histogram_builder_->BuildHist( page_id, gidx, p_tree, row_set_collection_, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); @@ -141,10 +139,7 @@ void QuantileHistMaker::Builder::InitRoot( GradientPairT grad_stat; if (data_layout_ == DataLayout::kDenseDataZeroBased || data_layout_ == DataLayout::kDenseDataOneBased) { - auto const &gmat = *(p_fmat - ->GetBatches(BatchParam{ - GenericParameter::kCpuId, param_.max_bin}) - .begin()); + auto const& gmat = *(p_fmat->GetBatches(HistBatch(param_)).begin()); const std::vector &row_ptr = gmat.cut.Ptrs(); const uint32_t ibegin = row_ptr[fid_least_bins_]; const uint32_t iend = row_ptr[fid_least_bins_ + 1]; @@ -170,8 +165,7 @@ void QuantileHistMaker::Builder::InitRoot( std::vector entries{node}; builder_monitor_.Start("EvaluateSplits"); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); - for (auto const &gmat : p_fmat->GetBatches( - BatchParam{GenericParameter::kCpuId, param_.max_bin})) { + for (auto const& gmat : p_fmat->GetBatches(HistBatch(param_))) { evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, &entries); break; @@ -264,8 +258,7 @@ void QuantileHistMaker::Builder::ExpandTree( if (param_.max_depth == 0 || depth < param_.max_depth) { size_t i = 0; - for (auto const &gidx : p_fmat->GetBatches( - {GenericParameter::kCpuId, param_.max_bin})) { + for (auto const& gidx : p_fmat->GetBatches(HistBatch(param_))) { this->histogram_builder_->BuildHist( i, gidx, p_tree, row_set_collection_, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index f2103270c..3f2b07ff9 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -92,6 +92,10 @@ using xgboost::common::GHistBuilder; using xgboost::common::ColumnMatrix; using xgboost::common::Column; +inline BatchParam HistBatch(TrainParam const& param) { + return {param.max_bin, param.sparse_threshold}; +} + /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 46ca6e6bb..46d89fe97 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -12,12 +12,14 @@ namespace xgboost { namespace common { TEST(DenseColumn, Test) { - uint64_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 2}; - for (size_t max_num_bin : max_num_bins) { + int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, + static_cast(std::numeric_limits::max()) + 1, + static_cast(std::numeric_limits::max()) + 2}; + for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); + auto sparse_thresh = 0.2; + GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, + common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); @@ -59,12 +61,12 @@ inline void CheckSparseColumn(const Column& col_input, const GHistIn } TEST(SparseColumn, Test) { - uint64_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 2}; - for (size_t max_num_bin : max_num_bins) { + int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, + static_cast(std::numeric_limits::max()) + 1, + static_cast(std::numeric_limits::max()) + 2}; + for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0)); switch (column_matrix.GetTypeSize()) { @@ -99,12 +101,12 @@ inline void CheckColumWithMissingValue(const Column& col_input, } TEST(DenseColumnWithMissing, Test) { - uint64_t max_num_bins[] = { static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 2 }; - for (size_t max_num_bin : max_num_bins) { + int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, + static_cast(std::numeric_limits::max()) + 1, + static_cast(std::numeric_limits::max()) + 2}; + for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); switch (column_matrix.GetTypeSize()) { @@ -131,9 +133,8 @@ void TestGHistIndexMatrixCreation(size_t nthreads) { size_t constexpr kPageSize = 1024, kEntriesPerCol = 3; size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; /* This should create multiple sparse pages */ - std::unique_ptr dmat{ CreateSparsePageDMatrix(kEntries) }; - omp_set_num_threads(nthreads); - GHistIndexMatrix gmat(dmat.get(), 256, false, common::OmpGetNumThreads(0)); + std::unique_ptr dmat{CreateSparsePageDMatrix(kEntries)}; + GHistIndexMatrix gmat(dmat.get(), 256, 0.5f, false, common::OmpGetNumThreads(nthreads)); } TEST(HistIndexCreationWithExternalMemory, Test) { diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index b820eeefe..13fd84691 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -299,7 +299,7 @@ TEST(HistUtil, IndexBinBound) { for (auto max_bin : bin_sizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0)); EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); } @@ -322,7 +322,7 @@ TEST(HistUtil, IndexBinData) { for (auto max_bin : kBinSizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0)); uint32_t* offsets = hmat.index.Offset(); EXPECT_EQ(hmat.index.Size(), kRows*kCols); switch (max_bin) { diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index c73dd9910..a67ab1d59 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -81,13 +81,13 @@ TEST(EllpackPage, BuildGidxSparse) { TEST(EllpackPage, FromCategoricalBasic) { using common::AsCat; size_t constexpr kRows = 1000, kCats = 13, kCols = 1; - size_t max_bins = 8; + int32_t max_bins = 8; auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats); auto m = GetDMatrixFromData(x, kRows, 1); auto& h_ft = m->Info().feature_types.HostVector(); h_ft.resize(kCols, FeatureType::kCategorical); - BatchParam p(0, max_bins); + BatchParam p{0, max_bins}; auto ellpack = EllpackPage(m.get(), p); auto accessor = ellpack.Impl()->GetDeviceAccessor(0); ASSERT_EQ(kCats, accessor.NumBins()); diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index dacf5d901..6bf12a060 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -4,8 +4,8 @@ #include #include -#include "../helpers.h" #include "../../../src/data/gradient_index.h" +#include "../helpers.h" namespace xgboost { namespace data { @@ -13,12 +13,22 @@ TEST(GradientIndex, ExternalMemory) { std::unique_ptr dmat = CreateSparsePageDMatrix(10000); std::vector base_rowids; std::vector hessian(dmat->Info().num_row_, 1); - for (auto const &page : dmat->GetBatches( - {GenericParameter::kCpuId, 64, hessian})) { + for (auto const &page : dmat->GetBatches({64, hessian, true})) { base_rowids.push_back(page.base_rowid); } size_t i = 0; - for (auto const& page : dmat->GetBatches()) { + for (auto const &page : dmat->GetBatches()) { + ASSERT_EQ(base_rowids[i], page.base_rowid); + ++i; + } + + + base_rowids.clear(); + for (auto const &page : dmat->GetBatches({64, hessian, false})) { + base_rowids.push_back(page.base_rowid); + } + i = 0; + for (auto const &page : dmat->GetBatches()) { ASSERT_EQ(base_rowids[i], page.base_rowid); ++i; } @@ -33,10 +43,10 @@ TEST(GradientIndex, FromCategoricalBasic) { auto &h_ft = m->Info().feature_types.HostVector(); h_ft.resize(kCols, FeatureType::kCategorical); - BatchParam p(0, max_bins); + BatchParam p(max_bins, 0.8); GHistIndexMatrix gidx; - gidx.Init(m.get(), max_bins, false, common::OmpGetNumThreads(0), {}); + gidx.Init(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {}); auto x_copy = x; std::sort(x_copy.begin(), x_copy.end()); diff --git a/tests/cpp/data/test_iterative_device_dmatrix.cu b/tests/cpp/data/test_iterative_device_dmatrix.cu index 0fc992f24..629c67bf9 100644 --- a/tests/cpp/data/test_iterative_device_dmatrix.cu +++ b/tests/cpp/data/test_iterative_device_dmatrix.cu @@ -21,7 +21,7 @@ void TestEquivalent(float sparsity) { std::unique_ptr page_concatenated { new EllpackPageImpl(0, first->Cuts(), first->is_dense, first->row_stride, 1000 * 100)}; - for (auto& batch : m.GetBatches()) { + for (auto& batch : m.GetBatches({})) { auto page = batch.Impl(); size_t num_elements = page_concatenated->Copy(0, page, offset); offset += num_elements; @@ -93,7 +93,7 @@ TEST(IterativeDeviceDMatrix, RowMajor) { 0, 256); size_t n_batches = 0; std::string interface_str = iter.AsArray(); - for (auto& ellpack : m.GetBatches()) { + for (auto& ellpack : m.GetBatches({})) { n_batches ++; auto impl = ellpack.Impl(); common::CompressedIterator iterator( diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index b494a1410..3113bc62b 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -68,6 +68,7 @@ TEST(GPUPredictor, EllpackBasic) { .Bins(bins) .Device(0) .GenerateDeviceDMatrix(true); + ASSERT_FALSE(p_m->PageExists()); TestPredictionFromGradientIndex("gpu_predictor", rows, kCols, p_m); TestPredictionFromGradientIndex("gpu_predictor", bins, kCols, p_m); } diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 5df2ff4e9..7819ec307 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -31,8 +31,7 @@ template void TestEvaluateSplits() { size_t constexpr kMaxBins = 4; // dense, no missing values - - GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix gmat(dmat.get(), kMaxBins, 0.5, false, common::OmpGetNumThreads(0)); common::RowSetCollection row_set_collection; std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kRows); @@ -127,7 +126,7 @@ TEST(HistEvaluator, CategoricalPartition) { auto evaluator = HistEvaluator{ param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; - for (auto const &gmat : dmat->GetBatches({GenericParameter::kCpuId, 32})) { + for (auto const &gmat : dmat->GetBatches({32, param.sparse_threshold})) { common::HistCollection hist; std::vector entries(1); @@ -212,7 +211,7 @@ auto CompareOneHotAndPartition(bool onehot) { param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; std::vector entries(1); - for (auto const &gmat : dmat->GetBatches({GenericParameter::kCpuId, 32})) { + for (auto const &gmat : dmat->GetBatches({32, param.sparse_threshold})) { common::HistCollection hist; entries.front().nid = 0; diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 8acf0959f..553550e33 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -1,20 +1,20 @@ /*! - * Copyright 2018-2021 by Contributors + * Copyright 2018-2022 by Contributors */ #include -#include "../../helpers.h" -#include "../../categorical_helpers.h" +#include #include "../../../../src/common/categorical.h" #include "../../../../src/tree/hist/histogram.h" #include "../../../../src/tree/updater_quantile_hist.h" +#include "../../categorical_helpers.h" +#include "../../helpers.h" namespace xgboost { namespace tree { namespace { -void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples, - size_t base_rowid = 0) { +void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples, size_t base_rowid = 0) { auto &row_indices = *row_set->Data(); row_indices.resize(n_samples); std::iota(row_indices.begin(), row_indices.end(), base_rowid); @@ -33,10 +33,7 @@ void TestAddHistRows(bool is_distributed) { int32_t constexpr kMaxBins = 4; auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); - auto const &gmat = *(p_fmat - ->GetBatches( - BatchParam{GenericParameter::kCpuId, kMaxBins}) - .begin()); + auto const &gmat = *(p_fmat->GetBatches(BatchParam{kMaxBins, 0.5}).begin()); RegTree tree; @@ -49,9 +46,8 @@ void TestAddHistRows(bool is_distributed) { nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); HistogramBuilder histogram_builder; - histogram_builder.Reset(gmat.cut.TotalBins(), - {GenericParameter::kCpuId, kMaxBins}, - omp_get_max_threads(), 1, is_distributed); + histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, + is_distributed); histogram_builder.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, &tree); @@ -89,15 +85,11 @@ void TestSyncHist(bool is_distributed) { auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); - auto const &gmat = *(p_fmat - ->GetBatches( - BatchParam{GenericParameter::kCpuId, kMaxBins}) - .begin()); + auto const &gmat = *(p_fmat->GetBatches(BatchParam{kMaxBins, 0.5}).begin()); HistogramBuilder histogram; uint32_t total_bins = gmat.cut.Ptrs().back(); - histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, - omp_get_max_threads(), 1, is_distributed); + histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); RowSetCollection row_set_collection_; { @@ -250,10 +242,7 @@ void TestBuildHistogram(bool is_distributed) { int32_t constexpr kMaxBins = 4; auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); - auto const &gmat = *(p_fmat - ->GetBatches( - BatchParam{GenericParameter::kCpuId, kMaxBins}) - .begin()); + auto const &gmat = *(p_fmat->GetBatches(BatchParam{kMaxBins, 0.5}).begin()); uint32_t total_bins = gmat.cut.Ptrs().back(); static double constexpr kEps = 1e-6; @@ -263,8 +252,7 @@ void TestBuildHistogram(bool is_distributed) { bst_node_t nid = 0; HistogramBuilder histogram; - histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, - omp_get_max_threads(), 1, is_distributed); + histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); RegTree tree; @@ -278,8 +266,7 @@ void TestBuildHistogram(bool is_distributed) { CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); std::vector nodes_for_explicit_hist_build; nodes_for_explicit_hist_build.push_back(node); - for (auto const &gidx : p_fmat->GetBatches( - {GenericParameter::kCpuId, kMaxBins})) { + for (auto const &gidx : p_fmat->GetBatches({kMaxBins, 0.5})) { histogram.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, gpair); } @@ -342,11 +329,9 @@ void TestHistogramCategorical(size_t n_categories) { * Generate hist with cat data. */ HistogramBuilder cat_hist; - for (auto const &gidx : cat_m->GetBatches( - {GenericParameter::kCpuId, kBins})) { + for (auto const &gidx : cat_m->GetBatches({kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); - cat_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins}, - omp_get_max_threads(), 1, false); + cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); cat_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, gpair.HostVector()); } @@ -357,13 +342,10 @@ void TestHistogramCategorical(size_t n_categories) { auto x_encoded = OneHotEncodeFeature(x, n_categories); auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories); HistogramBuilder onehot_hist; - for (auto const &gidx : encode_m->GetBatches( - {GenericParameter::kCpuId, kBins})) { + for (auto const &gidx : encode_m->GetBatches({kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); - onehot_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins}, - omp_get_max_threads(), 1, false); - onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, - nodes_for_explicit_hist_build, {}, + onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); + onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {}, gpair.HostVector()); } @@ -378,11 +360,16 @@ TEST(CPUHistogram, Categorical) { TestHistogramCategorical(n_categories); } } - -TEST(CPUHistogram, ExternalMemory) { +namespace { +void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { size_t constexpr kEntries = 1 << 16; - int32_t constexpr kBins = 32; auto m = CreateSparsePageDMatrix(kEntries, "cache"); + + std::vector hess(m->Info().num_row_, 1.0); + if (is_approx) { + batch_param.hess = hess; + } + std::vector partition_size(1, 0); size_t total_bins{0}; size_t n_samples{0}; @@ -401,9 +388,7 @@ TEST(CPUHistogram, ExternalMemory) { * Multi page */ std::vector rows_set; - std::vector hess(m->Info().num_row_, 1.0); - for (auto const &page : m->GetBatches( - {GenericParameter::kCpuId, kBins, hess})) { + for (auto const &page : m->GetBatches(batch_param)) { CHECK_LT(page.base_rowid, m->Info().num_row_); auto n_rows_in_node = page.Size(); partition_size[0] = std::max(partition_size[0], n_rows_in_node); @@ -419,12 +404,10 @@ TEST(CPUHistogram, ExternalMemory) { 1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, 256}; - multi_build.Reset(total_bins, {GenericParameter::kCpuId, kBins}, - omp_get_max_threads(), rows_set.size(), false); + multi_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), rows_set.size(), false); size_t page_idx{0}; - for (auto const &page : m->GetBatches( - {GenericParameter::kCpuId, kBins, hess})) { + for (auto const &page : m->GetBatches(batch_param)) { multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {}, h_gpair); ++page_idx; @@ -442,16 +425,13 @@ TEST(CPUHistogram, ExternalMemory) { RowSetCollection row_set_collection; InitRowPartitionForTest(&row_set_collection, n_samples); - single_build.Reset(total_bins, {GenericParameter::kCpuId, kBins}, - omp_get_max_threads(), 1, false); - size_t n_batches{0}; - for (auto const &page : - m->GetBatches({GenericParameter::kCpuId, kBins})) { - single_build.BuildHist(0, page, &tree, row_set_collection, nodes, {}, - h_gpair); - n_batches ++; - } - ASSERT_EQ(n_batches, 1); + single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false); + SparsePage concat; + GHistIndexMatrix gmat; + std::vector hess(m->Info().num_row_, 1.0f); + gmat.Init(m.get(), batch_param.max_bin, std::numeric_limits::quiet_NaN(), false, + common::OmpGetNumThreads(0), hess); + single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair); single_page = single_build.Histogram()[0]; } @@ -460,5 +440,11 @@ TEST(CPUHistogram, ExternalMemory) { ASSERT_NEAR(single_page[i].GetHess(), multi_page[i].GetHess(), kRtEps); } } +} // anonymous namespace + +TEST(CPUHistogram, ExternalMemory) { + int32_t constexpr kBins = 256; + TestHistogramExternalMemory(BatchParam{kBins, common::Span{}, false}, true); +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 57a0cd354..738d30d29 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -8,6 +8,19 @@ namespace xgboost { namespace tree { +namespace { +void GetSplit(RegTree *tree, float split_value, std::vector *candidates) { + tree->ExpandNode( + /*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); + candidates->front().split.split_value = split_value; + candidates->front().split.sindex = 0; + candidates->front().split.sindex |= (1U << 31); +} +} // anonymous namespace + TEST(Approx, Partitioner) { size_t n_samples = 1024, n_features = 1, base_rowid = 0; ApproxRowPartitioner partitioner{n_samples, base_rowid}; @@ -20,20 +33,18 @@ TEST(Approx, Partitioner) { ctx.InitAllowUnknown(Args{}); std::vector candidates{{0, 0, 0.4}}; - for (auto const &page : Xy->GetBatches({GenericParameter::kCpuId, 64})) { - bst_feature_t split_ind = 0; + auto grad = GenerateRandomGradients(n_samples); + std::vector hess(grad.Size()); + std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(), + [](auto gpair) { return gpair.GetHess(); }); + + for (auto const &page : Xy->GetBatches({64, hess, true})) { + bst_feature_t const split_ind = 0; { auto min_value = page.cut.MinValues()[split_ind]; RegTree tree; - tree.ExpandNode( - /*nid=*/0, /*split_index=*/0, /*split_value=*/min_value, - /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - /*left_sum=*/0.0f, - /*right_sum=*/0.0f); ApproxRowPartitioner partitioner{n_samples, base_rowid}; - candidates.front().split.split_value = min_value; - candidates.front().split.sindex = 0; - candidates.front().split.sindex |= (1U << 31); + GetSplit(&tree, min_value, &candidates); partitioner.UpdatePosition(&ctx, page, candidates, &tree); ASSERT_EQ(partitioner.Size(), 3); ASSERT_EQ(partitioner[1].Size(), 0); @@ -44,16 +55,8 @@ TEST(Approx, Partitioner) { auto ptr = page.cut.Ptrs()[split_ind + 1]; float split_value = page.cut.Values().at(ptr / 2); RegTree tree; - tree.ExpandNode( - /*nid=*/RegTree::kRoot, /*split_index=*/split_ind, - /*split_value=*/split_value, - /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - /*left_sum=*/0.0f, - /*right_sum=*/0.0f); + GetSplit(&tree, split_value, &candidates); auto left_nidx = tree[RegTree::kRoot].LeftChild(); - candidates.front().split.split_value = split_value; - candidates.front().split.sindex = 0; - candidates.front().split.sindex |= (1U << 31); partitioner.UpdatePosition(&ctx, page, candidates, &tree); auto elem = partitioner[left_nidx]; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 931ac79cb..fc7c43ad7 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -155,18 +155,19 @@ class QuantileHistMock : public QuantileHistMaker { std::vector row_gpairs = { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} }; - size_t constexpr kMaxBins = 4; + int32_t constexpr kMaxBins = 4; // try out different sparsity to get different number of missing values for (double sparsity : {0.0, 0.1, 0.2}) { // kNRows samples with kNCols features auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0)); + float sparse_th = 0.0; + GHistIndexMatrix gmat{dmat.get(), kMaxBins, sparse_th, false, common::OmpGetNumThreads(0)}; ColumnMatrix cm; // treat everything as dense, as this is what we intend to test here - cm.Init(gmat, 0.0, common::OmpGetNumThreads(0)); + cm.Init(gmat, sparse_th, common::OmpGetNumThreads(0)); RealImpl::InitData(gmat, *dmat, tree, &row_gpairs); const size_t num_row = dmat->Info().num_row_; // split by feature 0 @@ -247,8 +248,8 @@ class QuantileHistMock : public QuantileHistMaker { static size_t GetNumColumns() { return kNCols; } void TestInitData() { - size_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0)); + int32_t constexpr kMaxBins = 4; + GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)}; RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -264,8 +265,8 @@ class QuantileHistMock : public QuantileHistMaker { } void TestInitDataSampling() { - size_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0)); + int32_t constexpr kMaxBins = 4; + GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)}; RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); diff --git a/tests/cpp/tree/test_regen.cc b/tests/cpp/tree/test_regen.cc new file mode 100644 index 000000000..47a576f45 --- /dev/null +++ b/tests/cpp/tree/test_regen.cc @@ -0,0 +1,124 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include + +#include "../../../src/data/adapter.h" +#include "../../../src/data/simple_dmatrix.h" +#include "../helpers.h" + +namespace xgboost { +namespace { +class DMatrixForTest : public data::SimpleDMatrix { + size_t n_regen_{0}; + + public: + using SimpleDMatrix::SimpleDMatrix; + BatchSet GetGradientIndex(const BatchParam& param) override { + auto backup = this->gradient_index_; + auto iter = SimpleDMatrix::GetGradientIndex(param); + n_regen_ += (backup != this->gradient_index_); + return iter; + } + + BatchSet GetEllpackBatches(const BatchParam& param) override { + auto backup = this->ellpack_page_; + auto iter = SimpleDMatrix::GetEllpackBatches(param); + n_regen_ += (backup != this->ellpack_page_); + return iter; + } + + auto NumRegen() const { return n_regen_; } + + void Reset() { + this->gradient_index_.reset(); + this->ellpack_page_.reset(); + n_regen_ = 0; + } +}; + +/** + * \brief Test for whether the gradient index is correctly regenerated. + */ +class RegenTest : public ::testing::Test { + protected: + std::shared_ptr p_fmat_; + + void SetUp() override { + size_t constexpr kRows = 256, kCols = 10; + HostDeviceVector storage; + auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage); + auto adapter = data::ArrayAdapter(StringView{dense}); + p_fmat_ = std::shared_ptr(new DMatrixForTest{ + &adapter, std::numeric_limits::quiet_NaN(), common::OmpGetNumThreads(0)}); + + p_fmat_->Info().labels.Reshape(256, 1); + auto labels = p_fmat_->Info().labels.Data(); + RandomDataGenerator{kRows, 1, 0}.GenerateDense(labels); + } + + auto constexpr Iter() const { return 4; } + + template + size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const { + auto learner = std::unique_ptr{Learner::Create({p_fmat_})}; + learner->SetParam("tree_method", tree_method); + learner->SetParam("objective", obj); + learner->Configure(); + + for (auto i = 0; i < Iter(); ++i) { + learner->UpdateOneIter(i, p_fmat_); + } + + auto for_test = dynamic_cast(p_fmat_.get()); + CHECK(for_test); + auto backup = for_test->NumRegen(); + for_test->GetBatches(BatchParam{}); + CHECK_EQ(for_test->NumRegen(), backup); + + if (reset) { + for_test->Reset(); + } + return backup; + } +}; +} // anonymous namespace + +TEST_F(RegenTest, Approx) { + auto n = this->TestTreeMethod("approx", "reg:squarederror"); + ASSERT_EQ(n, 1); + n = this->TestTreeMethod("approx", "reg:logistic"); + ASSERT_EQ(n, this->Iter()); +} + +TEST_F(RegenTest, Hist) { + auto n = this->TestTreeMethod("hist", "reg:squarederror"); + ASSERT_EQ(n, 1); + n = this->TestTreeMethod("hist", "reg:logistic"); + ASSERT_EQ(n, 1); +} + +TEST_F(RegenTest, Mixed) { + auto n = this->TestTreeMethod("hist", "reg:squarederror", false); + ASSERT_EQ(n, 1); + n = this->TestTreeMethod("approx", "reg:logistic", true); + ASSERT_EQ(n, this->Iter() + 1); + + n = this->TestTreeMethod("approx", "reg:logistic", false); + ASSERT_EQ(n, this->Iter()); + n = this->TestTreeMethod("hist", "reg:squarederror", true); + ASSERT_EQ(n, this->Iter() + 1); +} + +#if defined(XGBOOST_USE_CUDA) +TEST_F(RegenTest, GpuHist) { + auto n = this->TestTreeMethod("gpu_hist", "reg:squarederror"); + ASSERT_EQ(n, 1); + n = this->TestTreeMethod("gpu_hist", "reg:logistic", false); + ASSERT_EQ(n, 1); + + n = this->TestTreeMethod("hist", "reg:logistic"); + ASSERT_EQ(n, 2); +} +#endif // defined(XGBOOST_USE_CUDA) +} // namespace xgboost