split new base treemaker, not very good abstraction, but ok
This commit is contained in:
parent
128e94be1a
commit
2910bdedf4
161
booster/tree/xgboost_base_treemaker.hpp
Normal file
161
booster/tree/xgboost_base_treemaker.hpp
Normal file
@ -0,0 +1,161 @@
|
||||
#ifndef XGBOOST_BASE_TREEMAKER_HPP
|
||||
#define XGBOOST_BASE_TREEMAKER_HPP
|
||||
/*!
|
||||
* \file xgboost_base_treemaker.hpp
|
||||
* \brief implementation of base data structure for regression tree maker,
|
||||
* gives common operations of tree construction steps template
|
||||
*
|
||||
* \author Tianqi Chen: tianqi.tchen@gmail.com
|
||||
*/
|
||||
#include <vector>
|
||||
#include "xgboost_tree_model.h"
|
||||
|
||||
namespace xgboost{
|
||||
namespace booster{
|
||||
class BaseTreeMaker{
|
||||
protected:
|
||||
BaseTreeMaker( RegTree &tree,
|
||||
const TreeParamTrain ¶m )
|
||||
: tree( tree ), param( param ){}
|
||||
protected:
|
||||
// statistics that is helpful to decide a split
|
||||
struct SplitEntry{
|
||||
/*! \brief loss change after split this node */
|
||||
float loss_chg;
|
||||
/*! \brief split index */
|
||||
unsigned sindex;
|
||||
/*! \brief split value */
|
||||
float split_value;
|
||||
/*! \brief constructor */
|
||||
SplitEntry( void ){
|
||||
loss_chg = 0.0f;
|
||||
split_value = 0.0f; sindex = 0;
|
||||
}
|
||||
// This function gives better priority to lower index when loss_chg equals
|
||||
// not the best way, but helps to give consistent result during multi-thread execution
|
||||
inline bool NeedReplace( float loss_chg, unsigned split_index ) const{
|
||||
if( this->split_index() <= split_index ){
|
||||
return loss_chg > this->loss_chg;
|
||||
}else{
|
||||
return !(this->loss_chg > loss_chg);
|
||||
}
|
||||
}
|
||||
inline void Update( const SplitEntry &e ){
|
||||
if( this->NeedReplace( e.loss_chg, e.split_index() ) ){
|
||||
this->loss_chg = e.loss_chg;
|
||||
this->sindex = e.sindex;
|
||||
this->split_value = e.split_value;
|
||||
}
|
||||
}
|
||||
inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){
|
||||
if( this->NeedReplace( loss_chg, split_index ) ){
|
||||
this->loss_chg = loss_chg;
|
||||
if( default_left ) split_index |= (1U << 31);
|
||||
this->sindex = split_index;
|
||||
this->split_value = split_value;
|
||||
}
|
||||
}
|
||||
inline unsigned split_index( void ) const{
|
||||
return sindex & ( (1U<<31) - 1U );
|
||||
}
|
||||
inline bool default_left( void ) const{
|
||||
return (sindex >> 31) != 0;
|
||||
}
|
||||
};
|
||||
struct NodeEntry{
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief loss of this node, without split */
|
||||
float root_gain;
|
||||
/*! \brief weight calculated related to current data */
|
||||
float weight;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
NodeEntry( void ){
|
||||
sum_grad = sum_hess = 0.0;
|
||||
weight = root_gain = 0.0f;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief per thread x per node entry to store tmp data */
|
||||
struct ThreadEntry{
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief last feature value scanned */
|
||||
float last_fvalue;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
/*! \brief constructor */
|
||||
ThreadEntry( void ){
|
||||
this->ClearStats();
|
||||
}
|
||||
/*! \brief clear statistics */
|
||||
inline void ClearStats( void ){
|
||||
sum_grad = sum_hess = 0.0;
|
||||
}
|
||||
};
|
||||
private:
|
||||
// try to prune off current leaf, return true if successful
|
||||
inline void TryPruneLeaf( int nid, int depth ){
|
||||
if( tree[ nid ].is_root() ) return;
|
||||
int pid = tree[ nid ].parent();
|
||||
RegTree::NodeStat &s = tree.stat( pid );
|
||||
++ s.leaf_child_cnt;
|
||||
|
||||
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
|
||||
this->stat_num_pruned += 2;
|
||||
// need to be pruned
|
||||
tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight );
|
||||
// tail recursion
|
||||
this->TryPruneLeaf( pid, depth - 1 );
|
||||
}
|
||||
}
|
||||
protected:
|
||||
/*! \brief do prunning of a tree */
|
||||
inline int DoPrune( void ){
|
||||
this->stat_num_pruned = 0;
|
||||
// initialize auxiliary statistics
|
||||
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
|
||||
tree.stat( nid ).leaf_child_cnt = 0;
|
||||
tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg;
|
||||
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||
}
|
||||
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
|
||||
if( tree[ nid ].is_leaf() ) this->TryPruneLeaf( nid, tree.GetDepth(nid) );
|
||||
}
|
||||
return this->stat_num_pruned;
|
||||
}
|
||||
protected:
|
||||
/*! \brief update queue expand add in new leaves */
|
||||
inline void UpdateQueueExpand( std::vector<int> &qexpand ){
|
||||
std::vector<int> newnodes;
|
||||
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||
const int nid = qexpand[i];
|
||||
if( !tree[ nid ].is_leaf() ){
|
||||
newnodes.push_back( tree[nid].cleft() );
|
||||
newnodes.push_back( tree[nid].cright() );
|
||||
}
|
||||
}
|
||||
// use new nodes for qexpand
|
||||
qexpand = newnodes;
|
||||
}
|
||||
protected:
|
||||
// local helper tmp data structure
|
||||
// statistics
|
||||
int stat_num_pruned;
|
||||
/*! \brief queue of nodes to be expanded */
|
||||
std::vector<int> qexpand;
|
||||
/*! \brief TreeNode Data: statistics for each constructed node, the derived class must maintain this */
|
||||
std::vector<NodeEntry> snode;
|
||||
protected:
|
||||
// original data that supports tree construction
|
||||
RegTree &tree;
|
||||
const TreeParamTrain ¶m;
|
||||
};
|
||||
}; // namespace booster
|
||||
}; // namespace xgboost
|
||||
#endif // XGBOOST_BASE_TREEMAKER_HPP
|
||||
@ -11,23 +11,25 @@
|
||||
#include "xgboost_tree_model.h"
|
||||
#include "../../utils/xgboost_omp.h"
|
||||
#include "../../utils/xgboost_random.h"
|
||||
#include "xgboost_base_treemaker.hpp"
|
||||
|
||||
namespace xgboost{
|
||||
namespace booster{
|
||||
template<typename FMatrix>
|
||||
class ColTreeMaker{
|
||||
class ColTreeMaker : public BaseTreeMaker{
|
||||
public:
|
||||
ColTreeMaker( RegTree &tree,
|
||||
const TreeParamTrain ¶m,
|
||||
const std::vector<float> &grad,
|
||||
const std::vector<float> &hess,
|
||||
const FMatrix &smat,
|
||||
const std::vector<unsigned> &root_index ):
|
||||
tree( tree ), param( param ), grad( grad ), hess( hess ),
|
||||
smat( smat ), root_index( root_index ){
|
||||
const std::vector<unsigned> &root_index )
|
||||
: BaseTreeMaker( tree, param ),
|
||||
grad(grad), hess(hess),
|
||||
smat(smat), root_index(root_index) {
|
||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( smat.HaveColAccess(), "ColTreeMaker: need column access matrix" );
|
||||
}
|
||||
inline void Make( int& stat_max_depth, int& stat_num_pruned ){
|
||||
@ -52,131 +54,6 @@ namespace xgboost{
|
||||
stat_num_pruned = this->DoPrune();
|
||||
}
|
||||
private:
|
||||
// statistics that is helpful to decide a split
|
||||
struct SplitEntry{
|
||||
/*! \brief loss change after split this node */
|
||||
float loss_chg;
|
||||
/*! \brief split index */
|
||||
unsigned sindex;
|
||||
/*! \brief split value */
|
||||
float split_value;
|
||||
/*! \brief constructor */
|
||||
SplitEntry( void ){
|
||||
loss_chg = 0.0f;
|
||||
split_value = 0.0f; sindex = 0;
|
||||
}
|
||||
// This function gives better priority to lower index when loss_chg equals
|
||||
// not the best way, but helps to give consistent result during multi-thread execution
|
||||
inline bool NeedReplace( float loss_chg, unsigned split_index ) const{
|
||||
if( this->split_index() <= split_index ){
|
||||
return loss_chg > this->loss_chg;
|
||||
}else{
|
||||
return !(this->loss_chg > loss_chg);
|
||||
}
|
||||
}
|
||||
inline void Update( const SplitEntry &e ){
|
||||
if( this->NeedReplace( e.loss_chg, e.split_index() ) ){
|
||||
this->loss_chg = e.loss_chg;
|
||||
this->sindex = e.sindex;
|
||||
this->split_value = e.split_value;
|
||||
}
|
||||
}
|
||||
inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){
|
||||
if( this->NeedReplace( loss_chg, split_index ) ){
|
||||
this->loss_chg = loss_chg;
|
||||
if( default_left ) split_index |= (1U << 31);
|
||||
this->sindex = split_index;
|
||||
this->split_value = split_value;
|
||||
}
|
||||
}
|
||||
inline unsigned split_index( void ) const{
|
||||
return sindex & ( (1U<<31) - 1U );
|
||||
}
|
||||
inline bool default_left( void ) const{
|
||||
return (sindex >> 31) != 0;
|
||||
}
|
||||
};
|
||||
struct NodeEntry{
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief loss of this node, without split */
|
||||
float root_gain;
|
||||
/*! \brief weight calculated related to current data */
|
||||
float weight;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
NodeEntry( void ){
|
||||
sum_grad = sum_hess = 0.0;
|
||||
weight = root_gain = 0.0f;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief per thread x per node entry to store tmp data */
|
||||
struct ThreadEntry{
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief last feature value scanned */
|
||||
float last_fvalue;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
/*! \brief constructor */
|
||||
ThreadEntry( void ){
|
||||
this->ClearStats();
|
||||
}
|
||||
/*! \brief clear statistics */
|
||||
inline void ClearStats( void ){
|
||||
sum_grad = sum_hess = 0.0;
|
||||
}
|
||||
};
|
||||
private:
|
||||
// try to prune off current leaf, return true if successful
|
||||
inline void TryPruneLeaf( int nid, int depth ){
|
||||
if( tree[ nid ].is_root() ) return;
|
||||
int pid = tree[ nid ].parent();
|
||||
RegTree::NodeStat &s = tree.stat( pid );
|
||||
++ s.leaf_child_cnt;
|
||||
|
||||
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
|
||||
this->stat_num_pruned += 2;
|
||||
// need to be pruned
|
||||
tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight );
|
||||
// tail recursion
|
||||
this->TryPruneLeaf( pid, depth - 1 );
|
||||
}
|
||||
}
|
||||
// prune tree
|
||||
inline int DoPrune( void ){
|
||||
this->stat_num_pruned = 0;
|
||||
// initialize auxiliary statistics
|
||||
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
|
||||
tree.stat( nid ).leaf_child_cnt = 0;
|
||||
tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg;
|
||||
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||
}
|
||||
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
|
||||
if( tree[ nid ].is_leaf() ) this->TryPruneLeaf( nid, tree.GetDepth(nid) );
|
||||
}
|
||||
return this->stat_num_pruned;
|
||||
}
|
||||
private:
|
||||
// update queue expand
|
||||
inline void UpdateQueueExpand( std::vector<int> &qexpand ){
|
||||
std::vector<int> newnodes;
|
||||
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||
const int nid = qexpand[i];
|
||||
if( !tree[ nid ].is_leaf() ){
|
||||
newnodes.push_back( tree[nid].cleft() );
|
||||
newnodes.push_back( tree[nid].cright() );
|
||||
}
|
||||
}
|
||||
// use new nodes for qexpand
|
||||
qexpand = newnodes;
|
||||
}
|
||||
|
||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||
{// setup statistics space for each tree node
|
||||
@ -414,30 +291,19 @@ namespace xgboost{
|
||||
}
|
||||
}
|
||||
private:
|
||||
// local helper tmp data structure
|
||||
// statistics
|
||||
int stat_num_pruned;
|
||||
// number of omp thread used during training
|
||||
int nthread;
|
||||
// queue of nodes to be expanded
|
||||
std::vector<int> qexpand;
|
||||
// Per feature: shuffle index of each feature index
|
||||
std::vector<int> feat_index;
|
||||
// Instance Data: current node position in the tree of each instance
|
||||
std::vector<int> position;
|
||||
// TreeNode Data: statistics for each constructed node
|
||||
std::vector<NodeEntry> snode;
|
||||
// PerThread x PerTreeNode: statistics for per thread construction
|
||||
std::vector< std::vector<ThreadEntry> > stemp;
|
||||
private:
|
||||
// original data that supports tree construction
|
||||
RegTree &tree;
|
||||
const TreeParamTrain ¶m;
|
||||
const std::vector<float> &grad;
|
||||
const std::vector<float> &hess;
|
||||
const FMatrix &smat;
|
||||
const std::vector<unsigned> &root_index;
|
||||
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user