Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -316,30 +316,16 @@ class ColumnMatrix {
|
||||
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
|
||||
auto n_features = gmat.Features();
|
||||
missing_flags_.resize(feature_offsets_[n_features], true);
|
||||
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[gmat.base_rowid];
|
||||
num_nonzeros_.resize(n_features, 0);
|
||||
auto const& ptrs = gmat.cut.Ptrs();
|
||||
|
||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||
using ColumnBinT = decltype(t);
|
||||
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
|
||||
auto const batch_size = gmat.Size();
|
||||
size_t k{0};
|
||||
|
||||
for (size_t ridx = 0; ridx < batch_size; ++ridx) {
|
||||
auto r_beg = gmat.row_ptr[ridx];
|
||||
auto r_end = gmat.row_ptr[ridx + 1];
|
||||
bst_feature_t fidx{0};
|
||||
for (size_t j = r_beg; j < r_end; ++j) {
|
||||
const uint32_t bin_idx = row_index[k];
|
||||
// find the feature index for current bin.
|
||||
while (bin_idx >= ptrs[fidx + 1]) {
|
||||
fidx++;
|
||||
}
|
||||
SetBinSparse(bin_idx, ridx, fidx, local_index);
|
||||
++k;
|
||||
}
|
||||
}
|
||||
CHECK(this->any_missing_);
|
||||
AssignColumnBinIndex(gmat,
|
||||
[&](auto bin_idx, std::size_t, std::size_t ridx, bst_feature_t fidx) {
|
||||
SetBinSparse(bin_idx, ridx, fidx, local_index);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -149,6 +149,19 @@ class HistogramCuts {
|
||||
return this->SearchCatBin(value, fidx, ptrs, vals);
|
||||
}
|
||||
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
|
||||
|
||||
/**
|
||||
* \brief Return numerical bin value given bin index.
|
||||
*/
|
||||
static float NumericBinValue(std::vector<std::uint32_t> const& ptrs,
|
||||
std::vector<float> const& vals, std::vector<float> const& mins,
|
||||
bst_feature_t fidx, bst_bin_t bin_idx) {
|
||||
auto lower = static_cast<bst_bin_t>(ptrs[fidx]);
|
||||
if (bin_idx == lower) {
|
||||
return mins[fidx];
|
||||
}
|
||||
return vals[bin_idx - 1];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user