Small cleanups to various data types. (#8086)

- Use `bst_bin_t` in batch param constructor.
- Use `StringView` to avoid `std::string` when appropriate.
- Avoid using `MetaInfo` in quantile constructor to limit the scope of parameter.
This commit is contained in:
Jiaming Yuan
2022-07-18 22:39:36 +08:00
committed by GitHub
parent e28f6f6657
commit 4083440690
9 changed files with 52 additions and 53 deletions

View File

@@ -49,14 +49,14 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b
}
if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
n_threads);
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
} else {
SortedSketchContainer container{max_bins, m->Info(), reduced,
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);

View File

@@ -86,7 +86,7 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
template <typename Batch>
void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid,
MetaInfo const &info, size_t nnz, float missing) {
MetaInfo const &info, float missing) {
auto const &h_weights =
(use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector());
@@ -94,14 +94,14 @@ void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid
auto weights = OptionalWeights{Span<float const>{h_weights}};
// the nnz from info is not reliable as sketching might be the first place to go through
// the data.
auto is_dense = nnz == info.num_col_ * info.num_row_;
this->PushRowPageImpl(batch, base_rowid, weights, nnz, info.num_col_, is_dense, is_valid);
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
this->PushRowPageImpl(batch, base_rowid, weights, info.num_nonzero_, info.num_col_, is_dense,
is_valid);
}
#define INSTANTIATE(_type) \
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
data::_type const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz, \
float missing);
#define INSTANTIATE(_type) \
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
data::_type const &batch, size_t base_rowid, MetaInfo const &info, float missing);
INSTANTIATE(ArrayAdapterBatch)
INSTANTIATE(CSRArrayAdapterBatch)
@@ -436,11 +436,10 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
template class SketchContainerImpl<WQuantileSketch<float, float>>;
template class SketchContainerImpl<WXQuantileSketch<float, float>>;
HostSketchContainer::HostSketchContainer(int32_t max_bins, MetaInfo const &info,
HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group,
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
n_threads} {
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
monitor_.Init(__func__);
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);

View File

@@ -903,12 +903,11 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl
using WQSketch = WQuantileSketch<float, float>;
public:
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
bool use_group, int32_t n_threads);
HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, int32_t n_threads);
template <typename Batch>
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz,
float missing);
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
};
/**
@@ -1000,13 +999,12 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
public:
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group,
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
n_threads} {
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
monitor_.Init(__func__);
sketches_.resize(info.num_col_);
sketches_.resize(columns_size.size());
size_t i = 0;
for (auto &sketch : sketches_) {
sketch.sketch = &Super::sketches_[i];