fix col maker, make it default
This commit is contained in:
parent
394d325078
commit
550010e9d2
@ -65,15 +65,24 @@ namespace xgboost{
|
||||
loss_chg = 0.0f;
|
||||
split_value = 0.0f; sindex = 0;
|
||||
}
|
||||
// This function gives better priority to lower index when loss_chg equals
|
||||
// not the best way, but helps to give consistent result during multi-thread execution
|
||||
inline bool NeedReplace( float loss_chg, unsigned split_index ) const{
|
||||
if( this->split_index() <= split_index ){
|
||||
return loss_chg > this->loss_chg;
|
||||
}else{
|
||||
return !(this->loss_chg > loss_chg);
|
||||
}
|
||||
}
|
||||
inline void Update( const SplitEntry &e ){
|
||||
if( e.loss_chg > this->loss_chg ){
|
||||
if( this->NeedReplace( e.loss_chg, e.split_index() ) ){
|
||||
this->loss_chg = e.loss_chg;
|
||||
this->sindex = e.sindex;
|
||||
this->split_value = e.split_value;
|
||||
}
|
||||
}
|
||||
inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){
|
||||
if( loss_chg > this->loss_chg ){
|
||||
if( this->NeedReplace( loss_chg, split_index ) ){
|
||||
this->loss_chg = loss_chg;
|
||||
if( default_left ) split_index |= (1U << 31);
|
||||
this->sindex = split_index;
|
||||
@ -153,14 +162,6 @@ namespace xgboost{
|
||||
return this->stat_num_pruned;
|
||||
}
|
||||
private:
|
||||
// main procedure
|
||||
inline void CleanSTemp( const std::vector<int> &qexpand ){
|
||||
for( size_t i = 0; i < stemp.size(); ++ i ){
|
||||
for( size_t j = 0; j < qexpand.size(); ++ j ){
|
||||
stemp[i][ qexpand[j] ].ClearStats();
|
||||
}
|
||||
}
|
||||
}
|
||||
// update queue expand
|
||||
inline void UpdateQueueExpand( std::vector<int> &qexpand ){
|
||||
std::vector<int> newnodes;
|
||||
@ -216,6 +217,11 @@ namespace xgboost{
|
||||
// enumerate the split values of specific feature
|
||||
template<typename Iter>
|
||||
inline void EnumerateSplit( Iter it, const unsigned fid, std::vector<ThreadEntry> &temp, bool is_forward_search ){
|
||||
// clear all the temp statistics
|
||||
for( size_t j = 0; j < qexpand.size(); ++ j ){
|
||||
temp[ qexpand[j] ].ClearStats();
|
||||
}
|
||||
|
||||
while( it.Next() ){
|
||||
const unsigned ridx = it.rindex();
|
||||
const int nid = position[ ridx ];
|
||||
@ -253,6 +259,7 @@ namespace xgboost{
|
||||
const int nid = qexpand[ i ];
|
||||
ThreadEntry &e = temp[ nid ];
|
||||
const double csum_hess = snode[nid].sum_hess - e.sum_hess;
|
||||
|
||||
if( e.sum_hess >= param.min_child_weight && csum_hess >= param.min_child_weight ){
|
||||
const double csum_grad = snode[nid].sum_grad - e.sum_grad;
|
||||
const double loss_chg =
|
||||
@ -269,21 +276,14 @@ namespace xgboost{
|
||||
inline void FindSplit( int depth ){
|
||||
const unsigned nsize = static_cast<unsigned>( feat_index.size() );
|
||||
|
||||
if( param.need_forward_search() ){
|
||||
this->CleanSTemp( this->qexpand );
|
||||
#pragma omp parallel for schedule( dynamic, 1 )
|
||||
for( unsigned i = 0; i < nsize; ++ i ){
|
||||
const unsigned fid = feat_index[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
if( param.need_forward_search() ){
|
||||
this->EnumerateSplit( smat.GetSortedCol(fid), fid, stemp[tid], true );
|
||||
}
|
||||
}
|
||||
if( param.need_backward_search() ){
|
||||
this->CleanSTemp( this->qexpand );
|
||||
#pragma omp parallel for schedule( dynamic, 1 )
|
||||
for( unsigned i = 0; i < nsize; ++ i ){
|
||||
const unsigned fid = feat_index[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
this->EnumerateSplit( smat.GetReverseSortedCol(fid), fid, stemp[tid], false );
|
||||
}
|
||||
}
|
||||
@ -295,8 +295,9 @@ namespace xgboost{
|
||||
for( int tid = 0; tid < this->nthread; ++ tid ){
|
||||
e.best.Update( stemp[ tid ][ nid ].best );
|
||||
}
|
||||
|
||||
// now we know the solution in snode[ nid ], set split
|
||||
if( snode[ nid ].best.loss_chg > rt_eps ){
|
||||
if( e.best.loss_chg > rt_eps ){
|
||||
tree.AddChilds( nid );
|
||||
tree[ nid ].set_split( e.best.split_index(), e.best.split_value, e.best.default_left() );
|
||||
} else{
|
||||
|
||||
@ -32,12 +32,13 @@ namespace xgboost{
|
||||
class RegTreeTrainer : public IBooster{
|
||||
public:
|
||||
RegTreeTrainer( void ){
|
||||
silent = 0; tree_maker = 0;
|
||||
silent = 0; tree_maker = 1;
|
||||
}
|
||||
virtual ~RegTreeTrainer( void ){}
|
||||
public:
|
||||
virtual void SetParam( const char *name, const char *val ){
|
||||
if( !strcmp( name, "silent") ) silent = atoi( val );
|
||||
if( !strcmp( name, "tree_maker") ) tree_maker = atoi( val );
|
||||
param.SetParam( name, val );
|
||||
tree.param.SetParam( name, val );
|
||||
}
|
||||
|
||||
@ -12,9 +12,6 @@ loss_type = 2
|
||||
|
||||
bst:num_feature=126
|
||||
bst:eta=1.0
|
||||
bst:gamma=1
|
||||
bst:gamma=1.0
|
||||
bst:min_child_weight=1
|
||||
bst:max_depth=3
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user