Make prediction functions thread safe. (#6648)
This commit is contained in:
@@ -183,11 +183,11 @@ class CPUPredictor : public Predictor {
|
||||
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
int32_t tree_end) const {
|
||||
const int threads = omp_get_max_threads();
|
||||
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature,
|
||||
&this->thread_temp_);
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
model.learner_model_param->num_feature, &feat_vecs);
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group);
|
||||
@@ -195,7 +195,7 @@ class CPUPredictor : public Predictor {
|
||||
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
|
||||
kBlockOfRowsSize>(SparsePageView<kUnroll>{&batch},
|
||||
out_preds, model, tree_begin,
|
||||
tree_end, &thread_temp_);
|
||||
tree_end, &feat_vecs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +238,7 @@ class CPUPredictor : public Predictor {
|
||||
// multi-output and forest. Same problem exists for tree_begin
|
||||
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
uint32_t const ntree_limit = 0) override {
|
||||
uint32_t const ntree_limit = 0) const override {
|
||||
// tree_begin is not used, right now we just enforce it to be 0.
|
||||
CHECK_EQ(tree_begin, 0);
|
||||
auto* out_preds = &predts->predictions;
|
||||
@@ -326,11 +326,10 @@ class CPUPredictor : public Predictor {
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
thread_temp_[0].Init(model.learner_model_param->num_feature);
|
||||
}
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
feat_vecs.resize(1, RegTree::FVec());
|
||||
feat_vecs[0].Init(model.learner_model_param->num_feature);
|
||||
ntree_limit *= model.learner_model_param->num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||
@@ -340,15 +339,16 @@ class CPUPredictor : public Predictor {
|
||||
// loop over output groups
|
||||
for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) {
|
||||
(*out_preds)[gid] = PredValue(inst, model.trees, model.tree_info, gid,
|
||||
&thread_temp_[0], 0, ntree_limit) +
|
||||
&feat_vecs[0], 0, ntree_limit) +
|
||||
model.learner_model_param->base_score;
|
||||
}
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param->num_output_group;
|
||||
@@ -366,7 +366,7 @@ class CPUPredictor : public Predictor {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
auto ridx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec &feats = thread_temp_[tid];
|
||||
RegTree::FVec &feats = feat_vecs[tid];
|
||||
feats.Fill(page[i]);
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
||||
@@ -381,9 +381,10 @@ class CPUPredictor : public Predictor {
|
||||
const gbm::GBTreeModel& model, uint32_t ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
unsigned condition_feature) const override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &this->thread_temp_);
|
||||
std::vector<RegTree::FVec> feat_vecs;
|
||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param->num_output_group;
|
||||
@@ -414,7 +415,7 @@ class CPUPredictor : public Predictor {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
RegTree::FVec &feats = thread_temp_[omp_get_thread_num()];
|
||||
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
|
||||
std::vector<bst_float> this_tree_contribs(ncolumns);
|
||||
// loop over all classes
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
@@ -452,7 +453,7 @@ class CPUPredictor : public Predictor {
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate) override {
|
||||
bool approximate) const override {
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
size_t const ncolumns = model.learner_model_param->num_feature;
|
||||
@@ -501,8 +502,6 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
std::vector<RegTree::FVec> thread_temp_;
|
||||
static size_t constexpr kBlockOfRowsSize = 64;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user