Small cleanups to various data types. (#8086)
- Use `bst_bin_t` in batch param constructor. - Use `StringView` to avoid `std::string` when appropriate. - Avoid using `MetaInfo` in quantile constructor to limit the scope of parameter.
This commit is contained in:
@@ -49,14 +49,14 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
|
||||
}
|
||||
|
||||
if (!use_sorted) {
|
||||
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
|
||||
n_threads);
|
||||
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info), n_threads);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(&out);
|
||||
} else {
|
||||
SortedSketchContainer container{max_bins, m->Info(), reduced,
|
||||
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
|
||||
HostSketchContainer::UseGroup(info), n_threads};
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
|
||||
@@ -86,7 +86,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
|
||||
|
||||
template <typename Batch>
|
||||
void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid,
|
||||
MetaInfo const &info, size_t nnz, float missing) {
|
||||
MetaInfo const &info, float missing) {
|
||||
auto const &h_weights =
|
||||
(use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector());
|
||||
|
||||
@@ -94,14 +94,14 @@ void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid
|
||||
auto weights = OptionalWeights{Span<float const>{h_weights}};
|
||||
// the nnz from info is not reliable as sketching might be the first place to go through
|
||||
// the data.
|
||||
auto is_dense = nnz == info.num_col_ * info.num_row_;
|
||||
this->PushRowPageImpl(batch, base_rowid, weights, nnz, info.num_col_, is_dense, is_valid);
|
||||
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
|
||||
this->PushRowPageImpl(batch, base_rowid, weights, info.num_nonzero_, info.num_col_, is_dense,
|
||||
is_valid);
|
||||
}
|
||||
|
||||
#define INSTANTIATE(_type) \
|
||||
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
|
||||
data::_type const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz, \
|
||||
float missing);
|
||||
#define INSTANTIATE(_type) \
|
||||
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
|
||||
data::_type const &batch, size_t base_rowid, MetaInfo const &info, float missing);
|
||||
|
||||
INSTANTIATE(ArrayAdapterBatch)
|
||||
INSTANTIATE(CSRArrayAdapterBatch)
|
||||
@@ -436,11 +436,10 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
|
||||
template class SketchContainerImpl<WQuantileSketch<float, float>>;
|
||||
template class SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
HostSketchContainer::HostSketchContainer(int32_t max_bins, MetaInfo const &info,
|
||||
HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group,
|
||||
int32_t n_threads)
|
||||
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
|
||||
n_threads} {
|
||||
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
|
||||
monitor_.Init(__func__);
|
||||
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
|
||||
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
|
||||
|
||||
@@ -903,12 +903,11 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
|
||||
using WQSketch = WQuantileSketch<float, float>;
|
||||
|
||||
public:
|
||||
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
|
||||
bool use_group, int32_t n_threads);
|
||||
HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group, int32_t n_threads);
|
||||
|
||||
template <typename Batch>
|
||||
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz,
|
||||
float missing);
|
||||
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -1000,13 +999,12 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
|
||||
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
|
||||
|
||||
public:
|
||||
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
|
||||
explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
|
||||
std::vector<size_t> columns_size, bool use_group,
|
||||
int32_t n_threads)
|
||||
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
|
||||
n_threads} {
|
||||
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
|
||||
monitor_.Init(__func__);
|
||||
sketches_.resize(info.num_col_);
|
||||
sketches_.resize(columns_size.size());
|
||||
size_t i = 0;
|
||||
for (auto &sketch : sketches_) {
|
||||
sketch.sketch = &Super::sketches_[i];
|
||||
|
||||
Reference in New Issue
Block a user