diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 77c67620b..45614e6e2 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -16,6 +16,7 @@ #include // std::move #include +#include "../data/adapter.h" #include "../data/gradient_index.h" #include "hist_util.h" @@ -128,7 +129,7 @@ class DenseColumnIter : public Column { /** * \brief Column major matrix for gradient index. This matrix contains both dense column * and sparse column, the type of the column is controlled by sparse threshold. When the - * number of missing values in a column is below the threshold it classified as dense + * number of missing values in a column is below the threshold it's classified as dense * column. */ class ColumnMatrix { @@ -136,9 +137,9 @@ class ColumnMatrix { // get number of features bst_feature_t GetNumFeature() const { return static_cast(type_.size()); } - // construct column matrix from GHistIndexMatrix - void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, - int32_t n_threads) { + template + void Init(Batch const& batch, float missing, GHistIndexMatrix const& gmat, + double sparse_threshold, int32_t n_threads) { auto const nfeature = static_cast(gmat.cut.Ptrs().size() - 1); const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column @@ -190,6 +191,7 @@ class ColumnMatrix { any_missing_ = !gmat.IsDense(); missing_flags_.clear(); + // pre-fill index_ for dense columns BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); if (!any_missing_) { @@ -197,14 +199,21 @@ class ColumnMatrix { // row index is compressed, we need to dispatch it. DispatchBinType(gmat_bin_size, [&, nrow, nfeature, n_threads](auto t) { using RowBinIdxT = decltype(t); - SetIndexNoMissing(page, gmat.index.data(), nrow, nfeature, n_threads); + SetIndexNoMissing(gmat.index.data(), nrow, nfeature, n_threads); }); } else { missing_flags_.resize(feature_offsets_[nfeature], true); - SetIndexMixedColumns(page, gmat.index.data(), gmat, nfeature); + SetIndexMixedColumns(batch, gmat.index.data(), gmat, nfeature, missing); } } + // construct column matrix from GHistIndexMatrix + void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, + int32_t n_threads) { + auto batch = data::SparsePageAdapterBatch{page.GetView()}; + this->Init(batch, std::numeric_limits::quiet_NaN(), gmat, sparse_threshold, n_threads); + } + /* Set the number of bytes based on numeric limit of maximum number of bins provided by user */ void SetTypeSize(size_t max_bin_per_feat) { if ((max_bin_per_feat - 1) <= static_cast(std::numeric_limits::max())) { @@ -241,8 +250,8 @@ class ColumnMatrix { // all columns are dense column and has no missing value // FIXME(jiamingy): We don't need a column matrix if there's no missing value. template - void SetIndexNoMissing(SparsePage const& page, RowBinIdxT const* row_index, - const size_t n_samples, const size_t n_features, int32_t n_threads) { + void SetIndexNoMissing(RowBinIdxT const* row_index, const size_t n_samples, + const size_t n_features, int32_t n_threads) { DispatchBinType(bins_type_size_, [&](auto t) { using ColumnBinT = decltype(t); auto column_index = Span{reinterpret_cast(index_.data()), @@ -263,10 +272,12 @@ class ColumnMatrix { /** * \brief Set column index for both dense and sparse columns */ - void SetIndexMixedColumns(SparsePage const& page, uint32_t const* row_index, - const GHistIndexMatrix& gmat, size_t n_features) { + template + void SetIndexMixedColumns(Batch const& batch, uint32_t const* row_index, + const GHistIndexMatrix& gmat, size_t n_features, float missing) { std::vector num_nonzeros; num_nonzeros.resize(n_features, 0); + auto is_valid = data::IsValidFunctor {missing}; DispatchBinType(bins_type_size_, [&](auto t) { using ColumnBinT = decltype(t); @@ -276,7 +287,8 @@ class ColumnMatrix { if (type_[fid] == kDenseColumn) { ColumnBinT* begin = &local_index[feature_offsets_[fid]]; begin[rid] = bin_id - index_base_[fid]; - // not thread-safe with bool vector. + // not thread-safe with bool vector. FIXME(jiamingy): We can directly assign + // kMissingId to the index to avoid missing flags. missing_flags_[feature_offsets_[fid] + rid] = false; } else { ColumnBinT* begin = &local_index[feature_offsets_[fid]]; @@ -286,22 +298,18 @@ class ColumnMatrix { } }; - const xgboost::Entry* data_ptr = page.data.HostVector().data(); - const std::vector& offset_vec = page.offset.HostVector(); const size_t batch_size = gmat.Size(); - CHECK_LT(batch_size, offset_vec.size()); + size_t k{0}; for (size_t rid = 0; rid < batch_size; ++rid) { - const size_t ibegin = gmat.row_ptr[rid]; - const size_t iend = gmat.row_ptr[rid + 1]; - const size_t size = offset_vec[rid + 1] - offset_vec[rid]; - SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; - - CHECK_EQ(ibegin + inst.size(), iend); - size_t j = 0; - for (size_t i = ibegin; i < iend; ++i, ++j) { - const uint32_t bin_id = row_index[i]; - auto fid = inst[j].index; - get_bin_idx(bin_id, rid, fid); + auto line = batch.GetLine(rid); + for (size_t i = 0; i < line.Size(); ++i) { + auto coo = line.GetElement(i); + if (is_valid(coo)) { + auto fid = coo.column_idx; + const uint32_t bin_id = row_index[k]; + get_bin_idx(bin_id, rid, fid); + ++k; + } } } }); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 66188bec6..6671f05e3 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -211,6 +211,8 @@ auto DispatchBinType(BinTypeSize type, Fn&& fn) { return fn(uint32_t{}); } } + LOG(FATAL) << "Unreachable"; + return fn(uint32_t{}); } /** diff --git a/src/data/adapter.h b/src/data/adapter.h index 4025ccd8e..e6cb6d8b9 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -1131,6 +1131,24 @@ class RecordBatchesIterAdapter: public dmlc::DataIter { struct ArrowSchemaImporter schema_; ArrowColumnarBatchVec batches_; }; + +class SparsePageAdapterBatch { + HostSparsePageView page_; + + public: + struct Line { + SparsePage::Inst inst; + bst_row_t ridx; + COOTuple GetElement(size_t idx) const { + return COOTuple{ridx, inst.data()[idx].index, inst.data()[idx].fvalue}; + } + size_t Size() const { return inst.size(); } + }; + + explicit SparsePageAdapterBatch(HostSparsePageView page) : page_{std::move(page)} {} + Line GetLine(size_t ridx) const { return Line{page_[ridx], ridx}; } + size_t Size() const { return page_.Size(); } +}; }; // namespace data } // namespace xgboost #endif // XGBOOST_DATA_ADAPTER_H_ diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 1122c04d5..4b6b0e91d 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -31,34 +31,33 @@ TEST(DenseColumn, Test) { ASSERT_FALSE(column_matrix.AnyMissing()); for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) { - switch (column_matrix.GetTypeSize()) { - case kUint8BinsTypeSize: { - auto col = column_matrix.DenseColumn(j); - ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); - } break; - case kUint16BinsTypeSize: { - auto col = column_matrix.DenseColumn(j); - ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); - } break; - case kUint32BinsTypeSize: { - auto col = column_matrix.DenseColumn(j); - ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); - } break; - } + DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { + using T = decltype(dtype); + auto col = column_matrix.DenseColumn(j); + ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); + }); } } } } template -inline void CheckSparseColumn(const SparseColumnIter& col_input, - const GHistIndexMatrix& gmat) { - const SparseColumnIter& col = - static_cast&>(col_input); +void CheckSparseColumn(SparseColumnIter* p_col, const GHistIndexMatrix& gmat) { + auto& col = *p_col; + + size_t n_samples = gmat.row_ptr.size() - 1; ASSERT_EQ(col.Size(), gmat.index.Size()); for (auto i = 0ull; i < col.Size(); i++) { ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i)); } + + for (auto i = 0ull; i < n_samples; i++) { + if (col[i] == Column::kMissingId) { + auto beg = gmat.row_ptr[i]; + auto end = gmat.row_ptr[i + 1]; + ASSERT_EQ(end - beg, 0); + } + } } TEST(SparseColumn, Test) { @@ -72,26 +71,17 @@ TEST(SparseColumn, Test) { for (auto const& page : dmat->GetBatches()) { column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0)); } - switch (column_matrix.GetTypeSize()) { - case kUint8BinsTypeSize: { - auto col = column_matrix.SparseColumn(0, 0); - CheckSparseColumn(col, gmat); - } break; - case kUint16BinsTypeSize: { - auto col = column_matrix.SparseColumn(0, 0); - CheckSparseColumn(col, gmat); - } break; - case kUint32BinsTypeSize: { - auto col = column_matrix.SparseColumn(0, 0); - CheckSparseColumn(col, gmat); - } break; - } + common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { + using T = decltype(dtype); + auto col = column_matrix.SparseColumn(0, 0); + CheckSparseColumn(&col, gmat); + }); } } template -inline void CheckColumWithMissingValue(const DenseColumnIter& col, - const GHistIndexMatrix& gmat) { +void CheckColumWithMissingValue(const DenseColumnIter& col, + const GHistIndexMatrix& gmat) { for (auto i = 0ull; i < col.Size(); i++) { if (col.IsMissing(i)) continue; EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i)); @@ -110,20 +100,11 @@ TEST(DenseColumnWithMissing, Test) { column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); } ASSERT_TRUE(column_matrix.AnyMissing()); - switch (column_matrix.GetTypeSize()) { - case kUint8BinsTypeSize: { - auto col = column_matrix.DenseColumn(0); - CheckColumWithMissingValue(col, gmat); - } break; - case kUint16BinsTypeSize: { - auto col = column_matrix.DenseColumn(0); - CheckColumWithMissingValue(col, gmat); - } break; - case kUint32BinsTypeSize: { - auto col = column_matrix.DenseColumn(0); - CheckColumWithMissingValue(col, gmat); - } break; - } + DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { + using T = decltype(dtype); + auto col = column_matrix.DenseColumn(0); + CheckColumWithMissingValue(col, gmat); + }); } }