try predpath
This commit is contained in:
parent
75aa5bd258
commit
19a1ee24a5
@ -107,7 +107,7 @@ class GBTree : public IGradBooster {
|
||||
int64_t buffer_offset,
|
||||
const BoosterInfo &info,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0) {
|
||||
unsigned ntree_limit = 0) {
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
@ -117,6 +117,10 @@ class GBTree : public IGradBooster {
|
||||
for (int i = 0; i < nthread; ++i) {
|
||||
thread_temp[i].Init(mparam.num_feature);
|
||||
}
|
||||
if (tparam.pred_path != 0) {
|
||||
this->PredPath(p_fmat, info, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<float> &preds = *out_preds;
|
||||
const size_t stride = info.num_row * mparam.num_output_group;
|
||||
@ -144,7 +148,7 @@ class GBTree : public IGradBooster {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||
std::vector<std::string> dump;
|
||||
for (size_t i = 0; i < trees.size(); i++) {
|
||||
@ -258,6 +262,34 @@ class GBTree : public IGradBooster {
|
||||
out_pred[stride * (i + 1)] = vec_psum[i];
|
||||
}
|
||||
}
|
||||
// predict independent leaf index
|
||||
inline void PredPath(IFMatrix *p_fmat,
|
||||
const BoosterInfo &info,
|
||||
std::vector<float> *out_preds) {
|
||||
std::vector<float> &preds = *out_preds;
|
||||
preds.resize(info.num_row * mparam.num_trees);
|
||||
// start collecting the prediction
|
||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
// parallel over local batch
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
tree::RegTree::FVec &feats = thread_temp[tid];
|
||||
feats.Fill(batch[i]);
|
||||
for (size_t j = 0; j < trees.size(); ++j) {
|
||||
int tid = trees[i]->GetLeafIndex(feats, info.GetRoot(ridx));
|
||||
preds[ridx * mparam.num_trees + j] = static_cast<float>(tid);
|
||||
}
|
||||
feats.Drop(batch[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- data structure ---
|
||||
/*! \brief training parameters */
|
||||
struct TrainParam {
|
||||
@ -268,6 +300,8 @@ class GBTree : public IGradBooster {
|
||||
* use this option to support boosted random forest
|
||||
*/
|
||||
int num_parallel_tree;
|
||||
/*! \brief predict path in prediction */
|
||||
int pred_path;
|
||||
/*! \brief whether updater is already initialized */
|
||||
int updater_initialized;
|
||||
/*! \brief tree updater sequence */
|
||||
@ -278,6 +312,7 @@ class GBTree : public IGradBooster {
|
||||
updater_seq = "grow_colmaker,prune";
|
||||
num_parallel_tree = 1;
|
||||
updater_initialized = 0;
|
||||
pred_path = 0;
|
||||
}
|
||||
inline void SetParam(const char *name, const char *val){
|
||||
using namespace std;
|
||||
@ -292,6 +327,7 @@ class GBTree : public IGradBooster {
|
||||
if (!strcmp(name, "num_parallel_tree")) {
|
||||
num_parallel_tree = atoi(val);
|
||||
}
|
||||
if (!strcmp(name, "pred_path")) pred_path = atoi(val);
|
||||
}
|
||||
};
|
||||
/*! \brief model parameters */
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user