Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -316,30 +316,16 @@ class ColumnMatrix {
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
auto n_features = gmat.Features();
missing_flags_.resize(feature_offsets_[n_features], true);
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[gmat.base_rowid];
num_nonzeros_.resize(n_features, 0);
auto const& ptrs = gmat.cut.Ptrs();
DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t);
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
auto const batch_size = gmat.Size();
size_t k{0};
for (size_t ridx = 0; ridx < batch_size; ++ridx) {
auto r_beg = gmat.row_ptr[ridx];
auto r_end = gmat.row_ptr[ridx + 1];
bst_feature_t fidx{0};
for (size_t j = r_beg; j < r_end; ++j) {
const uint32_t bin_idx = row_index[k];
// find the feature index for current bin.
while (bin_idx >= ptrs[fidx + 1]) {
fidx++;
}
SetBinSparse(bin_idx, ridx, fidx, local_index);
++k;
}
}
CHECK(this->any_missing_);
AssignColumnBinIndex(gmat,
[&](auto bin_idx, std::size_t, std::size_t ridx, bst_feature_t fidx) {
SetBinSparse(bin_idx, ridx, fidx, local_index);
});
});
}

View File

@@ -149,6 +149,19 @@ class HistogramCuts {
return this->SearchCatBin(value, fidx, ptrs, vals);
}
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
/**
* \brief Return numerical bin value given bin index.
*/
static float NumericBinValue(std::vector<std::uint32_t> const& ptrs,
std::vector<float> const& vals, std::vector<float> const& mins,
bst_feature_t fidx, bst_bin_t bin_idx) {
auto lower = static_cast<bst_bin_t>(ptrs[fidx]);
if (bin_idx == lower) {
return mins[fidx];
}
return vals[bin_idx - 1];
}
};
/**