chg root index to booster info, need review

This commit is contained in:
tqchen
2014-08-22 16:26:37 -07:00
parent 58d74861b9
commit 58354643b0
10 changed files with 48 additions and 44 deletions

View File

@@ -43,12 +43,11 @@ class IGradBooster {
* \brief peform update to the model(boosting)
* \param gpair the gradient pair statistics of the data
* \param fmat feature matrix that provide access to features
* \param root_index pre-partitioned root_index of each instance,
* root_index.size() can be 0 which indicates that no pre-partition involved
* \param info meta information about training
*/
virtual void DoBoost(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index) = 0;
const BoosterInfo &info) = 0;
/*!
* \brief generate predictions for given feature matrix
* \param fmat feature matrix
@@ -56,13 +55,12 @@ class IGradBooster {
* this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction
* the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size")
* \param root_index pre-partitioned root_index of each instance,
* root_index.size() can be 0 which indicates that no pre-partition involved
* \param info extra side information that may be needed for prediction
* \param out_preds output vector to hold the predictions
*/
virtual void Predict(const FMatrix &fmat,
int64_t buffer_offset,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
std::vector<float> *out_preds) = 0;
/*!
* \brief dump the model in text format

View File

@@ -84,9 +84,9 @@ class GBTree : public IGradBooster<FMatrix> {
}
virtual void DoBoost(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index) {
const BoosterInfo &info) {
if (mparam.num_output_group == 1) {
this->BoostNewTrees(gpair, fmat, root_index, 0);
this->BoostNewTrees(gpair, fmat, info, 0);
} else {
const int ngroup = mparam.num_output_group;
utils::Check(gpair.size() % ngroup == 0,
@@ -97,13 +97,13 @@ class GBTree : public IGradBooster<FMatrix> {
for (size_t i = 0; i < tmp.size(); ++i) {
tmp[i] = gpair[i * ngroup + gid];
}
this->BoostNewTrees(tmp, fmat, root_index, gid);
this->BoostNewTrees(tmp, fmat, info, gid);
}
}
}
virtual void Predict(const FMatrix &fmat,
int64_t buffer_offset,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
std::vector<float> *out_preds) {
int nthread;
#pragma omp parallel
@@ -134,7 +134,7 @@ class GBTree : public IGradBooster<FMatrix> {
const int tid = omp_get_thread_num();
tree::RegTree::FVec &feats = thread_temp[tid];
const size_t ridx = batch.base_rowid + i;
const unsigned root_idx = root_index.size() == 0 ? 0 : root_index[ridx];
const unsigned root_idx = info.GetRoot(i);
// loop over output groups
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
preds[ridx * mparam.num_output_group + gid] =
@@ -186,7 +186,7 @@ class GBTree : public IGradBooster<FMatrix> {
// do group specific group
inline void BoostNewTrees(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
int bst_group) {
this->InitUpdater();
// create the trees
@@ -200,7 +200,7 @@ class GBTree : public IGradBooster<FMatrix> {
}
// update the trees
for (size_t i = 0; i < updaters.size(); ++i) {
updaters[i]->Update(gpair, fmat, root_index, new_trees);
updaters[i]->Update(gpair, fmat, info, new_trees);
}
// push back to model
for (size_t i = 0; i < new_trees.size(); ++i) {