This commit is contained in:
tqchen
2014-03-03 11:05:10 -08:00
parent 074a861e7b
commit 623e003923
11 changed files with 172 additions and 34 deletions

View File

@@ -154,18 +154,20 @@ namespace xgboost{
if( compute ){
sum_grad += grad[ ridx ];
sum_hess += hess[ ridx ];
}
}
}
tree.stat( tsk.nid ).sum_hess = static_cast<float>( sum_hess );
tree[ tsk.nid ].set_leaf( param.learning_rate * param.CalcWeight( sum_grad, sum_hess, tsk.parent_base_weight ) );
this->try_prune_leaf( tsk.nid, tree.GetDepth( tsk.nid ) );
}
private:
// make split for current task, re-arrange positions in idset
inline void make_split( Task tsk, const SCEntry *entry, int num, float loss_chg, double base_weight ){
inline void make_split( Task tsk, const SCEntry *entry, int num, float loss_chg, double sum_hess, double base_weight ){
// before split, first prepare statistics
RegTree::NodeStat &s = tree.stat( tsk.nid );
s.loss_chg = loss_chg;
s.leaf_child_cnt = 0;
s.sum_hess = static_cast<float>( sum_hess );
s.base_weight = static_cast<float>( base_weight );
// add childs to current node
@@ -345,7 +347,7 @@ namespace xgboost{
// add splits
tree[ tsk.nid ].set_split( e.split_index(), e.split_value, e.default_left() );
// re-arrange idset, push tasks
this->make_split( tsk, &entry[ e.start ], e.len, e.loss_chg, base_weight );
this->make_split( tsk, &entry[ e.start ], e.len, e.loss_chg, rsum_hess, base_weight );
}else{
// make leaf if we didn't meet requirement
this->make_leaf( tsk, rsum_grad, rsum_hess, false );