Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -6,6 +6,8 @@
|
||||
#define XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <cinttypes> // std::uint32_t
|
||||
#include <cstddef> // std::size_t
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
@@ -229,6 +231,53 @@ class GHistIndexMatrix {
|
||||
bool isDense_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Helper for recovering feature index from row-based storage of histogram
|
||||
* bin. (`GHistIndexMatrix`).
|
||||
*
|
||||
* \param assign A callback function that takes bin index, index into the whole batch, row
|
||||
* index and feature index
|
||||
*/
|
||||
template <typename Fn>
|
||||
void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) {
|
||||
auto const batch_size = page.Size();
|
||||
auto const& ptrs = page.cut.Ptrs();
|
||||
std::size_t k{0};
|
||||
|
||||
auto dense = page.IsDense();
|
||||
|
||||
common::DispatchBinType(page.index.GetBinTypeSize(), [&](auto t) {
|
||||
using BinT = decltype(t);
|
||||
auto const& index = page.index;
|
||||
for (std::size_t ridx = 0; ridx < batch_size; ++ridx) {
|
||||
auto r_beg = page.row_ptr[ridx];
|
||||
auto r_end = page.row_ptr[ridx + 1];
|
||||
bst_feature_t fidx{0};
|
||||
if (dense) {
|
||||
// compressed, use the operator to obtain the true value.
|
||||
for (std::size_t j = r_beg; j < r_end; ++j) {
|
||||
bst_feature_t fidx = j - r_beg;
|
||||
std::uint32_t bin_idx = index[k];
|
||||
assign(bin_idx, k, ridx, fidx);
|
||||
++k;
|
||||
}
|
||||
} else {
|
||||
// not compressed
|
||||
auto const* row_index = index.data<BinT>() + page.row_ptr[page.base_rowid];
|
||||
for (std::size_t j = r_beg; j < r_end; ++j) {
|
||||
std::uint32_t bin_idx = row_index[k];
|
||||
// find the feature index for current bin.
|
||||
while (bin_idx >= ptrs[fidx + 1]) {
|
||||
fidx++;
|
||||
}
|
||||
assign(bin_idx, k, ridx, fidx);
|
||||
++k;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Should we regenerate the gradient index?
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user