diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index 9335ef8e7..c06dc51a1 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -138,10 +138,7 @@ class GBTree : public IGradBooster { { nthread = omp_get_num_threads(); } - thread_temp.resize(nthread, tree::RegTree::FVec()); - for (int i = 0; i < nthread; ++i) { - thread_temp[i].Init(mparam.num_feature); - } + InitThreadTemp(nthread); std::vector &preds = *out_preds; const size_t stride = info.num_row * mparam.num_output_group; preds.resize(stride * (mparam.size_leaf_vector+1)); @@ -194,10 +191,7 @@ class GBTree : public IGradBooster { { nthread = omp_get_num_threads(); } - thread_temp.resize(nthread, tree::RegTree::FVec()); - for (int i = 0; i < nthread; ++i) { - thread_temp[i].Init(mparam.num_feature); - } + InitThreadTemp(nthread); this->PredPath(p_fmat, info, out_preds, ntree_limit); } virtual std::vector DumpModel(const utils::FeatMap& fmap, int option) { @@ -391,6 +385,16 @@ class GBTree : public IGradBooster { } } } + // init thread buffers + inline void InitThreadTemp(int nthread) { + int prev_thread_temp_size = thread_temp.size(); + if (prev_thread_temp_size < nthread) { + thread_temp.resize(nthread, tree::RegTree::FVec()); + for (int i = prev_thread_temp_size; i < nthread; ++i) { + thread_temp[i].Init(mparam.num_feature); + } + } + } // --- data structure --- /*! \brief training parameters */