From 441ffc017aee142cbcbae9b7f74754fd5d9d1ebd Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 6 Sep 2022 23:05:49 +0800 Subject: [PATCH] Copy data from Ellpack to GHist. (#8215) --- src/common/algorithm.cuh | 27 ++++ src/common/algorithm.h | 16 ++ src/common/column_matrix.h | 153 +++++++++++++----- src/common/device_helpers.cuh | 13 +- src/common/hist_util.h | 35 +++- src/data/ellpack_page.cu | 11 ++ src/data/ellpack_page.cuh | 19 ++- src/data/gradient_index.cc | 10 +- src/data/gradient_index.cu | 111 +++++++++++++ src/data/gradient_index.h | 45 +++--- src/data/iterative_dmatrix.cc | 11 +- src/data/iterative_dmatrix.h | 13 +- tests/cpp/common/test_column_matrix.cc | 6 +- tests/cpp/data/test_gradient_index.cc | 77 +++++++++ tests/cpp/tree/test_quantile_hist.cc | 2 +- .../test_device_quantile_dmatrix.py | 29 ++-- 16 files changed, 466 insertions(+), 112 deletions(-) create mode 100644 src/common/algorithm.cuh create mode 100644 src/common/algorithm.h create mode 100644 src/data/gradient_index.cu diff --git a/src/common/algorithm.cuh b/src/common/algorithm.cuh new file mode 100644 index 000000000..dfce723da --- /dev/null +++ b/src/common/algorithm.cuh @@ -0,0 +1,27 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#pragma once + +#include // thrust::upper_bound +#include // thrust::seq + +#include "xgboost/base.h" +#include "xgboost/span.h" + +namespace xgboost { +namespace common { +namespace cuda { +template +size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) { + size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - 1 - first; + return segment_id; +} + +template +size_t XGBOOST_DEVICE SegmentId(Span segments_ptr, size_t idx) { + return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx); +} +} // namespace cuda +} // namespace common +} // namespace xgboost diff --git a/src/common/algorithm.h b/src/common/algorithm.h new file mode 100644 index 000000000..addcd95cf --- /dev/null +++ b/src/common/algorithm.h @@ -0,0 +1,16 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#pragma once +#include // std::upper_bound +#include // std::size_t + +namespace xgboost { +namespace common { +template +auto SegmentId(It first, It last, Idx idx) { + std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first; + return segment_id; +} +} // namespace common +} // namespace xgboost diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index b7f0bc2c2..f663987f0 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -18,6 +18,7 @@ #include "../data/adapter.h" #include "../data/gradient_index.h" +#include "algorithm.h" #include "hist_util.h" namespace xgboost { @@ -135,6 +136,22 @@ class DenseColumnIter : public Column { class ColumnMatrix { void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold); + template + void SetBinSparse(BinT bin_id, RIdx rid, bst_feature_t fid, ColumnBinT* local_index) { + if (type_[fid] == kDenseColumn) { + ColumnBinT* begin = &local_index[feature_offsets_[fid]]; + begin[rid] = bin_id - index_base_[fid]; + // not thread-safe with bool vector. FIXME(jiamingy): We can directly assign + // kMissingId to the index to avoid missing flags. + missing_flags_[feature_offsets_[fid] + rid] = false; + } else { + ColumnBinT* begin = &local_index[feature_offsets_[fid]]; + begin[num_nonzeros_[fid]] = bin_id - index_base_[fid]; + row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid; + ++num_nonzeros_[fid]; + } + } + public: // get number of features bst_feature_t GetNumFeature() const { return static_cast(type_.size()); } @@ -144,34 +161,66 @@ class ColumnMatrix { this->InitStorage(gmat, sparse_threshold); } - template - void PushBatch(int32_t n_threads, Batch const& batch, float missing, GHistIndexMatrix const& gmat, - size_t base_rowid) { - // pre-fill index_ for dense columns - auto n_features = gmat.Features(); - if (!any_missing_) { - missing_flags_.resize(feature_offsets_[n_features], false); - // row index is compressed, we need to dispatch it. - DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_features = n_features, - n_threads = n_threads](auto t) { - using RowBinIdxT = decltype(t); - SetIndexNoMissing(base_rowid, gmat.index.data(), size, n_features, n_threads); - }); - } else { - missing_flags_.resize(feature_offsets_[n_features], true); - SetIndexMixedColumns(base_rowid, batch, gmat, n_features, missing); - } - } - - // construct column matrix from GHistIndexMatrix - void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, - int32_t n_threads) { + /** + * \brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original + * SparsePage. + */ + void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, + int32_t n_threads) { auto batch = data::SparsePageAdapterBatch{page.GetView()}; this->InitStorage(gmat, sparse_threshold); // ignore base row id here as we always has one column matrix for each sparse page. this->PushBatch(n_threads, batch, std::numeric_limits::quiet_NaN(), gmat, 0); } + /** + * \brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual + * data. + * + * This function requires a binary search for each bin to get back the feature index + * for those bins. + */ + void InitFromGHist(Context const* ctx, GHistIndexMatrix const& gmat) { + auto n_threads = ctx->Threads(); + if (!any_missing_) { + // row index is compressed, we need to dispatch it. + DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = gmat.Size(), n_threads = n_threads, + n_features = gmat.Features()](auto t) { + using RowBinIdxT = decltype(t); + SetIndexNoMissing(gmat.base_rowid, gmat.index.data(), size, n_features, + n_threads); + }); + } else { + SetIndexMixedColumns(gmat); + } + } + + /** + * \brief Push batch of data for Quantile DMatrix support. + * + * \param batch Input data wrapped inside a adapter batch. + * \param gmat The row-major histogram index that contains index for ALL data. + * \param base_rowid The beginning row index for current batch. + */ + template + void PushBatch(int32_t n_threads, Batch const& batch, float missing, GHistIndexMatrix const& gmat, + size_t base_rowid) { + // pre-fill index_ for dense columns + if (!any_missing_) { + // row index is compressed, we need to dispatch it. + + // use base_rowid from input parameter as gmat is a single matrix that contains all + // the histogram index instead of being only a batch. + DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_threads = n_threads, + n_features = gmat.Features()](auto t) { + using RowBinIdxT = decltype(t); + SetIndexNoMissing(base_rowid, gmat.index.data(), size, n_features, n_threads); + }); + } else { + SetIndexMixedColumns(base_rowid, batch, gmat, missing); + } + } + /* Set the number of bytes based on numeric limit of maximum number of bins provided by user */ void SetTypeSize(size_t max_bin_per_feat) { if ((max_bin_per_feat - 1) <= static_cast(std::numeric_limits::max())) { @@ -210,6 +259,7 @@ class ColumnMatrix { template void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples, const size_t n_features, int32_t n_threads) { + missing_flags_.resize(feature_offsets_[n_features], false); DispatchBinType(bins_type_size_, [&](auto t) { using ColumnBinT = decltype(t); auto column_index = Span{reinterpret_cast(index_.data()), @@ -232,29 +282,16 @@ class ColumnMatrix { */ template void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat, - size_t n_features, float missing) { + float missing) { + auto n_features = gmat.Features(); + missing_flags_.resize(feature_offsets_[n_features], true); auto const* row_index = gmat.index.data() + gmat.row_ptr[base_rowid]; - auto is_valid = data::IsValidFunctor {missing}; + num_nonzeros_.resize(n_features, 0); + auto is_valid = data::IsValidFunctor{missing}; DispatchBinType(bins_type_size_, [&](auto t) { using ColumnBinT = decltype(t); ColumnBinT* local_index = reinterpret_cast(index_.data()); - num_nonzeros_.resize(n_features, 0); - auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { - if (type_[fid] == kDenseColumn) { - ColumnBinT* begin = reinterpret_cast(&local_index[feature_offsets_[fid]]); - begin[rid] = bin_id - index_base_[fid]; - // not thread-safe with bool vector. FIXME(jiamingy): We can directly assign - // kMissingId to the index to avoid missing flags. - missing_flags_[feature_offsets_[fid] + rid] = false; - } else { - ColumnBinT* begin = reinterpret_cast(&local_index[feature_offsets_[fid]]); - begin[num_nonzeros_[fid]] = bin_id - index_base_[fid]; - row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid; - ++num_nonzeros_[fid]; - } - }; - size_t const batch_size = batch.Size(); size_t k{0}; for (size_t rid = 0; rid < batch_size; ++rid) { @@ -264,7 +301,7 @@ class ColumnMatrix { if (is_valid(coo)) { auto fid = coo.column_idx; const uint32_t bin_id = row_index[k]; - get_bin_idx(bin_id, rid + base_rowid, fid); + SetBinSparse(bin_id, rid + base_rowid, fid, local_index); ++k; } } @@ -272,6 +309,40 @@ class ColumnMatrix { }); } + /** + * \brief Set column index for both dense and sparse columns, but with only GHistMatrix + * available and requires a search for each bin. + */ + void SetIndexMixedColumns(const GHistIndexMatrix& gmat) { + auto n_features = gmat.Features(); + missing_flags_.resize(feature_offsets_[n_features], true); + auto const* row_index = gmat.index.data() + gmat.row_ptr[gmat.base_rowid]; + num_nonzeros_.resize(n_features, 0); + auto const& ptrs = gmat.cut.Ptrs(); + + DispatchBinType(bins_type_size_, [&](auto t) { + using ColumnBinT = decltype(t); + ColumnBinT* local_index = reinterpret_cast(index_.data()); + auto const batch_size = gmat.Size(); + size_t k{0}; + + for (size_t ridx = 0; ridx < batch_size; ++ridx) { + auto r_beg = gmat.row_ptr[ridx]; + auto r_end = gmat.row_ptr[ridx + 1]; + bst_feature_t fidx{0}; + for (size_t j = r_beg; j < r_end; ++j) { + const uint32_t bin_idx = row_index[k]; + // find the feature index for current bin. + while (bin_idx >= ptrs[fidx + 1]) { + fidx++; + } + SetBinSparse(bin_idx, ridx, fidx, local_index); + ++k; + } + } + }); + } + BinTypeSize GetTypeSize() const { return bins_type_size_; } auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; } diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 5c922bbf3..754e47ff4 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -35,6 +35,7 @@ #include "xgboost/global_config.h" #include "common.h" +#include "algorithm.cuh" #ifdef XGBOOST_USE_NCCL #include "nccl.h" @@ -1556,17 +1557,7 @@ XGBOOST_DEVICE thrust::transform_iterator MakeTransformIt return thrust::transform_iterator(iter, func); } -template -size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) { - size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - - 1 - first; - return segment_id; -} - -template -size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span segments_ptr, size_t idx) { - return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx); -} +using xgboost::common::cuda::SegmentId; // import it for compatibility namespace detail { template diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 3b9f01c06..9bcc78ba4 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -22,6 +22,7 @@ #include "row_set.h" #include "threading_utils.h" #include "timer.h" +#include "algorithm.h" // SegmentId namespace xgboost { class GHistIndexMatrix; @@ -130,19 +131,23 @@ class HistogramCuts { /** * \brief Search the bin index for categorical feature. */ - bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const { - auto const &ptrs = this->Ptrs(); - auto const &vals = this->Values(); + bst_bin_t SearchCatBin(float value, bst_feature_t fidx, std::vector const& ptrs, + std::vector const& vals) const { auto end = ptrs.at(fidx + 1) + vals.cbegin(); auto beg = ptrs[fidx] + vals.cbegin(); // Truncates the value in case it's not perfectly rounded. - auto v = static_cast(common::AsCat(value)); + auto v = static_cast(common::AsCat(value)); auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin(); if (bin_idx == ptrs.at(fidx + 1)) { bin_idx -= 1; } return bin_idx; } + bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const { + auto const& ptrs = this->Ptrs(); + auto const& vals = this->Values(); + return this->SearchCatBin(value, fidx, ptrs, vals); + } bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); } }; @@ -189,6 +194,28 @@ auto DispatchBinType(BinTypeSize type, Fn&& fn) { * storage class. */ struct Index { + // Inside the compressor, bin_idx is the index for cut value across all features. By + // subtracting it with starting pointer of each feature, we can reduce it to smaller + // value and store it with smaller types. Usable only with dense data. + // + // For sparse input we have to store an addition feature index (similar to sparse matrix + // formats like CSR) for each bin in index field to choose the right offset. + template + struct CompressBin { + uint32_t const* offsets; + + template + auto operator()(Bin bin_idx, Feat fidx) const { + return static_cast(bin_idx - offsets[fidx]); + } + }; + + template + CompressBin MakeCompressor() const { + uint32_t const* offsets = this->Offset(); + return CompressBin{offsets}; + } + Index() { SetBinTypeSize(binTypeSize_); } Index(const Index& i) = delete; Index& operator=(Index i) = delete; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index cf04ab16e..11de33d8f 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -547,4 +547,15 @@ EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( NumSymbols()), feature_types}; } +EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor( + common::Span feature_types) const { + return {Context::kCpuId, + cuts_, + is_dense, + row_stride, + base_rowid, + n_rows, + common::CompressedIterator(gidx_buffer.ConstHostPointer(), NumSymbols()), + feature_types}; +} } // namespace xgboost diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 75d394e30..16e2f13b3 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -43,12 +43,18 @@ struct EllpackDeviceAccessor { base_rowid(base_rowid), n_rows(n_rows) ,gidx_iter(gidx_iter), feature_types{feature_types} { - cuts.cut_values_.SetDevice(device); - cuts.cut_ptrs_.SetDevice(device); - cuts.min_vals_.SetDevice(device); - gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan(); - feature_segments = cuts.cut_ptrs_.ConstDeviceSpan(); - min_fvalue = cuts.min_vals_.ConstDeviceSpan(); + if (device == Context::kCpuId) { + gidx_fvalue_map = cuts.cut_values_.ConstHostSpan(); + feature_segments = cuts.cut_ptrs_.ConstHostSpan(); + min_fvalue = cuts.min_vals_.ConstHostSpan(); + } else { + cuts.cut_values_.SetDevice(device); + cuts.cut_ptrs_.SetDevice(device); + cuts.min_vals_.SetDevice(device); + gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan(); + feature_segments = cuts.cut_ptrs_.ConstDeviceSpan(); + min_fvalue = cuts.min_vals_.ConstDeviceSpan(); + } } // Get a matrix element, uses binary search for look up Return NaN if missing // Given a row index and a feature index, returns the corresponding cut value @@ -202,6 +208,7 @@ class EllpackPageImpl { EllpackDeviceAccessor GetDeviceAccessor(int device, common::Span feature_types = {}) const; + EllpackDeviceAccessor GetHostAccessor(common::Span feature_types = {}) const; private: /*! diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index e34db5495..372dd6f54 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -53,7 +53,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat, // hist CHECK(!sorted_sketch); for (auto const &page : p_fmat->GetBatches()) { - this->columns_->Init(page, *this, sparse_thresh, n_threads); + this->columns_->InitFromSparse(page, *this, sparse_thresh, n_threads); } } } @@ -66,6 +66,12 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts & max_num_bins(max_bin_per_feat), isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} +#if !defined(XGBOOST_USE_CUDA) +GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &, + BatchParam const &) { + common::AssertGPUSupport(); +} +#endif // defined(XGBOOST_USE_CUDA) GHistIndexMatrix::~GHistIndexMatrix() = default; @@ -99,7 +105,7 @@ GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::SpanPushBatch(batch, ft, n_threads); this->columns_ = std::make_unique(); if (!std::isnan(sparse_thresh)) { - this->columns_->Init(batch, *this, sparse_thresh, n_threads); + this->columns_->InitFromSparse(batch, *this, sparse_thresh, n_threads); } } diff --git a/src/data/gradient_index.cu b/src/data/gradient_index.cu new file mode 100644 index 000000000..42d935b3c --- /dev/null +++ b/src/data/gradient_index.cu @@ -0,0 +1,111 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#include // std::unique_ptr + +#include "../common/column_matrix.h" +#include "../common/hist_util.h" // Index +#include "ellpack_page.cuh" +#include "gradient_index.h" +#include "xgboost/data.h" + +namespace xgboost { +// Similar to GHistIndexMatrix::SetIndexData, but without the need for adaptor or bin +// searching. Is there a way to unify the code? +template +void SetIndexData(Context const* ctx, EllpackPageImpl const* page, + std::vector* p_hit_count_tloc, CompressOffset&& get_offset, + GHistIndexMatrix* out) { + auto accessor = page->GetHostAccessor(); + auto const kNull = static_cast(accessor.NullValue()); + + common::Span index_data_span = {out->index.data(), out->index.Size()}; + auto n_bins_total = page->Cuts().TotalBins(); + + auto& hit_count_tloc = *p_hit_count_tloc; + hit_count_tloc.clear(); + hit_count_tloc.resize(ctx->Threads() * n_bins_total, 0); + + common::ParallelFor(page->Size(), ctx->Threads(), [&](auto i) { + auto tid = omp_get_thread_num(); + size_t in_rbegin = page->row_stride * i; + size_t out_rbegin = out->row_ptr[i]; + auto r_size = out->row_ptr[i + 1] - out->row_ptr[i]; + for (size_t j = 0; j < r_size; ++j) { + auto bin_idx = accessor.gidx_iter[in_rbegin + j]; + assert(bin_idx != kNull); + index_data_span[out_rbegin + j] = get_offset(bin_idx, j); + ++hit_count_tloc[tid * n_bins_total + bin_idx]; + } + }); +} + +void GetRowPtrFromEllpack(Context const* ctx, EllpackPageImpl const* page, + std::vector* p_out) { + auto& row_ptr = *p_out; + row_ptr.resize(page->Size() + 1, 0); + if (page->is_dense) { + std::fill(row_ptr.begin() + 1, row_ptr.end(), page->row_stride); + } else { + auto accessor = page->GetHostAccessor(); + auto const kNull = static_cast(accessor.NullValue()); + + common::ParallelFor(page->Size(), ctx->Threads(), [&](auto i) { + size_t ibegin = page->row_stride * i; + for (size_t j = 0; j < page->row_stride; ++j) { + bst_bin_t bin_idx = accessor.gidx_iter[ibegin + j]; + if (bin_idx != kNull) { + row_ptr[i + 1]++; + } + } + }); + } + std::partial_sum(row_ptr.begin(), row_ptr.end(), row_ptr.begin()); +} + +GHistIndexMatrix::GHistIndexMatrix(Context const* ctx, MetaInfo const& info, + EllpackPage const& in_page, BatchParam const& p) + : max_num_bins{p.max_bin} { + auto page = in_page.Impl(); + isDense_ = page->is_dense; + + CHECK_EQ(info.num_row_, in_page.Size()); + + this->cut = page->Cuts(); + // pull to host early, prevent race condition + this->cut.Ptrs(); + this->cut.Values(); + this->cut.MinValues(); + + this->ResizeIndex(info.num_nonzero_, page->is_dense); + if (page->is_dense) { + this->index.SetBinOffset(page->Cuts().Ptrs()); + } + + auto n_bins_total = page->Cuts().TotalBins(); + GetRowPtrFromEllpack(ctx, page, &this->row_ptr); + if (page->is_dense) { + common::DispatchBinType(this->index.GetBinTypeSize(), [&](auto dtype) { + using T = decltype(dtype); + ::xgboost::SetIndexData(ctx, page, &hit_count_tloc_, index.MakeCompressor(), this); + }); + } else { + // no compression + ::xgboost::SetIndexData( + ctx, page, &hit_count_tloc_, [&](auto bin_idx, auto) { return bin_idx; }, this); + } + + this->hit_count.resize(n_bins_total, 0); + this->GatherHitCount(ctx->Threads(), n_bins_total); + + // sanity checks + CHECK_EQ(this->Features(), info.num_col_); + CHECK_EQ(this->Size(), info.num_row_); + CHECK(this->cut.cut_ptrs_.HostCanRead()); + CHECK(this->cut.cut_values_.HostCanRead()); + CHECK(this->cut.min_vals_.HostCanRead()); + + this->columns_ = std::make_unique(*this, p.sparse_thresh); + this->columns_->InitFromGHist(ctx, *this); +} +} // namespace xgboost diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 71c199f81..1e58fcb42 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -69,7 +69,7 @@ class GHistIndexMatrix { if (is_valid(elem)) { bst_bin_t bin_idx{-1}; if (common::IsCat(ft, elem.column_idx)) { - bin_idx = cut.SearchCatBin(elem.value, elem.column_idx); + bin_idx = cut.SearchCatBin(elem.value, elem.column_idx, ptrs, values); } else { bin_idx = cut.SearchBin(elem.value, elem.column_idx, ptrs, values); } @@ -81,6 +81,17 @@ class GHistIndexMatrix { }); } + // Gather hit_count from all threads + void GatherHitCount(int32_t n_threads, bst_bin_t n_bins_total) { + CHECK_EQ(hit_count.size(), n_bins_total); + common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) { + for (int32_t tid = 0; tid < n_threads; ++tid) { + hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx]; + hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch + } + }); + } + template void PushBatchImpl(int32_t n_threads, Batch const& batch, size_t rbegin, IsValid&& is_valid, common::Span ft) { @@ -95,33 +106,20 @@ class GHistIndexMatrix { if (isDense_) { index.SetBinOffset(cut.Ptrs()); } - uint32_t const* offsets = index.Offset(); if (isDense_) { - // Inside the lambda functions, bin_idx is the index for cut value across all - // features. By subtracting it with starting pointer of each feature, we can reduce - // it to smaller value and compress it to smaller types. common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) { using T = decltype(dtype); common::Span index_data_span = {index.data(), index.Size()}; - SetIndexData( - index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total, - [offsets](auto bin_idx, auto fidx) { return static_cast(bin_idx - offsets[fidx]); }); + SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total, + index.MakeCompressor()); }); } else { - /* For sparse DMatrix we have to store index of feature for each bin - in index field to chose right offset. So offset is nullptr and index is - not reduced */ common::Span index_data_span = {index.data(), n_index}; + // no compression SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total, [](auto idx, auto) { return idx; }); } - - common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) { - for (int32_t tid = 0; tid < n_threads; ++tid) { - hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx]; - hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch - } - }); + this->GatherHitCount(n_threads, n_bins_total); } public: @@ -129,12 +127,12 @@ class GHistIndexMatrix { std::vector row_ptr; /*! \brief The index data */ common::Index index; - /*! \brief hit count of each index */ + /*! \brief hit count of each index, used for constructing the ColumnMatrix */ std::vector hit_count; /*! \brief The corresponding cuts */ common::HistogramCuts cut; /*! \brief max_bin for each feature. */ - size_t max_num_bins; + bst_bin_t max_num_bins; /*! \brief base row index for current page (used by external memory) */ size_t base_rowid{0}; @@ -149,6 +147,13 @@ class GHistIndexMatrix { * for push batch. */ GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat); + /** + * \brief Constructor fro Iterative DMatrix where we might copy an existing ellpack page + * to host gradient index. + */ + GHistIndexMatrix(Context const* ctx, MetaInfo const& info, EllpackPage const& page, + BatchParam const& p); + /** * \brief Constructor for external memory. */ diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index e43fcccbc..f108c746b 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -205,12 +205,11 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, BatchSet IterativeDMatrix::GetGradientIndex(BatchParam const& param) { CheckParam(param); - CHECK(ghist_) << R"(`QuantileDMatrix` is not initialized with CPU data but used for CPU training. -Possible solutions: -- Use `DMatrix` instead. -- Use CPU input for `QuantileDMatrix`. -- Run training on GPU. -)"; + if (!ghist_) { + CHECK(ellpack_); + ghist_ = std::make_shared(&ctx_, Info(), *ellpack_, param); + } + auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ghist_)); return BatchSet(begin_iter); diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 06d061382..7a8e5188c 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -29,20 +29,17 @@ namespace data { * `QuantileDMatrix` is an intermediate storage for quantilization results including * quantile cuts and histogram index. Quantilization is designed to be performed on stream * of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work - * with batches of data. During initializaion, it will walk through the data multiple - * times iteratively in order to perform quantilization. This design can help us reduce - * memory usage significantly by avoiding data concatenation along with removing the CSR - * matrix `SparsePage`. However, it has its limitation (can be fixed if needed): + * with batches of data. During initializaion, it walks through the data multiple times + * iteratively in order to perform quantilization. This design helps us reduce memory + * usage significantly by avoiding data concatenation along with removing the CSR matrix + * `SparsePage`. However, it has its limitation (can be fixed if needed): * * - It's only supported by hist tree method (both CPU and GPU) since approx requires a * re-calculation of quantiles for each iteration. We can fix this by retaining a * reference to the callback if there are feature requests. * * - The CPU format and the GPU format are different, the former uses a CSR + CSC for - * histogram index while the latter uses only Ellpack. This results into a design that - * we can obtain the GPU format from CPU but the other way around is not yet - * supported. We can search the bin value from ellpack to recover the feature index when - * we support copying data from GPU to CPU. + * histogram index while the latter uses only Ellpack. */ class IterativeDMatrix : public DMatrix { MetaInfo info_; diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index cdd38468a..e2f59c58d 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -23,7 +23,7 @@ TEST(DenseColumn, Test) { common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); + column_matrix.InitFromSparse(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); } ASSERT_GE(column_matrix.GetTypeSize(), last); ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize); @@ -69,7 +69,7 @@ TEST(SparseColumn, Test) { GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0)); + column_matrix.InitFromSparse(page, gmat, 1.0, common::OmpGetNumThreads(0)); } common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { using T = decltype(dtype); @@ -97,7 +97,7 @@ TEST(DenseColumnWithMissing, Test) { GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); + column_matrix.InitFromSparse(page, gmat, 0.2, common::OmpGetNumThreads(0)); } ASSERT_TRUE(column_matrix.AnyMissing()); DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 6e5d1312d..6233f1b25 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -5,6 +5,7 @@ #include #include "../../../src/common/column_matrix.h" +#include "../../../src/common/io.h" // MemoryBufferStream #include "../../../src/data/gradient_index.h" #include "../helpers.h" @@ -107,5 +108,81 @@ TEST(GradientIndex, PushBatch) { test(0.5f); test(0.9f); } + +#if defined(XGBOOST_USE_CUDA) + +namespace { +class GHistIndexMatrixTest : public testing::TestWithParam> { + protected: + void Run(float density, double threshold) { + // Only testing with small sample size as the cuts might be different between host and + // device. + size_t n_samples{128}, n_features{13}; + Context ctx; + ctx.gpu_id = 0; + auto Xy = RandomDataGenerator{n_samples, n_features, 1 - density}.GenerateDMatrix(true); + std::unique_ptr from_ellpack; + ASSERT_TRUE(Xy->SingleColBlock()); + bst_bin_t constexpr kBins{17}; + auto p = BatchParam{kBins, threshold}; + for (auto const &page : Xy->GetBatches(BatchParam{0, kBins})) { + from_ellpack.reset(new GHistIndexMatrix{&ctx, Xy->Info(), page, p}); + } + + for (auto const &from_sparse_page : Xy->GetBatches(p)) { + ASSERT_EQ(from_sparse_page.IsDense(), from_ellpack->IsDense()); + ASSERT_EQ(from_sparse_page.base_rowid, 0); + ASSERT_EQ(from_sparse_page.base_rowid, from_ellpack->base_rowid); + ASSERT_EQ(from_sparse_page.Size(), from_ellpack->Size()); + ASSERT_EQ(from_sparse_page.index.Size(), from_ellpack->index.Size()); + + auto const &gidx_from_sparse = from_sparse_page.index; + auto const &gidx_from_ellpack = from_ellpack->index; + + for (size_t i = 0; i < gidx_from_sparse.Size(); ++i) { + ASSERT_EQ(gidx_from_sparse[i], gidx_from_ellpack[i]); + } + + auto const &columns_from_sparse = from_sparse_page.Transpose(); + auto const &columns_from_ellpack = from_ellpack->Transpose(); + ASSERT_EQ(columns_from_sparse.AnyMissing(), columns_from_ellpack.AnyMissing()); + ASSERT_EQ(columns_from_sparse.GetTypeSize(), columns_from_ellpack.GetTypeSize()); + ASSERT_EQ(columns_from_sparse.GetNumFeature(), columns_from_ellpack.GetNumFeature()); + for (size_t i = 0; i < n_features; ++i) { + ASSERT_EQ(columns_from_sparse.GetColumnType(i), columns_from_ellpack.GetColumnType(i)); + } + + std::string from_sparse_buf; + { + common::MemoryBufferStream fo{&from_sparse_buf}; + columns_from_sparse.Write(&fo); + } + std::string from_ellpack_buf; + { + common::MemoryBufferStream fo{&from_ellpack_buf}; + columns_from_sparse.Write(&fo); + } + ASSERT_EQ(from_sparse_buf, from_ellpack_buf); + } + } +}; +} // anonymous namespace + +TEST_P(GHistIndexMatrixTest, FromEllpack) { + float sparsity; + double thresh; + std::tie(sparsity, thresh) = GetParam(); + this->Run(sparsity, thresh); +} + +INSTANTIATE_TEST_SUITE_P(GHistIndexMatrix, GHistIndexMatrixTest, + testing::Values(std::make_tuple(1.f, .0), // no missing + std::make_tuple(.2f, .8), // sparse columns + std::make_tuple(.8f, .2), // dense columns + std::make_tuple(1.f, .2), // no missing + std::make_tuple(.5f, .6), // sparse columns + std::make_tuple(.6f, .4))); // dense columns + +#endif // defined(XGBOOST_USE_CUDA) } // namespace data } // namespace xgboost diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index c34f63b46..f1491b829 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -37,7 +37,7 @@ TEST(QuantileHist, Partitioner) { GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads()); bst_feature_t const split_ind = 0; common::ColumnMatrix column_indices; - column_indices.Init(page, gmat, 0.5, ctx.Threads()); + column_indices.InitFromSparse(page, gmat, 0.5, ctx.Threads()); { auto min_value = gmat.cut.MinValues()[split_ind]; RegTree tree; diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index dee603920..88eae3890 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -32,32 +32,41 @@ class TestDeviceQuantileDMatrix: xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) @pytest.mark.skipif(**tm.no_cupy()) - def test_from_host(self) -> None: + @pytest.mark.parametrize( + "tree_method,max_bin", [ + ("hist", 16), ("gpu_hist", 16), ("hist", 64), ("gpu_hist", 64) + ] + ) + def test_interoperability(self, tree_method: str, max_bin: int) -> None: import cupy as cp n_samples = 64 n_features = 3 X, y, w = tm.make_batches( n_samples, n_features=n_features, n_batches=1, use_cupy=False ) - Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) - booster_0 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) + # from CPU + Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin) + booster_0 = xgb.train( + {"tree_method": tree_method, "max_bin": max_bin}, Xy, num_boost_round=4 + ) X[0] = cp.array(X[0]) y[0] = cp.array(y[0]) w[0] = cp.array(w[0]) - Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0]) - booster_1 = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=4) + # from GPU + Xy = xgb.QuantileDMatrix(X[0], y[0], weight=w[0], max_bin=max_bin) + booster_1 = xgb.train( + {"tree_method": tree_method, "max_bin": max_bin}, Xy, num_boost_round=4 + ) cp.testing.assert_allclose( booster_0.inplace_predict(X[0]), booster_1.inplace_predict(X[0]) ) - with pytest.raises(ValueError, match="not initialized with CPU"): - # Training on CPU with GPU data is not supported. - xgb.train({"tree_method": "hist"}, Xy, num_boost_round=4) - with pytest.raises(ValueError, match=r"Only.*hist.*"): - xgb.train({"tree_method": "approx"}, Xy, num_boost_round=4) + xgb.train( + {"tree_method": "approx", "max_bin": max_bin}, Xy, num_boost_round=4 + ) @pytest.mark.skipif(**tm.no_cupy()) def test_metainfo(self) -> None: