modify tree so that training is standalone

This commit is contained in:
tqchen
2014-02-26 16:03:00 -08:00
parent 2c6922f432
commit 4a612eb3ba
8 changed files with 197 additions and 287 deletions

View File

@@ -2,7 +2,7 @@
#define _XGBOOST_APEX_TREE_HPP_
/*!
* \file xgboost_svdf_tree.hpp
* \brief implementation of regression tree, with layerwise support
* \brief implementation of regression tree constructor, with layerwise support
* this file is adapted from GBRT implementation in SVDFeature project
* \author Tianqi Chen: tqchen@apex.sjtu.edu.cn, tianqi.tchen@gmail.com
*/
@@ -12,18 +12,7 @@
#include "../../utils/xgboost_matrix_csr.h"
namespace xgboost{
namespace booster{
const bool rt_debug = false;
// whether to check bugs
const bool check_bug = false;
const float rt_eps = 1e-5f;
const float rt_2eps = rt_eps * 2.0f;
inline double sqr( double a ){
return a * a;
}
namespace booster{
inline void assert_sorted( unsigned *idset, int len ){
if( !rt_debug || !check_bug ) return;
for( int i = 1; i < len; i ++ ){
@@ -32,21 +21,7 @@ namespace xgboost{
}
};
namespace booster{
// node stat used in rtree
struct RTreeNodeStat{
// loss chg caused by current split
float loss_chg;
// weight of current node
float base_weight;
// number of child that is leaf node known up to now
int leaf_child_cnt;
};
// structure of Regression Tree
class RTree: public TreeModel<float,RTreeNodeStat>{
};
namespace booster{
// selecter of rtree to find the suitable candidate
class RTSelecter{
public:
@@ -88,7 +63,9 @@ namespace xgboost{
}
};
// updater of rtree, allows the parameters to be stored inside, key solver
template<typename FMatrix>
class RTreeUpdater{
protected:
// training task, element of single task
@@ -128,10 +105,10 @@ namespace xgboost{
// training parameter
const TreeParamTrain &param;
// parameters, reference
RTree &tree;
RegTree &tree;
std::vector<float> &grad;
std::vector<float> &hess;
const FMatrixS &smat;
const FMatrix &smat;
const std::vector<unsigned> &group_id;
private:
// maximum depth up to now
@@ -158,7 +135,7 @@ namespace xgboost{
inline void try_prune_leaf( int nid, int depth ){
if( tree[ nid ].is_root() ) return;
int pid = tree[ nid ].parent();
RTree::NodeStat &s = tree.stat( pid );
RegTree::NodeStat &s = tree.stat( pid );
s.leaf_child_cnt ++;
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
@@ -186,7 +163,7 @@ namespace xgboost{
// make split for current task, re-arrange positions in idset
inline void make_split( Task tsk, const SCEntry *entry, int num, float loss_chg, double base_weight ){
// before split, first prepare statistics
RTree::NodeStat &s = tree.stat( tsk.nid );
RegTree::NodeStat &s = tree.stat( tsk.nid );
s.loss_chg = loss_chg;
s.leaf_child_cnt = 0;
s.base_weight = static_cast<float>( base_weight );
@@ -214,7 +191,7 @@ namespace xgboost{
}
}
// get two parts
RTree::Node &n = tree[ tsk.nid ];
RegTree::Node &n = tree[ tsk.nid ];
Task def_part( n.default_left() ? n.cleft() : n.cright(), tsk.idset, tsk.len - qset.size(), s.base_weight );
Task spl_part( n.default_left() ? n.cright(): n.cleft() , tsk.idset + def_part.len, qset.size(), s.base_weight );
// fill back split part
@@ -320,9 +297,8 @@ namespace xgboost{
rsum_grad += grad[ ridx ];
rsum_hess += hess[ ridx ];
FMatrixS::Line sp = smat[ ridx ];
for( unsigned j = 0; j < sp.len; j ++ ){
builder.AddBudget( sp[j].findex );
for( typename FMatrix::RowIter it = smat.GetRow(ridx); it.Next(); ){
builder.AddBudget( it.findex() );
}
}
@@ -334,10 +310,9 @@ namespace xgboost{
builder.InitStorage();
for( unsigned i = 0; i < tsk.len; i ++ ){
const unsigned ridx = tsk.idset[i];
FMatrixS::Line sp = smat[ ridx ];
for( unsigned j = 0; j < sp.len; j ++ ){
builder.PushElem( sp[j].findex, SCEntry( sp[j].fvalue, ridx ) );
}
for( typename FMatrix::RowIter it = smat.GetRow(ridx); it.Next(); ){
builder.PushElem( it.findex(), SCEntry( it.fvalue(), ridx ) );
}
}
// --- end of building column major matrix ---
// after this point, tmp_rptr and entry is ready to use
@@ -426,10 +401,10 @@ namespace xgboost{
}
public:
RTreeUpdater( const TreeParamTrain &pparam,
RTree &ptree,
RegTree &ptree,
std::vector<float> &pgrad,
std::vector<float> &phess,
const FMatrixS &psmat,
const FMatrix &psmat,
const std::vector<unsigned> &pgroup_id ):
param( pparam ), tree( ptree ), grad( pgrad ), hess( phess ),
smat( psmat ), group_id( pgroup_id ){
@@ -446,113 +421,6 @@ namespace xgboost{
return max_depth;
}
};
class RTreeTrainer : public IBooster{
private:
int silent;
// tree of current shape
RTree tree;
TreeParamTrain param;
private:
std::vector<float> tmp_feat;
std::vector<bool> tmp_funknown;
inline void init_tmpfeat( void ){
if( tmp_feat.size() != (size_t)tree.param.num_feature ){
tmp_feat.resize( tree.param.num_feature );
tmp_funknown.resize( tree.param.num_feature );
std::fill( tmp_funknown.begin(), tmp_funknown.end(), true );
}
}
public:
virtual void SetParam( const char *name, const char *val ){
if( !strcmp( name, "silent") ) silent = atoi( val );
param.SetParam( name, val );
tree.param.SetParam( name, val );
}
virtual void LoadModel( utils::IStream &fi ){
tree.LoadModel( fi );
}
virtual void SaveModel( utils::IStream &fo ) const{
tree.SaveModel( fo );
}
virtual void InitModel( void ){
tree.InitModel();
}
private:
inline int get_next( int pid, float fvalue, bool is_unknown ){
float split_value = tree[ pid ].split_cond();
if( is_unknown ){
if( tree[ pid ].default_left() ) return tree[ pid ].cleft();
else return tree[ pid ].cright();
}else{
if( fvalue < split_value ) return tree[ pid ].cleft();
else return tree[ pid ].cright();
}
}
public:
virtual void DoBoost( std::vector<float> &grad,
std::vector<float> &hess,
const FMatrixS &smat,
const std::vector<unsigned> &group_id ){
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
if( !silent ){
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
}
// start with a id set
RTreeUpdater updater( param, tree, grad, hess, smat, group_id );
int num_pruned;
tree.param.max_depth = updater.do_boost( num_pruned );
if( !silent ){
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
}
}
virtual int GetLeafIndex( const std::vector<float> &feat,
const std::vector<bool> &funknown,
unsigned gid = 0 ){
// start from groups that belongs to current data
int pid = (int)gid;
// tranverse tree
while( !tree[ pid ].is_leaf() ){
unsigned split_index = tree[ pid ].split_index();
pid = this->get_next( pid, feat[ split_index ], funknown[ split_index ] );
}
return pid;
}
virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){
this->init_tmpfeat();
for( unsigned i = 0; i < feat.len; i ++ ){
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
tmp_funknown[ feat[i].findex ] = false;
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
}
int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid );
// set back
for( unsigned i = 0; i < feat.len; i ++ ){
tmp_funknown[ feat[i].findex ] = true;
}
return tree[ pid ].leaf_value();
}
virtual float Predict( const std::vector<float> &feat,
const std::vector<bool> &funknown,
unsigned gid = 0 ){
utils::Assert( feat.size() >= (size_t)tree.param.num_feature,
"input data smaller than num feature" );
int pid = this->GetLeafIndex( feat, funknown, gid );
return tree[ pid ].leaf_value();
}
virtual void DumpModel( FILE *fo ){
tree.DumpModel( fo );
}
public:
RTreeTrainer( void ){ silent = 0; }
virtual ~RTreeTrainer( void ){}
};
};
};
#endif

View File

@@ -306,7 +306,7 @@ namespace xgboost{
}
};
};
namespace booster{
/*! \brief training parameters for regression tree */
struct TreeParamTrain{
@@ -431,5 +431,20 @@ namespace xgboost{
}
};
};
namespace booster{
/*! \brief node statistics used in regression tree */
struct RTreeNodeStat{
// loss chg caused by current split
float loss_chg;
// weight of current node
float base_weight;
// number of child that is leaf node known up to now
int leaf_child_cnt;
};
/*! \brief most comment structure of regression tree */
class RegTree: public TreeModel<bst_float,RTreeNodeStat>{
};
};
};
#endif