Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -164,34 +164,30 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
|
||||
return values[gidx];
|
||||
}
|
||||
|
||||
auto lower = static_cast<bst_bin_t>(cut.Ptrs()[fidx]);
|
||||
auto get_bin_idx = [&](auto &column) {
|
||||
auto get_bin_val = [&](auto &column) {
|
||||
auto bin_idx = column[ridx];
|
||||
if (bin_idx == common::DenseColumnIter<uint8_t, true>::kMissingId) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
if (bin_idx == lower) {
|
||||
return mins[fidx];
|
||||
}
|
||||
return values[bin_idx - 1];
|
||||
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
|
||||
};
|
||||
|
||||
if (columns_->GetColumnType(fidx) == common::kDenseColumn) {
|
||||
if (columns_->AnyMissing()) {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
|
||||
return get_bin_idx(column);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
|
||||
return get_bin_idx(column);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
|
||||
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
|
||||
return get_bin_idx(column);
|
||||
return get_bin_val(column);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user