Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -164,34 +164,30 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
return values[gidx];
}
auto lower = static_cast<bst_bin_t>(cut.Ptrs()[fidx]);
auto get_bin_idx = [&](auto &column) {
auto get_bin_val = [&](auto &column) {
auto bin_idx = column[ridx];
if (bin_idx == common::DenseColumnIter<uint8_t, true>::kMissingId) {
return std::numeric_limits<float>::quiet_NaN();
}
if (bin_idx == lower) {
return mins[fidx];
}
return values[bin_idx - 1];
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
};
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
if (columns_->AnyMissing()) {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
return get_bin_idx(column);
return get_bin_val(column);
});
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
return get_bin_idx(column);
return get_bin_val(column);
});
}
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
return get_bin_idx(column);
return get_bin_val(column);
});
}