chg root index to booster info, need review

This commit is contained in:
tqchen
2014-08-22 16:26:37 -07:00
parent 58d74861b9
commit 58354643b0
10 changed files with 48 additions and 44 deletions

View File

@@ -29,8 +29,7 @@ class IUpdater {
* \brief peform update to the tree models
* \param gpair the gradient pair statistics of the data
* \param fmat feature matrix that provide access to features
* \param root_index pre-partitioned root_index of each instance,
* root_index.size() can be 0 which indicates that no pre-partition involved
* \param info extra side information that may be need, such as root index
* \param trees pointer to the trese to be updated, upater will change the content of the tree
* note: all the trees in the vector are updated, with the same statistics,
* but maybe different random seeds, usually one tree is passed in at a time,
@@ -38,7 +37,7 @@ class IUpdater {
*/
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) = 0;
// destructor
virtual ~IUpdater(void) {}

View File

@@ -25,7 +25,7 @@ class ColMaker: public IUpdater<FMatrix> {
}
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees
float lr = param.learning_rate;
@@ -33,7 +33,7 @@ class ColMaker: public IUpdater<FMatrix> {
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param);
builder.Update(gpair, fmat, root_index, trees[i]);
builder.Update(gpair, fmat, info, trees[i]);
}
param.learning_rate = lr;
}
@@ -77,9 +77,9 @@ class ColMaker: public IUpdater<FMatrix> {
// update one tree, growing
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
RegTree *p_tree) {
this->InitData(gpair, fmat, root_index, *p_tree);
this->InitData(gpair, fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) {

View File

@@ -24,7 +24,7 @@ class TreePruner: public IUpdater<FMatrix> {
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees
float lr = param.learning_rate;

View File

@@ -24,7 +24,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
if (trees.size() == 0) return;
// number of threads
@@ -66,7 +66,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
feats.Fill(inst);
for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair[ridx],
root_index.size() == 0 ? 0 : root_index[ridx],
info.GetRoot(j),
&stemp[tid * trees.size() + j]);
}
feats.Drop(inst);