CPU predict performance improvement (#6127)

Co-authored-by: ShvetsKS <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS 2020-10-08 15:50:21 +03:00 committed by GitHub
parent 4cfdcaaf7b
commit a4ce0eae43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 56 deletions

View File

@ -457,6 +457,7 @@ class RegTree : public Model {
}
return depth;
}
/*!
* \brief get maximum depth
* \param nid node id
@ -498,6 +499,7 @@ class RegTree : public Model {
* \param inst The sparse instance to fill.
*/
void Fill(const SparsePage::Inst& inst);
/*!
* \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instance to drop.
@ -520,6 +522,8 @@ class RegTree : public Model {
* \return whether i-th value is missing.
*/
bool IsMissing(size_t i) const;
bool HasMissing() const;
private:
/*!
@ -531,13 +535,16 @@ class RegTree : public Model {
int flag;
};
std::vector<Entry> data_;
bool has_missing_;
};
/*!
* \brief get the leaf index
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \return the leaf index of the given feature
*/
template <bool has_missing = true>
int GetLeafIndex(const FVec& feat) const;
/*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
* \param feat dense feature vector, if the feature is missing the field is set to NaN
@ -581,6 +588,7 @@ class RegTree : public Model {
* \param fvalue feature value if not missing.
* \param is_unknown Whether current required feature is missing.
*/
template <bool has_missing = true>
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
/*!
* \brief dump the model in the requested format as a text string
@ -676,15 +684,19 @@ inline void RegTree::FVec::Init(size_t size) {
Entry e; e.flag = -1;
data_.resize(size);
std::fill(data_.begin(), data_.end(), e);
has_missing_ = true;
}
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
size_t feature_count = 0;
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].fvalue = entry.fvalue;
++feature_count;
}
has_missing_ = data_.size() != feature_count;
}
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
@ -694,6 +706,7 @@ inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
}
data_[entry.index].flag = -1;
}
has_missing_ = true;
}
inline size_t RegTree::FVec::Size() const {
@ -708,27 +721,41 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
return data_[i].flag == -1;
}
inline bool RegTree::FVec::HasMissing() const {
return has_missing_;
}
template <bool has_missing>
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
nid = this->GetNext<has_missing>(nid, feat.GetFvalue(split_index),
has_missing && feat.IsMissing(split_index));
}
return nid;
}
/*! \brief get next position of the tree given current pid */
template <bool has_missing>
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
bst_float split_value = (*this)[pid].SplitCond();
if (is_unknown) {
return (*this)[pid].DefaultChild();
} else {
if (fvalue < split_value) {
return (*this)[pid].LeftChild();
if (has_missing) {
if (is_unknown) {
return (*this)[pid].DefaultChild();
} else {
return (*this)[pid].RightChild();
if (fvalue < (*this)[pid].SplitCond()) {
return (*this)[pid].LeftChild();
} else {
return (*this)[pid].RightChild();
}
}
} else {
// 35% speed up due to reduced miss branch predictions
// The following expression returns the left child if (fvalue < split_cond);
// the right child otherwise.
return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond());
}
}
} // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_

View File

@ -42,6 +42,47 @@ bst_float PredValue(const SparsePage::Inst &inst,
return psum;
}
inline bst_float PredValueByOneTree(const RegTree::FVec& p_feats,
const std::unique_ptr<RegTree>& tree) {
const int lid = p_feats.HasMissing() ? tree->GetLeafIndex<true>(p_feats) :
tree->GetLeafIndex<false>(p_feats); // 35% speed up
return (*tree)[lid].LeafValue();
}
inline void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, std::vector<bst_float>* out_preds,
const size_t predict_offset, const size_t num_group,
const std::vector<RegTree::FVec> &thread_temp,
const size_t offset, const size_t block_size) {
std::vector<bst_float> &preds = *out_preds;
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
const size_t gid = model.tree_info[tree_id];
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] += PredValueByOneTree(thread_temp[offset + i],
model.trees[tree_id]);
}
}
}
template <typename DataView>
void FVecFill(const size_t block_size, const size_t batch_offset, DataView* batch,
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
for (size_t i = 0; i < block_size; ++i) {
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
const SparsePage::Inst inst = (*batch)[batch_offset + i];
feats.Fill(inst);
}
}
template <typename DataView>
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
for (size_t i = 0; i < block_size; ++i) {
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
const SparsePage::Inst inst = (*batch)[batch_offset + i];
feats.Drop(inst);
}
}
template <size_t kUnrollLen = 8>
struct SparsePageView {
bst_row_t base_rowid;
@ -99,52 +140,31 @@ class AdapterView {
bst_row_t const static base_rowid = 0; // NOLINT
};
template <typename DataView>
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp) {
template <typename DataView, size_t block_of_rows_size>
void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp) {
auto& thread_temp = *p_thread_temp;
int32_t const num_group = model.learner_model_param->num_output_group;
std::vector<bst_float> &preds = *out_preds;
CHECK_EQ(model.param.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far";
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
auto constexpr kUnroll = DataView::kUnroll;
const bst_omp_uint rest = nsize % kUnroll;
if (nsize >= kUnroll) {
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
const int tid = omp_get_thread_num();
RegTree::FVec &feats = thread_temp[tid];
int64_t ridx[kUnroll];
SparsePage::Inst inst[kUnroll];
for (size_t k = 0; k < kUnroll; ++k) {
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
}
for (size_t k = 0; k < kUnroll; ++k) {
inst[k] = batch[i + k];
}
for (size_t k = 0; k < kUnroll; ++k) {
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx[k] * num_group + gid;
preds[offset] += PredValue(inst[k], model.trees, model.tree_info, gid,
&feats, tree_begin, tree_end);
}
}
}
}
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
RegTree::FVec &feats = thread_temp[0];
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
auto inst = batch[i];
for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx * num_group + gid;
preds[offset] += PredValue(inst, model.trees, model.tree_info, gid,
&feats, tree_begin, tree_end);
}
const bst_omp_uint n_row_blocks = (nsize) / block_of_rows_size + !!((nsize) % block_of_rows_size);
#pragma omp parallel for schedule(guided)
for (bst_omp_uint block_id = 0; block_id < n_row_blocks; ++block_id) {
const size_t batch_offset = block_id * block_of_rows_size;
const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size);
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;
FVecFill(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
// process block of rows through all trees to keep cache locality
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
num_group, thread_temp, fvec_offset, block_size);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
}
}
@ -166,13 +186,16 @@ class CPUPredictor : public Predictor {
int32_t tree_end) {
std::lock_guard<std::mutex> guard(lock_);
const int threads = omp_get_max_threads();
InitThreadTemp(threads, model.learner_model_param->num_feature, &this->thread_temp_);
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature,
&this->thread_temp_);
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(),
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group);
size_t constexpr kUnroll = 8;
PredictBatchKernel(SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
tree_end, &thread_temp_);
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
kBlockOfRowsSize>(SparsePageView<kUnroll>{&batch},
out_preds, model, tree_begin,
tree_end, &thread_temp_);
}
}
@ -279,11 +302,12 @@ class CPUPredictor : public Predictor {
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
auto &predictions = out_preds->predictions.HostVector();
std::vector<RegTree::FVec> thread_temp;
InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp);
size_t constexpr kUnroll = 8;
PredictBatchKernel(AdapterView<Adapter, kUnroll>(
m.get(), missing, common::Span<Entry>{workspace}),
&predictions, model, tree_begin, tree_end, &thread_temp);
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature,
&thread_temp);
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>,
kBlockOfRowsSize>(AdapterView<Adapter>(
m.get(), missing, common::Span<Entry>{workspace}),
&predictions, model, tree_begin, tree_end, &thread_temp);
}
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
@ -477,6 +501,7 @@ class CPUPredictor : public Predictor {
private:
std::mutex lock_;
std::vector<RegTree::FVec> thread_temp_;
static size_t constexpr kBlockOfRowsSize = 64;
};
XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor")

View File

@ -1155,7 +1155,7 @@ void QuantileHistMaker::Builder<GradientSumT>::AddSplitsToRowSet(
const int32_t nid = nodes[i].nid;
const size_t n_left = partition_builder_.GetNLeftElems(i);
const size_t n_right = partition_builder_.GetNRightElems(i);
CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild());
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(),
(*p_tree)[nid].RightChild(), n_left, n_right);
}