split new base treemaker, not very good abstraction, but ok

This commit is contained in:
tqchen 2014-03-05 10:20:36 -08:00
parent 128e94be1a
commit 2910bdedf4
2 changed files with 168 additions and 141 deletions

View File

@ -0,0 +1,161 @@
#ifndef XGBOOST_BASE_TREEMAKER_HPP
#define XGBOOST_BASE_TREEMAKER_HPP
/*!
* \file xgboost_base_treemaker.hpp
* \brief implementation of base data structure for regression tree maker,
* gives common operations of tree construction steps template
*
* \author Tianqi Chen: tianqi.tchen@gmail.com
*/
#include <vector>
#include "xgboost_tree_model.h"
namespace xgboost{
namespace booster{
class BaseTreeMaker{
protected:
BaseTreeMaker( RegTree &tree,
const TreeParamTrain &param )
: tree( tree ), param( param ){}
protected:
// statistics that is helpful to decide a split
struct SplitEntry{
/*! \brief loss change after split this node */
float loss_chg;
/*! \brief split index */
unsigned sindex;
/*! \brief split value */
float split_value;
/*! \brief constructor */
SplitEntry( void ){
loss_chg = 0.0f;
split_value = 0.0f; sindex = 0;
}
// This function gives better priority to lower index when loss_chg equals
// not the best way, but helps to give consistent result during multi-thread execution
inline bool NeedReplace( float loss_chg, unsigned split_index ) const{
if( this->split_index() <= split_index ){
return loss_chg > this->loss_chg;
}else{
return !(this->loss_chg > loss_chg);
}
}
inline void Update( const SplitEntry &e ){
if( this->NeedReplace( e.loss_chg, e.split_index() ) ){
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
this->split_value = e.split_value;
}
}
inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){
if( this->NeedReplace( loss_chg, split_index ) ){
this->loss_chg = loss_chg;
if( default_left ) split_index |= (1U << 31);
this->sindex = split_index;
this->split_value = split_value;
}
}
inline unsigned split_index( void ) const{
return sindex & ( (1U<<31) - 1U );
}
inline bool default_left( void ) const{
return (sindex >> 31) != 0;
}
};
struct NodeEntry{
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief loss of this node, without split */
float root_gain;
/*! \brief weight calculated related to current data */
float weight;
/*! \brief current best solution */
SplitEntry best;
NodeEntry( void ){
sum_grad = sum_hess = 0.0;
weight = root_gain = 0.0f;
}
};
/*! \brief per thread x per node entry to store tmp data */
struct ThreadEntry{
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief last feature value scanned */
float last_fvalue;
/*! \brief current best solution */
SplitEntry best;
/*! \brief constructor */
ThreadEntry( void ){
this->ClearStats();
}
/*! \brief clear statistics */
inline void ClearStats( void ){
sum_grad = sum_hess = 0.0;
}
};
private:
// try to prune off current leaf, return true if successful
inline void TryPruneLeaf( int nid, int depth ){
if( tree[ nid ].is_root() ) return;
int pid = tree[ nid ].parent();
RegTree::NodeStat &s = tree.stat( pid );
++ s.leaf_child_cnt;
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
this->stat_num_pruned += 2;
// need to be pruned
tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight );
// tail recursion
this->TryPruneLeaf( pid, depth - 1 );
}
}
protected:
/*! \brief do prunning of a tree */
inline int DoPrune( void ){
this->stat_num_pruned = 0;
// initialize auxiliary statistics
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
tree.stat( nid ).leaf_child_cnt = 0;
tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg;
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
}
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
if( tree[ nid ].is_leaf() ) this->TryPruneLeaf( nid, tree.GetDepth(nid) );
}
return this->stat_num_pruned;
}
protected:
/*! \brief update queue expand add in new leaves */
inline void UpdateQueueExpand( std::vector<int> &qexpand ){
std::vector<int> newnodes;
for( size_t i = 0; i < qexpand.size(); ++ i ){
const int nid = qexpand[i];
if( !tree[ nid ].is_leaf() ){
newnodes.push_back( tree[nid].cleft() );
newnodes.push_back( tree[nid].cright() );
}
}
// use new nodes for qexpand
qexpand = newnodes;
}
protected:
// local helper tmp data structure
// statistics
int stat_num_pruned;
/*! \brief queue of nodes to be expanded */
std::vector<int> qexpand;
/*! \brief TreeNode Data: statistics for each constructed node, the derived class must maintain this */
std::vector<NodeEntry> snode;
protected:
// original data that supports tree construction
RegTree &tree;
const TreeParamTrain &param;
};
}; // namespace booster
}; // namespace xgboost
#endif // XGBOOST_BASE_TREEMAKER_HPP

View File

@ -11,23 +11,25 @@
#include "xgboost_tree_model.h"
#include "../../utils/xgboost_omp.h"
#include "../../utils/xgboost_random.h"
#include "xgboost_base_treemaker.hpp"
namespace xgboost{
namespace booster{
template<typename FMatrix>
class ColTreeMaker{
class ColTreeMaker : public BaseTreeMaker{
public:
ColTreeMaker( RegTree &tree,
const TreeParamTrain &param,
const std::vector<float> &grad,
const std::vector<float> &hess,
const FMatrix &smat,
const std::vector<unsigned> &root_index ):
tree( tree ), param( param ), grad( grad ), hess( hess ),
smat( smat ), root_index( root_index ){
const std::vector<unsigned> &root_index )
: BaseTreeMaker( tree, param ),
grad(grad), hess(hess),
smat(smat), root_index(root_index) {
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
utils::Assert( smat.HaveColAccess(), "ColTreeMaker: need column access matrix" );
}
inline void Make( int& stat_max_depth, int& stat_num_pruned ){
@ -52,131 +54,6 @@ namespace xgboost{
stat_num_pruned = this->DoPrune();
}
private:
// statistics that is helpful to decide a split
struct SplitEntry{
/*! \brief loss change after split this node */
float loss_chg;
/*! \brief split index */
unsigned sindex;
/*! \brief split value */
float split_value;
/*! \brief constructor */
SplitEntry( void ){
loss_chg = 0.0f;
split_value = 0.0f; sindex = 0;
}
// This function gives better priority to lower index when loss_chg equals
// not the best way, but helps to give consistent result during multi-thread execution
inline bool NeedReplace( float loss_chg, unsigned split_index ) const{
if( this->split_index() <= split_index ){
return loss_chg > this->loss_chg;
}else{
return !(this->loss_chg > loss_chg);
}
}
inline void Update( const SplitEntry &e ){
if( this->NeedReplace( e.loss_chg, e.split_index() ) ){
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
this->split_value = e.split_value;
}
}
inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){
if( this->NeedReplace( loss_chg, split_index ) ){
this->loss_chg = loss_chg;
if( default_left ) split_index |= (1U << 31);
this->sindex = split_index;
this->split_value = split_value;
}
}
inline unsigned split_index( void ) const{
return sindex & ( (1U<<31) - 1U );
}
inline bool default_left( void ) const{
return (sindex >> 31) != 0;
}
};
struct NodeEntry{
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief loss of this node, without split */
float root_gain;
/*! \brief weight calculated related to current data */
float weight;
/*! \brief current best solution */
SplitEntry best;
NodeEntry( void ){
sum_grad = sum_hess = 0.0;
weight = root_gain = 0.0f;
}
};
/*! \brief per thread x per node entry to store tmp data */
struct ThreadEntry{
/*! \brief sum gradient statistics */
double sum_grad;
/*! \brief sum hessian statistics */
double sum_hess;
/*! \brief last feature value scanned */
float last_fvalue;
/*! \brief current best solution */
SplitEntry best;
/*! \brief constructor */
ThreadEntry( void ){
this->ClearStats();
}
/*! \brief clear statistics */
inline void ClearStats( void ){
sum_grad = sum_hess = 0.0;
}
};
private:
// try to prune off current leaf, return true if successful
inline void TryPruneLeaf( int nid, int depth ){
if( tree[ nid ].is_root() ) return;
int pid = tree[ nid ].parent();
RegTree::NodeStat &s = tree.stat( pid );
++ s.leaf_child_cnt;
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
this->stat_num_pruned += 2;
// need to be pruned
tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight );
// tail recursion
this->TryPruneLeaf( pid, depth - 1 );
}
}
// prune tree
inline int DoPrune( void ){
this->stat_num_pruned = 0;
// initialize auxiliary statistics
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
tree.stat( nid ).leaf_child_cnt = 0;
tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg;
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
}
for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){
if( tree[ nid ].is_leaf() ) this->TryPruneLeaf( nid, tree.GetDepth(nid) );
}
return this->stat_num_pruned;
}
private:
// update queue expand
inline void UpdateQueueExpand( std::vector<int> &qexpand ){
std::vector<int> newnodes;
for( size_t i = 0; i < qexpand.size(); ++ i ){
const int nid = qexpand[i];
if( !tree[ nid ].is_leaf() ){
newnodes.push_back( tree[nid].cleft() );
newnodes.push_back( tree[nid].cright() );
}
}
// use new nodes for qexpand
qexpand = newnodes;
}
// make leaf nodes for all qexpand, update node statistics, mark leaf value
inline void InitNewNode( const std::vector<int> &qexpand ){
{// setup statistics space for each tree node
@ -414,30 +291,19 @@ namespace xgboost{
}
}
private:
// local helper tmp data structure
// statistics
int stat_num_pruned;
// number of omp thread used during training
int nthread;
// queue of nodes to be expanded
std::vector<int> qexpand;
// Per feature: shuffle index of each feature index
std::vector<int> feat_index;
// Instance Data: current node position in the tree of each instance
std::vector<int> position;
// TreeNode Data: statistics for each constructed node
std::vector<NodeEntry> snode;
// PerThread x PerTreeNode: statistics for per thread construction
std::vector< std::vector<ThreadEntry> > stemp;
private:
// original data that supports tree construction
RegTree &tree;
const TreeParamTrain &param;
const std::vector<float> &grad;
const std::vector<float> &hess;
const FMatrix &smat;
const std::vector<unsigned> &root_index;
};
};
};