diff --git a/booster/tree/xgboost_tree.hpp b/booster/tree/xgboost_tree.hpp new file mode 100644 index 000000000..76ec073ce --- /dev/null +++ b/booster/tree/xgboost_tree.hpp @@ -0,0 +1,146 @@ +#ifndef _XGBOOST_TREE_HPP_ +#define _XGBOOST_TREE_HPP_ +/*! + * \file xgboost_tree.hpp + * \brief implementation of regression tree + * \author Tianqi Chen: tianqi.tchen@gmail.com + */ +#include "xgboost_tree_model.h" + +namespace xgboost{ + namespace booster{ + const bool rt_debug = false; + // whether to check bugs + const bool check_bug = false; + + const float rt_eps = 1e-5f; + const float rt_2eps = rt_eps * 2.0f; + + inline double sqr( double a ){ + return a * a; + } + }; +}; + +#include "xgboost_svdf_tree.hpp" + +namespace xgboost{ + namespace booster{ + // regression tree, construction algorithm is seperated from this class + // see RegTreeUpdater + class RegTreeTrainer : public IBooster{ + public: + RegTreeTrainer( void ){ silent = 0; } + virtual ~RegTreeTrainer( void ){} + public: + virtual void SetParam( const char *name, const char *val ){ + if( !strcmp( name, "silent") ) silent = atoi( val ); + param.SetParam( name, val ); + tree.param.SetParam( name, val ); + } + virtual void LoadModel( utils::IStream &fi ){ + tree.LoadModel( fi ); + } + virtual void SaveModel( utils::IStream &fo ) const{ + tree.SaveModel( fo ); + } + virtual void InitModel( void ){ + tree.InitModel(); + } + public: + virtual void DoBoost( std::vector &grad, + std::vector &hess, + const FMatrixS &smat, + const std::vector &group_id ){ + this->DoBoost_( grad, hess, smat, group_id ); + } + + virtual int GetLeafIndex( const std::vector &feat, + const std::vector &funknown, + unsigned gid = 0 ){ + // start from groups that belongs to current data + int pid = (int)gid; + // tranverse tree + while( !tree[ pid ].is_leaf() ){ + unsigned split_index = tree[ pid ].split_index(); + pid = this->GetNext( pid, feat[ split_index ], funknown[ split_index ] ); + } + return pid; + } + virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){ + this->InitTmp(); + for( unsigned i = 0; i < feat.len; i ++ ){ + utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" ); + tmp_funknown[ feat[i].findex ] = false; + tmp_feat[ feat[i].findex ] = feat[i].fvalue; + } + int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid ); + // set back + for( unsigned i = 0; i < feat.len; i ++ ){ + tmp_funknown[ feat[i].findex ] = true; + } + return tree[ pid ].leaf_value(); + } + virtual float Predict( const std::vector &feat, + const std::vector &funknown, + unsigned gid = 0 ){ + utils::Assert( feat.size() >= (size_t)tree.param.num_feature, + "input data smaller than num feature" ); + int pid = this->GetLeafIndex( feat, funknown, gid ); + return tree[ pid ].leaf_value(); + } + + virtual void DumpModel( FILE *fo ){ + tree.DumpModel( fo ); + } + private: + template + inline void DoBoost_( std::vector &grad, + std::vector &hess, + const FMatrix &smat, + const std::vector &group_id ){ + utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" ); + if( !silent ){ + printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() ); + } + // start with a id set + RTreeUpdater updater( param, tree, grad, hess, smat, group_id ); + int num_pruned; + tree.param.max_depth = updater.do_boost( num_pruned ); + + if( !silent ){ + printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n", + tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth ); + } + } + private: + int silent; + RegTree tree; + TreeParamTrain param; + private: + std::vector tmp_feat; + std::vector tmp_funknown; + inline void InitTmp( void ){ + if( tmp_feat.size() != (size_t)tree.param.num_feature ){ + tmp_feat.resize( tree.param.num_feature ); + tmp_funknown.resize( tree.param.num_feature ); + std::fill( tmp_funknown.begin(), tmp_funknown.end(), true ); + } + } + + inline int GetNext( int pid, float fvalue, bool is_unknown ){ + float split_value = tree[ pid ].split_cond(); + if( is_unknown ){ + if( tree[ pid ].default_left() ) return tree[ pid ].cleft(); + else return tree[ pid ].cright(); + }else{ + if( fvalue < split_value ) return tree[ pid ].cleft(); + else return tree[ pid ].cright(); + } + } + }; + }; +}; + +#endif +