fix base score, and print message
This commit is contained in:
@@ -19,6 +19,7 @@ class TreePruner: public IUpdater<FMatrix> {
|
||||
// set training parameter
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
param.SetParam(name, val);
|
||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
||||
}
|
||||
// update the tree, do pruning
|
||||
virtual void Update(const std::vector<bst_gpair> &gpair,
|
||||
@@ -32,33 +33,41 @@ class TreePruner: public IUpdater<FMatrix> {
|
||||
|
||||
private:
|
||||
// try to prune off current leaf
|
||||
inline void TryPruneLeaf(RegTree &tree, int nid, int depth) {
|
||||
if (tree[nid].is_root()) return;
|
||||
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) {
|
||||
if (tree[nid].is_root()) return npruned;
|
||||
int pid = tree[nid].parent();
|
||||
RegTree::NodeStat &s = tree.stat(pid);
|
||||
++s.leaf_child_cnt;
|
||||
|
||||
if (s.leaf_child_cnt >= 2 && param.need_prune(s.loss_chg, depth - 1)) {
|
||||
// need to be pruned
|
||||
tree.ChangeToLeaf(pid, param.learning_rate * s.base_weight);
|
||||
// tail recursion
|
||||
this->TryPruneLeaf(tree, pid, depth - 1);
|
||||
}
|
||||
return this->TryPruneLeaf(tree, pid, depth - 1, npruned+2);
|
||||
} else {
|
||||
return npruned;
|
||||
}
|
||||
}
|
||||
/*! \brief do prunning of a tree */
|
||||
inline void DoPrune(RegTree &tree) {
|
||||
int npruned = 0;
|
||||
// initialize auxiliary statistics
|
||||
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
|
||||
tree.stat(nid).leaf_child_cnt = 0;
|
||||
}
|
||||
for (int nid = 0; nid < tree.param.num_nodes; ++nid) {
|
||||
if (tree[nid].is_leaf()) {
|
||||
this->TryPruneLeaf(tree, nid, tree.GetDepth(nid));
|
||||
npruned = this->TryPruneLeaf(tree, nid, tree.GetDepth(nid), npruned);
|
||||
}
|
||||
}
|
||||
if (silent == 0) {
|
||||
printf("tree prunning end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
||||
tree.param.num_roots, tree.num_extra_nodes(), npruned, tree.MaxDepth());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// shutup
|
||||
int silent;
|
||||
// training parameter
|
||||
TrainParam param;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user