diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index ed52afa7d..6688e3829 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -107,7 +107,7 @@ class GBTree : public IGradBooster { int64_t buffer_offset, const BoosterInfo &info, std::vector *out_preds, - unsigned ntree_limit = 0) { + unsigned ntree_limit = 0) { int nthread; #pragma omp parallel { @@ -117,6 +117,10 @@ class GBTree : public IGradBooster { for (int i = 0; i < nthread; ++i) { thread_temp[i].Init(mparam.num_feature); } + if (tparam.pred_path != 0) { + this->PredPath(p_fmat, info, out_preds); + return; + } std::vector &preds = *out_preds; const size_t stride = info.num_row * mparam.num_output_group; @@ -144,7 +148,7 @@ class GBTree : public IGradBooster { } } } - } + } virtual std::vector DumpModel(const utils::FeatMap& fmap, int option) { std::vector dump; for (size_t i = 0; i < trees.size(); i++) { @@ -258,6 +262,34 @@ class GBTree : public IGradBooster { out_pred[stride * (i + 1)] = vec_psum[i]; } } + // predict independent leaf index + inline void PredPath(IFMatrix *p_fmat, + const BoosterInfo &info, + std::vector *out_preds) { + std::vector &preds = *out_preds; + preds.resize(info.num_row * mparam.num_trees); + // start collecting the prediction + utils::IIterator *iter = p_fmat->RowIterator(); + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch &batch = iter->Value(); + // parallel over local batch + const bst_omp_uint nsize = static_cast(batch.size); + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize; ++i) { + const int tid = omp_get_thread_num(); + int64_t ridx = static_cast(batch.base_rowid + i); + tree::RegTree::FVec &feats = thread_temp[tid]; + feats.Fill(batch[i]); + for (size_t j = 0; j < trees.size(); ++j) { + int tid = trees[i]->GetLeafIndex(feats, info.GetRoot(ridx)); + preds[ridx * mparam.num_trees + j] = static_cast(tid); + } + feats.Drop(batch[i]); + } + } + } + // --- data structure --- /*! \brief training parameters */ struct TrainParam { @@ -268,6 +300,8 @@ class GBTree : public IGradBooster { * use this option to support boosted random forest */ int num_parallel_tree; + /*! \brief predict path in prediction */ + int pred_path; /*! \brief whether updater is already initialized */ int updater_initialized; /*! \brief tree updater sequence */ @@ -278,6 +312,7 @@ class GBTree : public IGradBooster { updater_seq = "grow_colmaker,prune"; num_parallel_tree = 1; updater_initialized = 0; + pred_path = 0; } inline void SetParam(const char *name, const char *val){ using namespace std; @@ -292,6 +327,7 @@ class GBTree : public IGradBooster { if (!strcmp(name, "num_parallel_tree")) { num_parallel_tree = atoi(val); } + if (!strcmp(name, "pred_path")) pred_path = atoi(val); } }; /*! \brief model parameters */