Sort sparse page index when constructing DMatrix. (#7731)

This commit is contained in:
Jiaming Yuan 2022-03-16 18:01:05 +08:00 committed by GitHub
parent 613ec36c5a
commit e78a38b837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 1 deletions

View File

@ -201,6 +201,9 @@ struct Entry {
inline static bool CmpValue(const Entry& a, const Entry& b) {
return a.fvalue < b.fvalue;
}
static bool CmpIndex(Entry const& a, Entry const& b) {
return a.index < b.index;
}
inline bool operator==(const Entry& other) const {
return (this->index == other.index && this->fvalue == other.fvalue);
}
@ -313,6 +316,15 @@ class SparsePage {
SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
/**
* \brief Sort the column index.
*/
void SortIndices(int32_t n_threads);
/**
* \brief Check wether the column index is sorted.
*/
bool IsIndicesSorted(int32_t n_threads) const;
void SortRows(int32_t n_threads);
/**

View File

@ -1037,6 +1037,32 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
return transpose;
}
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
std::vector<int32_t> is_sorted_tloc(n_threads, 0);
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
auto beg = h_offset[i];
auto end = h_offset[i + 1];
is_sorted_tloc[omp_get_thread_num()] +=
!!std::is_sorted(h_data.begin() + beg, h_data.begin() + end, Entry::CmpIndex);
});
auto is_sorted = std::accumulate(is_sorted_tloc.cbegin(), is_sorted_tloc.cend(),
static_cast<size_t>(0)) == this->Size();
return is_sorted;
}
void SparsePage::SortIndices(int32_t n_threads) {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
auto beg = h_offset[i];
auto end = h_offset[i + 1];
std::sort(h_data.begin() + beg, h_data.begin() + end, Entry::CmpIndex);
});
}
void SparsePage::SortRows(int32_t n_threads) {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
@ -1162,7 +1188,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
});
}
exec.Rethrow();
return max_columns;
}

View File

@ -211,6 +211,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
info_.num_row_ = adapter->NumRows();
}
info_.num_nonzero_ = data_vec.size();
// Sort the index for row partitioners used by variuos tree methods.
if (!sparse_page_->IsIndicesSorted(this->ctx_.Threads())) {
sparse_page_->SortIndices(this->ctx_.Threads());
}
}
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {

View File

@ -86,6 +86,30 @@ TEST(SparsePage, PushCSCAfterTranspose) {
}
}
TEST(SparsePage, SortIndices) {
auto p_fmat = RandomDataGenerator{100, 10, 0.6}.GenerateDMatrix();
auto n_threads = common::OmpGetNumThreads(0);
SparsePage copy;
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
ASSERT_TRUE(page.IsIndicesSorted(n_threads));
copy.Push(page);
}
ASSERT_TRUE(copy.IsIndicesSorted(n_threads));
for (size_t ridx = 0; ridx < copy.Size(); ++ridx) {
auto beg = copy.offset.HostVector()[ridx];
auto end = copy.offset.HostVector()[ridx + 1];
auto& h_data = copy.data.HostVector();
if (end - beg >= 2) {
std::swap(h_data[beg], h_data[end - 1]);
}
}
ASSERT_FALSE(copy.IsIndicesSorted(n_threads));
copy.SortIndices(n_threads);
ASSERT_TRUE(copy.IsIndicesSorted(n_threads));
}
TEST(DMatrix, Uri) {
size_t constexpr kRows {16};
size_t constexpr kCols {8};