add row tree maker, to be finished
This commit is contained in:
parent
cf14b11130
commit
73dfdc539b
@ -40,19 +40,25 @@ namespace xgboost{
|
||||
return !(this->loss_chg > loss_chg);
|
||||
}
|
||||
}
|
||||
inline void Update( const SplitEntry &e ){
|
||||
inline bool Update( const SplitEntry &e ){
|
||||
if( this->NeedReplace( e.loss_chg, e.split_index() ) ){
|
||||
this->loss_chg = e.loss_chg;
|
||||
this->sindex = e.sindex;
|
||||
this->split_value = e.split_value;
|
||||
}
|
||||
return true;
|
||||
} else{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){
|
||||
inline bool Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){
|
||||
if( this->NeedReplace( loss_chg, split_index ) ){
|
||||
this->loss_chg = loss_chg;
|
||||
if( default_left ) split_index |= (1U << 31);
|
||||
this->sindex = split_index;
|
||||
this->split_value = split_value;
|
||||
return true;
|
||||
}else{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
inline unsigned split_index( void ) const{
|
||||
@ -78,26 +84,6 @@ namespace xgboost{
|
||||
weight = root_gain = 0.0f;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief per thread x per node entry to store tmp data */
|
||||
struct ThreadEntry{
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief last feature value scanned */
|
||||
float last_fvalue;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
/*! \brief constructor */
|
||||
ThreadEntry( void ){
|
||||
this->ClearStats();
|
||||
}
|
||||
/*! \brief clear statistics */
|
||||
inline void ClearStats( void ){
|
||||
sum_grad = sum_hess = 0.0;
|
||||
}
|
||||
};
|
||||
private:
|
||||
// try to prune off current leaf, return true if successful
|
||||
inline void TryPruneLeaf( int nid, int depth ){
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
namespace xgboost{
|
||||
namespace booster{
|
||||
template<typename FMatrix>
|
||||
class ColTreeMaker : public BaseTreeMaker{
|
||||
class ColTreeMaker : protected BaseTreeMaker{
|
||||
public:
|
||||
ColTreeMaker( RegTree &tree,
|
||||
const TreeParamTrain ¶m,
|
||||
@ -53,6 +53,26 @@ namespace xgboost{
|
||||
// start prunning the tree
|
||||
stat_num_pruned = this->DoPrune();
|
||||
}
|
||||
private:
|
||||
/*! \brief per thread x per node entry to store tmp data */
|
||||
struct ThreadEntry{
|
||||
/*! \brief sum gradient statistics */
|
||||
double sum_grad;
|
||||
/*! \brief sum hessian statistics */
|
||||
double sum_hess;
|
||||
/*! \brief last feature value scanned */
|
||||
float last_fvalue;
|
||||
/*! \brief current best solution */
|
||||
SplitEntry best;
|
||||
/*! \brief constructor */
|
||||
ThreadEntry( void ){
|
||||
this->ClearStats();
|
||||
}
|
||||
/*! \brief clear statistics */
|
||||
inline void ClearStats( void ){
|
||||
sum_grad = sum_hess = 0.0;
|
||||
}
|
||||
};
|
||||
private:
|
||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||
|
||||
149
booster/tree/xgboost_row_treemaker.hpp
Normal file
149
booster/tree/xgboost_row_treemaker.hpp
Normal file
@ -0,0 +1,149 @@
|
||||
#ifndef XGBOOST_ROW_TREEMAKER_HPP
|
||||
#define XGBOOST_ROW_TREEMAKER_HPP
|
||||
/*!
|
||||
* \file xgboost_row_treemaker.hpp
|
||||
* \brief implementation of regression tree maker,
|
||||
* use a row based approach
|
||||
* \author Tianqi Chen: tianqi.tchen@gmail.com
|
||||
*/
|
||||
// use openmp
|
||||
#include <vector>
|
||||
#include "xgboost_tree_model.h"
|
||||
#include "../../utils/xgboost_omp.h"
|
||||
#include "../../utils/xgboost_random.h"
|
||||
#include "xgboost_base_treemaker.hpp"
|
||||
|
||||
namespace xgboost{
|
||||
namespace booster{
|
||||
template<typename FMatrix>
|
||||
class RowTreeMaker : protected BaseTreeMaker{
|
||||
public:
|
||||
RowTreeMaker( RegTree &tree,
|
||||
const TreeParamTrain ¶m,
|
||||
const std::vector<float> &grad,
|
||||
const std::vector<float> &hess,
|
||||
const FMatrix &smat,
|
||||
const std::vector<unsigned> &root_index )
|
||||
: BaseTreeMaker( tree, param ),
|
||||
grad(grad), hess(hess),
|
||||
smat(smat), root_index(root_index) {
|
||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||
}
|
||||
inline void Make( int& stat_max_depth, int& stat_num_pruned ){
|
||||
this->InitData();
|
||||
this->InitNewNode( this->qexpand );
|
||||
stat_max_depth = 0;
|
||||
|
||||
for( int depth = 0; depth < param.max_depth; ++ depth ){
|
||||
//this->FindSplit( this->qexpand );
|
||||
this->UpdateQueueExpand( this->qexpand );
|
||||
this->InitNewNode( this->qexpand );
|
||||
// if nothing left to be expand, break
|
||||
if( qexpand.size() == 0 ) break;
|
||||
stat_max_depth = depth + 1;
|
||||
}
|
||||
// set all the rest expanding nodes to leaf
|
||||
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||
const int nid = qexpand[i];
|
||||
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
|
||||
}
|
||||
// start prunning the tree
|
||||
stat_num_pruned = this->DoPrune();
|
||||
}
|
||||
private:
|
||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||
snode.resize( tree.param.num_nodes, NodeEntry() );
|
||||
|
||||
for( size_t j = 0; j < qexpand.size(); ++ j ){
|
||||
const int nid = qexpand[ j ];
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
// TODO: get sum statistics for nid
|
||||
|
||||
// update node statistics
|
||||
snode[nid].sum_grad = sum_grad;
|
||||
snode[nid].sum_hess = sum_hess;
|
||||
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
|
||||
if( !tree[nid].is_root() ){
|
||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight );
|
||||
}else{
|
||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
|
||||
}
|
||||
}
|
||||
}
|
||||
// find splits at current level
|
||||
inline void FindSplit( int nid ){
|
||||
// TODO
|
||||
|
||||
}
|
||||
private:
|
||||
// initialize temp data structure
|
||||
inline void InitData( void ){
|
||||
std::vector<bst_uint> valid_index;
|
||||
for( size_t i = 0; i < grad.size(); ++i ){
|
||||
if( hess[ i ] < 0.0f ) continue;
|
||||
if( param.subsample > 1.0f-1e-6f || random::SampleBinary( param.subsample ) != 0 ){
|
||||
valid_index.push_back( static_cast<bst_uint>(i) );
|
||||
}
|
||||
}
|
||||
node_bound.resize( tree.param.num_roots );
|
||||
|
||||
if( root_index.size() == 0 ){
|
||||
row_index_set = valid_index;
|
||||
// set bound of root node
|
||||
node_bound[0] = std::make_pair( 0, (bst_uint)row_index_set.size() );
|
||||
}else{
|
||||
std::vector<size_t> rptr;
|
||||
utils::SparseCSRMBuilder<bst_uint> builder( rptr, row_index_set );
|
||||
builder.InitBudget( tree.param.num_roots );
|
||||
for( size_t i = 0; i < valid_index.size(); ++i ){
|
||||
const bst_uint rid = valid_index[ i ];
|
||||
utils::Assert( root_index[ rid ] < (unsigned)tree.param.num_roots, "root id exceed number of roots" );
|
||||
builder.AddBudget( root_index[ rid ] );
|
||||
}
|
||||
builder.InitStorage();
|
||||
for( size_t i = 0; i < valid_index.size(); ++i ){
|
||||
const bst_uint rid = valid_index[ i ];
|
||||
builder.PushElem( root_index[ rid ], rid );
|
||||
}
|
||||
for( size_t i = 1; i < rptr.size(); ++ i ){
|
||||
node_bound[i-1] = std::make_pair( rptr[ i - 1 ], rptr[ i ] );
|
||||
}
|
||||
}
|
||||
|
||||
{// setup temp space for each thread
|
||||
if( param.nthread != 0 ){
|
||||
omp_set_num_threads( param.nthread );
|
||||
}
|
||||
#pragma omp parallel
|
||||
{
|
||||
this->nthread = omp_get_num_threads();
|
||||
}
|
||||
snode.reserve( 256 );
|
||||
}
|
||||
|
||||
{// expand query
|
||||
qexpand.reserve( 256 ); qexpand.clear();
|
||||
for( int i = 0; i < tree.param.num_roots; ++ i ){
|
||||
qexpand.push_back( i );
|
||||
}
|
||||
}
|
||||
}
|
||||
private:
|
||||
// number of omp thread used during training
|
||||
int nthread;
|
||||
// Instance row indexes corresponding to each node
|
||||
std::vector<bst_uint> row_index_set;
|
||||
// lower and upper bound of each nodes' row_index
|
||||
std::vector< std::pair<bst_uint, bst_uint> > node_bound;
|
||||
private:
|
||||
const std::vector<float> &grad;
|
||||
const std::vector<float> &hess;
|
||||
const FMatrix &smat;
|
||||
const std::vector<unsigned> &root_index;
|
||||
};
|
||||
};
|
||||
};
|
||||
#endif
|
||||
@ -24,6 +24,7 @@ namespace xgboost{
|
||||
|
||||
#include "xgboost_svdf_tree.hpp"
|
||||
#include "xgboost_col_treemaker.hpp"
|
||||
#include "xgboost_row_treemaker.hpp"
|
||||
|
||||
namespace xgboost{
|
||||
namespace booster{
|
||||
@ -64,13 +65,23 @@ namespace xgboost{
|
||||
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
|
||||
}
|
||||
int num_pruned;
|
||||
if( tree_maker == 0 ){
|
||||
// start with a id set
|
||||
switch( tree_maker ){
|
||||
case 0: {
|
||||
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
|
||||
tree.param.max_depth = updater.do_boost( num_pruned );
|
||||
}else{
|
||||
break;
|
||||
}
|
||||
case 1:{
|
||||
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
|
||||
maker.Make( tree.param.max_depth, num_pruned );
|
||||
break;
|
||||
}
|
||||
case 2:{
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
|
||||
maker.Make( tree.param.max_depth, num_pruned );
|
||||
break;
|
||||
}
|
||||
default: utils::Error("unknown tree maker");
|
||||
}
|
||||
if( !silent ){
|
||||
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user