From e78a38b8375e0bc8e9c0a6f9232e80faadc17c21 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 16 Mar 2022 18:01:05 +0800 Subject: [PATCH] Sort sparse page index when constructing DMatrix. (#7731) --- include/xgboost/data.h | 12 ++++++++++++ src/data/data.cc | 27 ++++++++++++++++++++++++++- src/data/simple_dmatrix.cc | 5 +++++ tests/cpp/data/test_data.cc | 24 ++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 7399b8265..1655f9d0e 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -201,6 +201,9 @@ struct Entry { inline static bool CmpValue(const Entry& a, const Entry& b) { return a.fvalue < b.fvalue; } + static bool CmpIndex(Entry const& a, Entry const& b) { + return a.index < b.index; + } inline bool operator==(const Entry& other) const { return (this->index == other.index && this->fvalue == other.fvalue); } @@ -313,6 +316,15 @@ class SparsePage { SparsePage GetTranspose(int num_columns, int32_t n_threads) const; + /** + * \brief Sort the column index. + */ + void SortIndices(int32_t n_threads); + /** + * \brief Check wether the column index is sorted. + */ + bool IsIndicesSorted(int32_t n_threads) const; + void SortRows(int32_t n_threads); /** diff --git a/src/data/data.cc b/src/data/data.cc index 3e3a1141f..57940199f 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1037,6 +1037,32 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { return transpose; } +bool SparsePage::IsIndicesSorted(int32_t n_threads) const { + auto& h_offset = this->offset.HostVector(); + auto& h_data = this->data.HostVector(); + std::vector is_sorted_tloc(n_threads, 0); + common::ParallelFor(this->Size(), n_threads, [&](auto i) { + auto beg = h_offset[i]; + auto end = h_offset[i + 1]; + is_sorted_tloc[omp_get_thread_num()] += + !!std::is_sorted(h_data.begin() + beg, h_data.begin() + end, Entry::CmpIndex); + }); + auto is_sorted = std::accumulate(is_sorted_tloc.cbegin(), is_sorted_tloc.cend(), + static_cast(0)) == this->Size(); + return is_sorted; +} + +void SparsePage::SortIndices(int32_t n_threads) { + auto& h_offset = this->offset.HostVector(); + auto& h_data = this->data.HostVector(); + + common::ParallelFor(this->Size(), n_threads, [&](auto i) { + auto beg = h_offset[i]; + auto end = h_offset[i + 1]; + std::sort(h_data.begin() + beg, h_data.begin() + end, Entry::CmpIndex); + }); +} + void SparsePage::SortRows(int32_t n_threads) { auto& h_offset = this->offset.HostVector(); auto& h_data = this->data.HostVector(); @@ -1162,7 +1188,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread }); } exec.Rethrow(); - return max_columns; } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 7d2ab32c2..8d42a4220 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -211,6 +211,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { info_.num_row_ = adapter->NumRows(); } info_.num_nonzero_ = data_vec.size(); + + // Sort the index for row partitioners used by variuos tree methods. + if (!sparse_page_->IsIndicesSorted(this->ctx_.Threads())) { + sparse_page_->SortIndices(this->ctx_.Threads()); + } } SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) { diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 5dc3d0646..92e94fee8 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -86,6 +86,30 @@ TEST(SparsePage, PushCSCAfterTranspose) { } } +TEST(SparsePage, SortIndices) { + auto p_fmat = RandomDataGenerator{100, 10, 0.6}.GenerateDMatrix(); + auto n_threads = common::OmpGetNumThreads(0); + SparsePage copy; + for (auto const& page : p_fmat->GetBatches()) { + ASSERT_TRUE(page.IsIndicesSorted(n_threads)); + copy.Push(page); + } + ASSERT_TRUE(copy.IsIndicesSorted(n_threads)); + + for (size_t ridx = 0; ridx < copy.Size(); ++ridx) { + auto beg = copy.offset.HostVector()[ridx]; + auto end = copy.offset.HostVector()[ridx + 1]; + auto& h_data = copy.data.HostVector(); + if (end - beg >= 2) { + std::swap(h_data[beg], h_data[end - 1]); + } + } + ASSERT_FALSE(copy.IsIndicesSorted(n_threads)); + + copy.SortIndices(n_threads); + ASSERT_TRUE(copy.IsIndicesSorted(n_threads)); +} + TEST(DMatrix, Uri) { size_t constexpr kRows {16}; size_t constexpr kCols {8};