diff --git a/booster/tree/xgboost_base_treemaker.hpp b/booster/tree/xgboost_base_treemaker.hpp index dcc8c06fd..1d604eccc 100644 --- a/booster/tree/xgboost_base_treemaker.hpp +++ b/booster/tree/xgboost_base_treemaker.hpp @@ -40,19 +40,25 @@ namespace xgboost{ return !(this->loss_chg > loss_chg); } } - inline void Update( const SplitEntry &e ){ + inline bool Update( const SplitEntry &e ){ if( this->NeedReplace( e.loss_chg, e.split_index() ) ){ this->loss_chg = e.loss_chg; this->sindex = e.sindex; this->split_value = e.split_value; - } + return true; + } else{ + return false; + } } - inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){ + inline bool Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){ if( this->NeedReplace( loss_chg, split_index ) ){ this->loss_chg = loss_chg; if( default_left ) split_index |= (1U << 31); this->sindex = split_index; this->split_value = split_value; + return true; + }else{ + return false; } } inline unsigned split_index( void ) const{ @@ -78,26 +84,6 @@ namespace xgboost{ weight = root_gain = 0.0f; } }; - - /*! \brief per thread x per node entry to store tmp data */ - struct ThreadEntry{ - /*! \brief sum gradient statistics */ - double sum_grad; - /*! \brief sum hessian statistics */ - double sum_hess; - /*! \brief last feature value scanned */ - float last_fvalue; - /*! \brief current best solution */ - SplitEntry best; - /*! \brief constructor */ - ThreadEntry( void ){ - this->ClearStats(); - } - /*! \brief clear statistics */ - inline void ClearStats( void ){ - sum_grad = sum_hess = 0.0; - } - }; private: // try to prune off current leaf, return true if successful inline void TryPruneLeaf( int nid, int depth ){ diff --git a/booster/tree/xgboost_col_treemaker.hpp b/booster/tree/xgboost_col_treemaker.hpp index 2782c67a9..16eca07d5 100644 --- a/booster/tree/xgboost_col_treemaker.hpp +++ b/booster/tree/xgboost_col_treemaker.hpp @@ -16,7 +16,7 @@ namespace xgboost{ namespace booster{ template - class ColTreeMaker : public BaseTreeMaker{ + class ColTreeMaker : protected BaseTreeMaker{ public: ColTreeMaker( RegTree &tree, const TreeParamTrain ¶m, @@ -53,6 +53,26 @@ namespace xgboost{ // start prunning the tree stat_num_pruned = this->DoPrune(); } + private: + /*! \brief per thread x per node entry to store tmp data */ + struct ThreadEntry{ + /*! \brief sum gradient statistics */ + double sum_grad; + /*! \brief sum hessian statistics */ + double sum_hess; + /*! \brief last feature value scanned */ + float last_fvalue; + /*! \brief current best solution */ + SplitEntry best; + /*! \brief constructor */ + ThreadEntry( void ){ + this->ClearStats(); + } + /*! \brief clear statistics */ + inline void ClearStats( void ){ + sum_grad = sum_hess = 0.0; + } + }; private: // make leaf nodes for all qexpand, update node statistics, mark leaf value inline void InitNewNode( const std::vector &qexpand ){ diff --git a/booster/tree/xgboost_row_treemaker.hpp b/booster/tree/xgboost_row_treemaker.hpp new file mode 100644 index 000000000..00f16af75 --- /dev/null +++ b/booster/tree/xgboost_row_treemaker.hpp @@ -0,0 +1,149 @@ +#ifndef XGBOOST_ROW_TREEMAKER_HPP +#define XGBOOST_ROW_TREEMAKER_HPP +/*! + * \file xgboost_row_treemaker.hpp + * \brief implementation of regression tree maker, + * use a row based approach + * \author Tianqi Chen: tianqi.tchen@gmail.com + */ +// use openmp +#include +#include "xgboost_tree_model.h" +#include "../../utils/xgboost_omp.h" +#include "../../utils/xgboost_random.h" +#include "xgboost_base_treemaker.hpp" + +namespace xgboost{ + namespace booster{ + template + class RowTreeMaker : protected BaseTreeMaker{ + public: + RowTreeMaker( RegTree &tree, + const TreeParamTrain ¶m, + const std::vector &grad, + const std::vector &hess, + const FMatrix &smat, + const std::vector &root_index ) + : BaseTreeMaker( tree, param ), + grad(grad), hess(hess), + smat(smat), root_index(root_index) { + utils::Assert( grad.size() == hess.size(), "booster:invalid input" ); + utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" ); + utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" ); + } + inline void Make( int& stat_max_depth, int& stat_num_pruned ){ + this->InitData(); + this->InitNewNode( this->qexpand ); + stat_max_depth = 0; + + for( int depth = 0; depth < param.max_depth; ++ depth ){ + //this->FindSplit( this->qexpand ); + this->UpdateQueueExpand( this->qexpand ); + this->InitNewNode( this->qexpand ); + // if nothing left to be expand, break + if( qexpand.size() == 0 ) break; + stat_max_depth = depth + 1; + } + // set all the rest expanding nodes to leaf + for( size_t i = 0; i < qexpand.size(); ++ i ){ + const int nid = qexpand[i]; + tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate ); + } + // start prunning the tree + stat_num_pruned = this->DoPrune(); + } + private: + // make leaf nodes for all qexpand, update node statistics, mark leaf value + inline void InitNewNode( const std::vector &qexpand ){ + snode.resize( tree.param.num_nodes, NodeEntry() ); + + for( size_t j = 0; j < qexpand.size(); ++ j ){ + const int nid = qexpand[ j ]; + double sum_grad = 0.0, sum_hess = 0.0; + // TODO: get sum statistics for nid + + // update node statistics + snode[nid].sum_grad = sum_grad; + snode[nid].sum_hess = sum_hess; + snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess ); + if( !tree[nid].is_root() ){ + snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight ); + }else{ + snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f ); + } + } + } + // find splits at current level + inline void FindSplit( int nid ){ + // TODO + + } + private: + // initialize temp data structure + inline void InitData( void ){ + std::vector valid_index; + for( size_t i = 0; i < grad.size(); ++i ){ + if( hess[ i ] < 0.0f ) continue; + if( param.subsample > 1.0f-1e-6f || random::SampleBinary( param.subsample ) != 0 ){ + valid_index.push_back( static_cast(i) ); + } + } + node_bound.resize( tree.param.num_roots ); + + if( root_index.size() == 0 ){ + row_index_set = valid_index; + // set bound of root node + node_bound[0] = std::make_pair( 0, (bst_uint)row_index_set.size() ); + }else{ + std::vector rptr; + utils::SparseCSRMBuilder builder( rptr, row_index_set ); + builder.InitBudget( tree.param.num_roots ); + for( size_t i = 0; i < valid_index.size(); ++i ){ + const bst_uint rid = valid_index[ i ]; + utils::Assert( root_index[ rid ] < (unsigned)tree.param.num_roots, "root id exceed number of roots" ); + builder.AddBudget( root_index[ rid ] ); + } + builder.InitStorage(); + for( size_t i = 0; i < valid_index.size(); ++i ){ + const bst_uint rid = valid_index[ i ]; + builder.PushElem( root_index[ rid ], rid ); + } + for( size_t i = 1; i < rptr.size(); ++ i ){ + node_bound[i-1] = std::make_pair( rptr[ i - 1 ], rptr[ i ] ); + } + } + + {// setup temp space for each thread + if( param.nthread != 0 ){ + omp_set_num_threads( param.nthread ); + } + #pragma omp parallel + { + this->nthread = omp_get_num_threads(); + } + snode.reserve( 256 ); + } + + {// expand query + qexpand.reserve( 256 ); qexpand.clear(); + for( int i = 0; i < tree.param.num_roots; ++ i ){ + qexpand.push_back( i ); + } + } + } + private: + // number of omp thread used during training + int nthread; + // Instance row indexes corresponding to each node + std::vector row_index_set; + // lower and upper bound of each nodes' row_index + std::vector< std::pair > node_bound; + private: + const std::vector &grad; + const std::vector &hess; + const FMatrix &smat; + const std::vector &root_index; + }; + }; +}; +#endif diff --git a/booster/tree/xgboost_tree.hpp b/booster/tree/xgboost_tree.hpp index 31cfcf01f..0ce7125e0 100644 --- a/booster/tree/xgboost_tree.hpp +++ b/booster/tree/xgboost_tree.hpp @@ -24,6 +24,7 @@ namespace xgboost{ #include "xgboost_svdf_tree.hpp" #include "xgboost_col_treemaker.hpp" +#include "xgboost_row_treemaker.hpp" namespace xgboost{ namespace booster{ @@ -64,13 +65,23 @@ namespace xgboost{ printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() ); } int num_pruned; - if( tree_maker == 0 ){ - // start with a id set + switch( tree_maker ){ + case 0: { RTreeUpdater updater( param, tree, grad, hess, smat, root_index ); tree.param.max_depth = updater.do_boost( num_pruned ); - }else{ + break; + } + case 1:{ ColTreeMaker maker( tree, param, grad, hess, smat, root_index ); maker.Make( tree.param.max_depth, num_pruned ); + break; + } + case 2:{ + RowTreeMaker maker( tree, param, grad, hess, smat, root_index ); + maker.Make( tree.param.max_depth, num_pruned ); + break; + } + default: utils::Error("unknown tree maker"); } if( !silent ){ printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",