start add coltree maker
This commit is contained in:
@@ -205,13 +205,13 @@ namespace xgboost{
|
||||
|
||||
// enumerate split point of the tree
|
||||
inline void enumerate_split( RTSelecter &sglobal, int tlen,
|
||||
double rsum_grad, double rsum_hess, double root_cost,
|
||||
double rsum_grad, double rsum_hess, double root_gain,
|
||||
const SCEntry *entry, size_t start, size_t end,
|
||||
int findex, float parent_base_weight ){
|
||||
// local selecter
|
||||
RTSelecter slocal( param );
|
||||
|
||||
if( param.default_direction != 1 ){
|
||||
if( param.need_forward_search() ){
|
||||
// forward process, default right
|
||||
double csum_grad = 0.0, csum_hess = 0.0;
|
||||
for( size_t j = start; j < end; j ++ ){
|
||||
@@ -225,8 +225,8 @@ namespace xgboost{
|
||||
if( dsum_hess < param.min_child_weight ) break;
|
||||
// change of loss
|
||||
double loss_chg =
|
||||
param.CalcCost( csum_grad, csum_hess, parent_base_weight ) +
|
||||
param.CalcCost( rsum_grad - csum_grad, dsum_hess, parent_base_weight ) - root_cost;
|
||||
param.CalcGain( csum_grad, csum_hess, parent_base_weight ) +
|
||||
param.CalcGain( rsum_grad - csum_grad, dsum_hess, parent_base_weight ) - root_gain;
|
||||
|
||||
const int clen = static_cast<int>( j + 1 - start );
|
||||
// add candidate to selecter
|
||||
@@ -237,7 +237,7 @@ namespace xgboost{
|
||||
}
|
||||
}
|
||||
|
||||
if( param.default_direction != 2 ){
|
||||
if( param.need_backward_search() ){
|
||||
// backward process, default left
|
||||
double csum_grad = 0.0, csum_hess = 0.0;
|
||||
for( size_t j = end; j > start; j -- ){
|
||||
@@ -249,8 +249,8 @@ namespace xgboost{
|
||||
if( csum_hess < param.min_child_weight ) continue;
|
||||
const double dsum_hess = rsum_hess - csum_hess;
|
||||
if( dsum_hess < param.min_child_weight ) break;
|
||||
double loss_chg = param.CalcCost( csum_grad, csum_hess, parent_base_weight ) +
|
||||
param.CalcCost( rsum_grad - csum_grad, dsum_hess, parent_base_weight ) - root_cost;
|
||||
double loss_chg = param.CalcGain( csum_grad, csum_hess, parent_base_weight ) +
|
||||
param.CalcGain( rsum_grad - csum_grad, dsum_hess, parent_base_weight ) - root_gain;
|
||||
const int clen = static_cast<int>( end - j + 1 );
|
||||
// add candidate to selecter
|
||||
slocal.push_back( RTSelecter::Entry( loss_chg, j - 1, clen, findex,
|
||||
@@ -319,8 +319,8 @@ namespace xgboost{
|
||||
|
||||
// global selecter
|
||||
RTSelecter sglobal( param );
|
||||
// cost root
|
||||
const double root_cost = param.CalcRootCost( rsum_grad, rsum_hess );
|
||||
// gain root
|
||||
const double root_gain = param.CalcRootGain( rsum_grad, rsum_hess );
|
||||
// KEY: layerwise, weight of current node if it is leaf
|
||||
const double base_weight = param.CalcWeight( rsum_grad, rsum_hess, tsk.parent_base_weight );
|
||||
// enumerate feature index
|
||||
@@ -333,7 +333,7 @@ namespace xgboost{
|
||||
std::sort( entry.begin() + start, entry.begin() + end );
|
||||
// local selecter
|
||||
this->enumerate_split( sglobal, tsk.len,
|
||||
rsum_grad, rsum_hess, root_cost,
|
||||
rsum_grad, rsum_hess, root_gain,
|
||||
&entry[0], start, end, findex, base_weight );
|
||||
}
|
||||
// Cleanup tmp_rptr for next use
|
||||
|
||||
Reference in New Issue
Block a user