[backport] Optimize prediction with QuantileDMatrix. (#9096) (#9303)

This commit is contained in:
Jiaming Yuan 2023-06-15 23:32:03 +08:00 committed by GitHub
parent 573f1c7db4
commit 39ddf40a8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 37 deletions

View File

@ -508,7 +508,7 @@ class RegTree : public Model {
* \brief drop the trace after fill, must be called after fill. * \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instance to drop. * \param inst The sparse instance to drop.
*/ */
void Drop(const SparsePage::Inst& inst); void Drop();
/*! /*!
* \brief returns the size of the feature vector * \brief returns the size of the feature vector
* \return the size of the feature vector * \return the size of the feature vector
@ -709,13 +709,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
has_missing_ = data_.size() != feature_count; has_missing_ = data_.size() != feature_count;
} }
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) { inline void RegTree::FVec::Drop() {
for (auto const& entry : inst) { Entry e{};
if (entry.index >= data_.size()) { e.flag = -1;
continue; std::fill_n(data_.data(), data_.size(), e);
}
data_[entry.index].flag = -1;
}
has_missing_ = true; has_missing_ = true;
} }

View File

@ -149,10 +149,28 @@ common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
return *columns_; return *columns_;
} }
bst_bin_t GHistIndexMatrix::GetGindex(size_t ridx, size_t fidx) const {
auto begin = RowIdx(ridx);
if (IsDense()) {
return static_cast<bst_bin_t>(index[begin + fidx]);
}
auto end = RowIdx(ridx + 1);
auto const& cut_ptrs = cut.Ptrs();
auto f_begin = cut_ptrs[fidx];
auto f_end = cut_ptrs[fidx + 1];
return BinarySearchBin(begin, end, index, f_begin, f_end);
}
float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const { float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
auto const &values = cut.Values(); auto const &values = cut.Values();
auto const &mins = cut.MinValues(); auto const &mins = cut.MinValues();
auto const &ptrs = cut.Ptrs(); auto const &ptrs = cut.Ptrs();
return this->GetFvalue(ptrs, values, mins, ridx, fidx, is_cat);
}
float GHistIndexMatrix::GetFvalue(std::vector<std::uint32_t> const &ptrs,
std::vector<float> const &values, std::vector<float> const &mins,
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const {
if (is_cat) { if (is_cat) {
auto f_begin = ptrs[fidx]; auto f_begin = ptrs[fidx];
auto f_end = ptrs[fidx + 1]; auto f_end = ptrs[fidx + 1];
@ -172,24 +190,27 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
} }
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx); return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
}; };
switch (columns_->GetColumnType(fidx)) {
if (columns_->GetColumnType(fidx) == common::kDenseColumn) { case common::kDenseColumn: {
if (columns_->AnyMissing()) { if (columns_->AnyMissing()) {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx);
return get_bin_val(column);
});
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
auto bin_idx = column[ridx];
return common::HistogramCuts::NumericBinValue(ptrs, values, mins, fidx, bin_idx);
});
}
}
case common::kSparseColumn: {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), true>(fidx); auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
return get_bin_val(column);
});
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->DenseColumn<decltype(dtype), false>(fidx);
return get_bin_val(column); return get_bin_val(column);
}); });
} }
} else {
return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) {
auto column = columns_->SparseColumn<decltype(dtype)>(fidx, 0);
return get_bin_val(column);
});
} }
SPAN_CHECK(false); SPAN_CHECK(false);

View File

@ -227,7 +227,12 @@ class GHistIndexMatrix {
common::ColumnMatrix const& Transpose() const; common::ColumnMatrix const& Transpose() const;
bst_bin_t GetGindex(size_t ridx, size_t fidx) const;
float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const; float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
float GetFvalue(std::vector<std::uint32_t> const& ptrs, std::vector<float> const& values,
std::vector<float> const& mins, bst_row_t ridx, bst_feature_t fidx,
bool is_cat) const;
private: private:
std::unique_ptr<common::ColumnMatrix> columns_; std::unique_ptr<common::ColumnMatrix> columns_;

View File

@ -63,7 +63,7 @@ bst_float PredValue(const SparsePage::Inst &inst,
psum += (*trees[i])[nidx].LeafValue(); psum += (*trees[i])[nidx].LeafValue();
} }
} }
p_feats->Drop(inst); p_feats->Drop();
return psum; return psum;
} }
@ -116,13 +116,11 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
} }
} }
template <typename DataView> void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset,
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch, std::vector<RegTree::FVec> *p_feats) {
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
for (size_t i = 0; i < block_size; ++i) { for (size_t i = 0; i < block_size; ++i) {
RegTree::FVec &feats = (*p_feats)[fvec_offset + i]; RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
const SparsePage::Inst inst = (*batch)[batch_offset + i]; feats.Drop();
feats.Drop(inst);
} }
} }
@ -142,11 +140,15 @@ struct SparsePageView {
struct GHistIndexMatrixView { struct GHistIndexMatrixView {
private: private:
GHistIndexMatrix const &page_; GHistIndexMatrix const &page_;
uint64_t n_features_; std::uint64_t const n_features_;
common::Span<FeatureType const> ft_; common::Span<FeatureType const> ft_;
common::Span<Entry> workspace_; common::Span<Entry> workspace_;
std::vector<size_t> current_unroll_; std::vector<size_t> current_unroll_;
std::vector<std::uint32_t> const& ptrs_;
std::vector<float> const& mins_;
std::vector<float> const& values_;
public: public:
size_t base_rowid; size_t base_rowid;
@ -159,6 +161,9 @@ struct GHistIndexMatrixView {
ft_{ft}, ft_{ft},
workspace_{workplace}, workspace_{workplace},
current_unroll_(n_threads > 0 ? n_threads : 1, 0), current_unroll_(n_threads > 0 ? n_threads : 1, 0),
ptrs_{_page.cut.Ptrs()},
mins_{_page.cut.MinValues()},
values_{_page.cut.Values()},
base_rowid{_page.base_rowid} {} base_rowid{_page.base_rowid} {}
SparsePage::Inst operator[](size_t r) { SparsePage::Inst operator[](size_t r) {
@ -167,7 +172,7 @@ struct GHistIndexMatrixView {
size_t non_missing{beg}; size_t non_missing{beg};
for (bst_feature_t c = 0; c < n_features_; ++c) { for (bst_feature_t c = 0; c < n_features_; ++c) {
float f = page_.GetFvalue(r, c, common::IsCat(ft_, c)); float f = page_.GetFvalue(ptrs_, values_, mins_, r, c, common::IsCat(ft_, c));
if (!common::CheckNAN(f)) { if (!common::CheckNAN(f)) {
workspace_[non_missing] = Entry{c, f}; workspace_[non_missing] = Entry{c, f};
++non_missing; ++non_missing;
@ -250,10 +255,9 @@ void PredictBatchByBlockOfRowsKernel(
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
p_thread_temp); p_thread_temp);
// process block of rows through all trees to keep cache locality // process block of rows through all trees to keep cache locality
PredictByAllTrees(model, tree_begin, tree_end, out_preds, PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
batch_offset + batch.base_rowid, num_group, thread_temp, num_group, thread_temp, fvec_offset, block_size);
fvec_offset, block_size); FVecDrop(block_size, fvec_offset, p_thread_temp);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
}); });
} }
@ -470,7 +474,7 @@ class CPUPredictor : public Predictor {
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats); bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid); preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
} }
feats.Drop(page[i]); feats.Drop();
}); });
} }
} }
@ -544,7 +548,7 @@ class CPUPredictor : public Predictor {
(tree_weights == nullptr ? 1 : (*tree_weights)[j]); (tree_weights == nullptr ? 1 : (*tree_weights)[j]);
} }
} }
feats.Drop(page[i]); feats.Drop();
// add base margin to BIAS // add base margin to BIAS
if (base_margin.Size() != 0) { if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Shape(1), ngroup); CHECK_EQ(base_margin.Shape(1), ngroup);

View File

@ -89,7 +89,7 @@ class TreeRefresher : public TreeUpdater {
dmlc::BeginPtr(stemp[tid]) + offset); dmlc::BeginPtr(stemp[tid]) + offset);
offset += tree->param.num_nodes; offset += tree->param.num_nodes;
} }
feats.Drop(inst); feats.Drop();
}); });
} }
// aggregate the statistics // aggregate the statistics