Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -6,6 +6,8 @@
#define XGBOOST_DATA_GRADIENT_INDEX_H_
#include <algorithm> // std::min
#include <cinttypes> // std::uint32_t
#include <cstddef> // std::size_t
#include <memory>
#include <vector>
@@ -229,6 +231,53 @@ class GHistIndexMatrix {
bool isDense_;
};
/**
* \brief Helper for recovering feature index from row-based storage of histogram
* bin. (`GHistIndexMatrix`).
*
* \param assign A callback function that takes bin index, index into the whole batch, row
* index and feature index
*/
template <typename Fn>
void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) {
auto const batch_size = page.Size();
auto const& ptrs = page.cut.Ptrs();
std::size_t k{0};
auto dense = page.IsDense();
common::DispatchBinType(page.index.GetBinTypeSize(), [&](auto t) {
using BinT = decltype(t);
auto const& index = page.index;
for (std::size_t ridx = 0; ridx < batch_size; ++ridx) {
auto r_beg = page.row_ptr[ridx];
auto r_end = page.row_ptr[ridx + 1];
bst_feature_t fidx{0};
if (dense) {
// compressed, use the operator to obtain the true value.
for (std::size_t j = r_beg; j < r_end; ++j) {
bst_feature_t fidx = j - r_beg;
std::uint32_t bin_idx = index[k];
assign(bin_idx, k, ridx, fidx);
++k;
}
} else {
// not compressed
auto const* row_index = index.data<BinT>() + page.row_ptr[page.base_rowid];
for (std::size_t j = r_beg; j < r_end; ++j) {
std::uint32_t bin_idx = row_index[k];
// find the feature index for current bin.
while (bin_idx >= ptrs[fidx + 1]) {
fidx++;
}
assign(bin_idx, k, ridx, fidx);
++k;
}
}
}
});
}
/**
* \brief Should we regenerate the gradient index?
*