try predpath
This commit is contained in:
parent
75aa5bd258
commit
19a1ee24a5
@ -107,7 +107,7 @@ class GBTree : public IGradBooster {
|
|||||||
int64_t buffer_offset,
|
int64_t buffer_offset,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
std::vector<float> *out_preds,
|
std::vector<float> *out_preds,
|
||||||
unsigned ntree_limit = 0) {
|
unsigned ntree_limit = 0) {
|
||||||
int nthread;
|
int nthread;
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
@ -117,6 +117,10 @@ class GBTree : public IGradBooster {
|
|||||||
for (int i = 0; i < nthread; ++i) {
|
for (int i = 0; i < nthread; ++i) {
|
||||||
thread_temp[i].Init(mparam.num_feature);
|
thread_temp[i].Init(mparam.num_feature);
|
||||||
}
|
}
|
||||||
|
if (tparam.pred_path != 0) {
|
||||||
|
this->PredPath(p_fmat, info, out_preds);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<float> &preds = *out_preds;
|
std::vector<float> &preds = *out_preds;
|
||||||
const size_t stride = info.num_row * mparam.num_output_group;
|
const size_t stride = info.num_row * mparam.num_output_group;
|
||||||
@ -144,7 +148,7 @@ class GBTree : public IGradBooster {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
virtual std::vector<std::string> DumpModel(const utils::FeatMap& fmap, int option) {
|
||||||
std::vector<std::string> dump;
|
std::vector<std::string> dump;
|
||||||
for (size_t i = 0; i < trees.size(); i++) {
|
for (size_t i = 0; i < trees.size(); i++) {
|
||||||
@ -258,6 +262,34 @@ class GBTree : public IGradBooster {
|
|||||||
out_pred[stride * (i + 1)] = vec_psum[i];
|
out_pred[stride * (i + 1)] = vec_psum[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// predict independent leaf index
|
||||||
|
inline void PredPath(IFMatrix *p_fmat,
|
||||||
|
const BoosterInfo &info,
|
||||||
|
std::vector<float> *out_preds) {
|
||||||
|
std::vector<float> &preds = *out_preds;
|
||||||
|
preds.resize(info.num_row * mparam.num_trees);
|
||||||
|
// start collecting the prediction
|
||||||
|
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||||
|
iter->BeforeFirst();
|
||||||
|
while (iter->Next()) {
|
||||||
|
const RowBatch &batch = iter->Value();
|
||||||
|
// parallel over local batch
|
||||||
|
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||||
|
const int tid = omp_get_thread_num();
|
||||||
|
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||||
|
tree::RegTree::FVec &feats = thread_temp[tid];
|
||||||
|
feats.Fill(batch[i]);
|
||||||
|
for (size_t j = 0; j < trees.size(); ++j) {
|
||||||
|
int tid = trees[i]->GetLeafIndex(feats, info.GetRoot(ridx));
|
||||||
|
preds[ridx * mparam.num_trees + j] = static_cast<float>(tid);
|
||||||
|
}
|
||||||
|
feats.Drop(batch[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- data structure ---
|
// --- data structure ---
|
||||||
/*! \brief training parameters */
|
/*! \brief training parameters */
|
||||||
struct TrainParam {
|
struct TrainParam {
|
||||||
@ -268,6 +300,8 @@ class GBTree : public IGradBooster {
|
|||||||
* use this option to support boosted random forest
|
* use this option to support boosted random forest
|
||||||
*/
|
*/
|
||||||
int num_parallel_tree;
|
int num_parallel_tree;
|
||||||
|
/*! \brief predict path in prediction */
|
||||||
|
int pred_path;
|
||||||
/*! \brief whether updater is already initialized */
|
/*! \brief whether updater is already initialized */
|
||||||
int updater_initialized;
|
int updater_initialized;
|
||||||
/*! \brief tree updater sequence */
|
/*! \brief tree updater sequence */
|
||||||
@ -278,6 +312,7 @@ class GBTree : public IGradBooster {
|
|||||||
updater_seq = "grow_colmaker,prune";
|
updater_seq = "grow_colmaker,prune";
|
||||||
num_parallel_tree = 1;
|
num_parallel_tree = 1;
|
||||||
updater_initialized = 0;
|
updater_initialized = 0;
|
||||||
|
pred_path = 0;
|
||||||
}
|
}
|
||||||
inline void SetParam(const char *name, const char *val){
|
inline void SetParam(const char *name, const char *val){
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -292,6 +327,7 @@ class GBTree : public IGradBooster {
|
|||||||
if (!strcmp(name, "num_parallel_tree")) {
|
if (!strcmp(name, "num_parallel_tree")) {
|
||||||
num_parallel_tree = atoi(val);
|
num_parallel_tree = atoi(val);
|
||||||
}
|
}
|
||||||
|
if (!strcmp(name, "pred_path")) pred_path = atoi(val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
/*! \brief model parameters */
|
/*! \brief model parameters */
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user