Sort sparse page index when constructing DMatrix. (#7731)
This commit is contained in:
parent
613ec36c5a
commit
e78a38b837
@ -201,6 +201,9 @@ struct Entry {
|
|||||||
inline static bool CmpValue(const Entry& a, const Entry& b) {
|
inline static bool CmpValue(const Entry& a, const Entry& b) {
|
||||||
return a.fvalue < b.fvalue;
|
return a.fvalue < b.fvalue;
|
||||||
}
|
}
|
||||||
|
static bool CmpIndex(Entry const& a, Entry const& b) {
|
||||||
|
return a.index < b.index;
|
||||||
|
}
|
||||||
inline bool operator==(const Entry& other) const {
|
inline bool operator==(const Entry& other) const {
|
||||||
return (this->index == other.index && this->fvalue == other.fvalue);
|
return (this->index == other.index && this->fvalue == other.fvalue);
|
||||||
}
|
}
|
||||||
@ -313,6 +316,15 @@ class SparsePage {
|
|||||||
|
|
||||||
SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
|
SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Sort the column index.
|
||||||
|
*/
|
||||||
|
void SortIndices(int32_t n_threads);
|
||||||
|
/**
|
||||||
|
* \brief Check wether the column index is sorted.
|
||||||
|
*/
|
||||||
|
bool IsIndicesSorted(int32_t n_threads) const;
|
||||||
|
|
||||||
void SortRows(int32_t n_threads);
|
void SortRows(int32_t n_threads);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1037,6 +1037,32 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
|||||||
return transpose;
|
return transpose;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
|
||||||
|
auto& h_offset = this->offset.HostVector();
|
||||||
|
auto& h_data = this->data.HostVector();
|
||||||
|
std::vector<int32_t> is_sorted_tloc(n_threads, 0);
|
||||||
|
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
|
||||||
|
auto beg = h_offset[i];
|
||||||
|
auto end = h_offset[i + 1];
|
||||||
|
is_sorted_tloc[omp_get_thread_num()] +=
|
||||||
|
!!std::is_sorted(h_data.begin() + beg, h_data.begin() + end, Entry::CmpIndex);
|
||||||
|
});
|
||||||
|
auto is_sorted = std::accumulate(is_sorted_tloc.cbegin(), is_sorted_tloc.cend(),
|
||||||
|
static_cast<size_t>(0)) == this->Size();
|
||||||
|
return is_sorted;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparsePage::SortIndices(int32_t n_threads) {
|
||||||
|
auto& h_offset = this->offset.HostVector();
|
||||||
|
auto& h_data = this->data.HostVector();
|
||||||
|
|
||||||
|
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
|
||||||
|
auto beg = h_offset[i];
|
||||||
|
auto end = h_offset[i + 1];
|
||||||
|
std::sort(h_data.begin() + beg, h_data.begin() + end, Entry::CmpIndex);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void SparsePage::SortRows(int32_t n_threads) {
|
void SparsePage::SortRows(int32_t n_threads) {
|
||||||
auto& h_offset = this->offset.HostVector();
|
auto& h_offset = this->offset.HostVector();
|
||||||
auto& h_data = this->data.HostVector();
|
auto& h_data = this->data.HostVector();
|
||||||
@ -1162,7 +1188,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
exec.Rethrow();
|
exec.Rethrow();
|
||||||
|
|
||||||
return max_columns;
|
return max_columns;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -211,6 +211,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
|||||||
info_.num_row_ = adapter->NumRows();
|
info_.num_row_ = adapter->NumRows();
|
||||||
}
|
}
|
||||||
info_.num_nonzero_ = data_vec.size();
|
info_.num_nonzero_ = data_vec.size();
|
||||||
|
|
||||||
|
// Sort the index for row partitioners used by variuos tree methods.
|
||||||
|
if (!sparse_page_->IsIndicesSorted(this->ctx_.Threads())) {
|
||||||
|
sparse_page_->SortIndices(this->ctx_.Threads());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
|
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
|
||||||
|
|||||||
@ -86,6 +86,30 @@ TEST(SparsePage, PushCSCAfterTranspose) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SparsePage, SortIndices) {
|
||||||
|
auto p_fmat = RandomDataGenerator{100, 10, 0.6}.GenerateDMatrix();
|
||||||
|
auto n_threads = common::OmpGetNumThreads(0);
|
||||||
|
SparsePage copy;
|
||||||
|
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
||||||
|
ASSERT_TRUE(page.IsIndicesSorted(n_threads));
|
||||||
|
copy.Push(page);
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(copy.IsIndicesSorted(n_threads));
|
||||||
|
|
||||||
|
for (size_t ridx = 0; ridx < copy.Size(); ++ridx) {
|
||||||
|
auto beg = copy.offset.HostVector()[ridx];
|
||||||
|
auto end = copy.offset.HostVector()[ridx + 1];
|
||||||
|
auto& h_data = copy.data.HostVector();
|
||||||
|
if (end - beg >= 2) {
|
||||||
|
std::swap(h_data[beg], h_data[end - 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_FALSE(copy.IsIndicesSorted(n_threads));
|
||||||
|
|
||||||
|
copy.SortIndices(n_threads);
|
||||||
|
ASSERT_TRUE(copy.IsIndicesSorted(n_threads));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(DMatrix, Uri) {
|
TEST(DMatrix, Uri) {
|
||||||
size_t constexpr kRows {16};
|
size_t constexpr kRows {16};
|
||||||
size_t constexpr kCols {8};
|
size_t constexpr kCols {8};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user