complete row maker
This commit is contained in:
parent
73dfdc539b
commit
74828295fe
@ -36,8 +36,8 @@ namespace xgboost{
|
||||
this->InitNewNode( this->qexpand );
|
||||
stat_max_depth = 0;
|
||||
|
||||
for( int depth = 0; depth < param.max_depth; ++ depth ){
|
||||
//this->FindSplit( this->qexpand );
|
||||
for( int depth = 0; depth < param.max_depth; ++ depth ){
|
||||
this->FindSplit( this->qexpand, depth );
|
||||
this->UpdateQueueExpand( this->qexpand );
|
||||
this->InitNewNode( this->qexpand );
|
||||
// if nothing left to be expand, break
|
||||
@ -51,20 +51,24 @@ namespace xgboost{
|
||||
}
|
||||
// start prunning the tree
|
||||
stat_num_pruned = this->DoPrune();
|
||||
}
|
||||
}
|
||||
private:
|
||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||
snode.resize( tree.param.num_nodes, NodeEntry() );
|
||||
|
||||
for( size_t j = 0; j < qexpand.size(); ++ j ){
|
||||
for( size_t j = 0; j < qexpand.size(); ++j ){
|
||||
const int nid = qexpand[ j ];
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
// TODO: get sum statistics for nid
|
||||
|
||||
for( bst_uint i = node_bound[nid].first; i < node_bound[nid].second; ++i ){
|
||||
const bst_uint ridx = row_index_set[i];
|
||||
sum_grad += grad[ridx]; sum_hess += hess[ridx];
|
||||
}
|
||||
// update node statistics
|
||||
snode[nid].sum_grad = sum_grad;
|
||||
snode[nid].sum_hess = sum_hess;
|
||||
|
||||
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
|
||||
if( !tree[nid].is_root() ){
|
||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight );
|
||||
@ -73,10 +77,182 @@ namespace xgboost{
|
||||
}
|
||||
}
|
||||
}
|
||||
// find splits at current level
|
||||
inline void FindSplit( int nid ){
|
||||
// TODO
|
||||
private:
|
||||
// enumerate the split values of specific feature
|
||||
template<typename Iter>
|
||||
inline void EnumerateSplit( Iter it, SplitEntry &best, const int nid, const unsigned fid, bool is_forward_search ){
|
||||
float last_fvalue = 0.0f;
|
||||
double sum_hess = 0.0, sum_grad = 0.0;
|
||||
const NodeEntry enode = snode[ nid ];
|
||||
|
||||
while( it.Next() ){
|
||||
const bst_uint ridx = it.rindex();
|
||||
const float fvalue = it.fvalue();
|
||||
|
||||
if( sum_hess == 0.0 ){
|
||||
sum_grad = grad[ ridx ];
|
||||
sum_hess = hess[ ridx ];
|
||||
last_fvalue = fvalue;
|
||||
}else{
|
||||
// try to find a split
|
||||
if( fabsf(fvalue - last_fvalue) > rt_2eps && sum_hess >= param.min_child_weight ){
|
||||
const double csum_hess = enode.sum_hess - sum_hess;
|
||||
if( csum_hess >= param.min_child_weight ){
|
||||
const double csum_grad = enode.sum_grad - sum_grad;
|
||||
const double loss_chg =
|
||||
+ param.CalcGain( sum_grad, sum_hess, enode.weight )
|
||||
+ param.CalcGain( csum_grad, csum_hess, enode.weight )
|
||||
- enode.root_gain;
|
||||
best.Update( loss_chg, fid, (fvalue + last_fvalue) * 0.5f, !is_forward_search );
|
||||
}else{
|
||||
// the rest part doesn't meet split condition anyway, return
|
||||
return;
|
||||
}
|
||||
}
|
||||
// update the statistics
|
||||
sum_grad += grad[ ridx ];
|
||||
sum_hess += hess[ ridx ];
|
||||
last_fvalue = fvalue;
|
||||
}
|
||||
}
|
||||
|
||||
const double csum_hess = enode.sum_hess - sum_hess;
|
||||
if( sum_hess >= param.min_child_weight && csum_hess >= param.min_child_weight ){
|
||||
const double csum_grad = enode.sum_grad - sum_grad;
|
||||
const double loss_chg =
|
||||
+ param.CalcGain( sum_grad, sum_hess, enode.weight )
|
||||
+ param.CalcGain( csum_grad, csum_hess, enode.weight )
|
||||
- snode[nid].root_gain;
|
||||
const float delta = is_forward_search ? rt_eps:-rt_eps;
|
||||
best.Update( loss_chg, fid, last_fvalue + delta, !is_forward_search );
|
||||
}
|
||||
}
|
||||
private:
|
||||
inline void FindSplit( const std::vector<int> &qexpand, int depth ){
|
||||
int nexpand = (int)qexpand.size();
|
||||
if( depth < 3 ){
|
||||
for( int i = 0; i < nexpand; ++ i ){
|
||||
this->FindSplit( qexpand[i], tmp_rptr[0] );
|
||||
}
|
||||
}else{
|
||||
// if get to enough depth, parallelize over node
|
||||
#pragma omp parallel for schedule(dynamic,1)
|
||||
for( int i = 0; i < nexpand; ++ i ){
|
||||
const int tid = omp_get_thread_num();
|
||||
utils::Assert( tid < (int)tmp_rptr.size(), "BUG: FindSplit, tid exceed tmp_rptr size" );
|
||||
this->FindSplit( qexpand[i], tmp_rptr[tid] );
|
||||
}
|
||||
}
|
||||
}
|
||||
private:
|
||||
inline void MakeSplit( int nid, unsigned gid ){
|
||||
node_bound.resize( tree.param.num_nodes );
|
||||
// re-organize the row_index_set after split on nid
|
||||
const unsigned split_index = tree[nid].split_index();
|
||||
const float split_value = tree[nid].split_cond();
|
||||
|
||||
std::vector<bst_uint> right;
|
||||
bst_uint top = node_bound[nid].first;
|
||||
for( bst_uint i = node_bound[ nid ].first; i < node_bound[ nid ].second; ++i ){
|
||||
const bst_uint ridx = row_index_set[i];
|
||||
bool goleft = tree[ nid ].default_left();
|
||||
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
||||
if( it.findex() == split_index ){
|
||||
if( it.fvalue() < split_value ){
|
||||
goleft = true; break;
|
||||
}else{
|
||||
goleft = false; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if( goleft ) {
|
||||
row_index_set[ top ++ ] = ridx;
|
||||
}else{
|
||||
right.push_back( ridx );
|
||||
}
|
||||
}
|
||||
node_bound[ tree[nid].cleft() ] = std::make_pair( node_bound[nid].first, top );
|
||||
node_bound[ tree[nid].cright() ] = std::make_pair( top, node_bound[nid].second );
|
||||
|
||||
utils::Assert( node_bound[nid].second - top == (bst_uint)right.size(), "BUG:MakeSplit" );
|
||||
for( size_t i = 0; i < right.size(); ++ i ){
|
||||
row_index_set[ top ++ ] = right[ i ];
|
||||
}
|
||||
}
|
||||
|
||||
// find splits at current level
|
||||
inline void FindSplit( int nid, std::vector<size_t> &tmp_rptr ){
|
||||
if( tmp_rptr.size() == 0 ){
|
||||
tmp_rptr.resize( tree.param.num_feature + 1, 0 );
|
||||
}
|
||||
const bst_uint begin = node_bound[ nid ].first;
|
||||
const bst_uint end = node_bound[ nid ].second;
|
||||
const unsigned ncgroup = smat.NumColGroup();
|
||||
unsigned best_group = 0;
|
||||
|
||||
for( unsigned gid = 0; gid < ncgroup; ++gid ){
|
||||
// records the columns
|
||||
std::vector<FMatrixS::REntry> centry;
|
||||
// records the active features
|
||||
std::vector<size_t> aclist;
|
||||
utils::SparseCSRMBuilder<FMatrixS::REntry,true> builder( tmp_rptr, centry, aclist );
|
||||
builder.InitBudget( tree.param.num_feature );
|
||||
for( bst_uint i = begin; i < end; ++i ){
|
||||
const bst_uint ridx = row_index_set[i];
|
||||
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
||||
builder.AddBudget( it.findex() );
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
for( bst_uint i = begin; i < end; ++i ){
|
||||
const bst_uint ridx = row_index_set[i];
|
||||
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
||||
builder.PushElem( it.findex(), FMatrixS::REntry( ridx, it.fvalue() ) );
|
||||
}
|
||||
}
|
||||
// --- end of building column major matrix ---
|
||||
// after this point, tmp_rptr and entry is ready to use
|
||||
int naclist = (int)aclist.size();
|
||||
// best entry for each thread
|
||||
SplitEntry nbest, tbest;
|
||||
#pragma omp parallel private(tbest)
|
||||
{
|
||||
#pragma omp for schedule(dynamic,1)
|
||||
for( int j = 0; j < naclist; ++j ){
|
||||
bst_uint findex = static_cast<bst_uint>( aclist[j] );
|
||||
// local sort can be faster when the features are sparse
|
||||
std::sort( centry.begin() + tmp_rptr[findex], centry.begin() + tmp_rptr[findex+1], FMatrixS::REntry::cmp_fvalue );
|
||||
if( param.need_forward_search() ){
|
||||
this->EnumerateSplit( FMatrixS::ColIter( ¢ry[tmp_rptr[findex]]-1, ¢ry[tmp_rptr[findex+1]] - 1 ),
|
||||
tbest, nid, findex, true );
|
||||
}
|
||||
if( param.need_backward_search() ){
|
||||
this->EnumerateSplit( FMatrixS::ColBackIter( ¢ry[tmp_rptr[findex+1]], ¢ry[tmp_rptr[findex]] ),
|
||||
tbest, nid, findex, false );
|
||||
}
|
||||
}
|
||||
#pragma omp critical
|
||||
{
|
||||
nbest.Update( tbest );
|
||||
}
|
||||
}
|
||||
// if current solution gives the best
|
||||
if( snode[nid].best.Update( nbest ) ){
|
||||
best_group = gid;
|
||||
}
|
||||
// cleanup tmp_rptr for next usage
|
||||
builder.Cleanup();
|
||||
}
|
||||
|
||||
// at this point, we already know the best split
|
||||
if( snode[nid].best.loss_chg > rt_eps ){
|
||||
const SplitEntry &e = snode[nid].best;
|
||||
tree.AddChilds( nid );
|
||||
tree[ nid ].set_split( e.split_index(), e.split_value, e.default_left() );
|
||||
this->MakeSplit( nid, best_group );
|
||||
}else{
|
||||
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
|
||||
}
|
||||
}
|
||||
private:
|
||||
// initialize temp data structure
|
||||
@ -121,9 +297,9 @@ namespace xgboost{
|
||||
{
|
||||
this->nthread = omp_get_num_threads();
|
||||
}
|
||||
tmp_rptr.resize( this->nthread, std::vector<size_t>() );
|
||||
snode.reserve( 256 );
|
||||
}
|
||||
|
||||
}
|
||||
{// expand query
|
||||
qexpand.reserve( 256 ); qexpand.clear();
|
||||
for( int i = 0; i < tree.param.num_roots; ++ i ){
|
||||
@ -134,6 +310,8 @@ namespace xgboost{
|
||||
private:
|
||||
// number of omp thread used during training
|
||||
int nthread;
|
||||
// tmp row pointer, per thread, used for tmp data construction
|
||||
std::vector< std::vector<size_t> > tmp_rptr;
|
||||
// Instance row indexes corresponding to each node
|
||||
std::vector<bst_uint> row_index_set;
|
||||
// lower and upper bound of each nodes' row_index
|
||||
|
||||
@ -86,6 +86,20 @@ namespace xgboost{
|
||||
* \return row iterator
|
||||
*/
|
||||
inline RowIter GetRow( size_t ridx ) const;
|
||||
/*!
|
||||
* \brief get number of column groups, this ise used together with GetRow( ridx, gid )
|
||||
* \return number of column group
|
||||
*/
|
||||
inline unsigned NumColGroup( void ) const{
|
||||
return 1;
|
||||
}
|
||||
/*!
|
||||
* \brief get row iterator, return iterator of specific column group
|
||||
* \param ridx row index
|
||||
* \param gid colmun group id
|
||||
* \return row iterator, only iterates over features of specified column group
|
||||
*/
|
||||
inline RowIter GetRow( size_t ridx, unsigned gid ) const;
|
||||
|
||||
/*! \return whether column access is enabled */
|
||||
inline bool HaveColAccess( void ) const;
|
||||
@ -140,6 +154,8 @@ namespace xgboost{
|
||||
/*! \brief row iterator */
|
||||
struct RowIter{
|
||||
const REntry *dptr_, *end_;
|
||||
RowIter( const REntry* dptr, const REntry* end )
|
||||
:dptr_(dptr),end_(end){}
|
||||
inline bool Next( void ){
|
||||
if( dptr_ == end_ ) return false;
|
||||
else{
|
||||
@ -155,12 +171,16 @@ namespace xgboost{
|
||||
};
|
||||
/*! \brief column iterator */
|
||||
struct ColIter: public RowIter{
|
||||
ColIter( const REntry* dptr, const REntry* end )
|
||||
:RowIter( dptr, end ){}
|
||||
inline bst_uint rindex( void ) const{
|
||||
return this->findex();
|
||||
}
|
||||
};
|
||||
/*! \brief reverse column iterator */
|
||||
struct ColBackIter: public ColIter{
|
||||
ColBackIter( const REntry* dptr, const REntry* end )
|
||||
:ColIter( dptr, end ){}
|
||||
// shadows RowIter::Next
|
||||
inline bool Next( void ){
|
||||
if( dptr_ == end_ ) return false;
|
||||
@ -223,10 +243,12 @@ namespace xgboost{
|
||||
/*! \brief get row iterator*/
|
||||
inline RowIter GetRow( size_t ridx ) const{
|
||||
utils::Assert( !bst_debug || ridx < this->NumRow(), "row id exceed bound" );
|
||||
RowIter it;
|
||||
it.dptr_ = &row_data_[ row_ptr_[ridx] ] - 1;
|
||||
it.end_ = &row_data_[ row_ptr_[ridx+1] ] - 1;
|
||||
return it;
|
||||
return RowIter( &row_data_[ row_ptr_[ridx] ] - 1, &row_data_[ row_ptr_[ridx+1] ] - 1 );
|
||||
}
|
||||
/*! \brief get row iterator*/
|
||||
inline RowIter GetRow( size_t ridx, unsigned gid ) const{
|
||||
utils::Assert( gid == 0, "FMatrixS only have 1 column group" );
|
||||
return FMatrixS::GetRow( ridx );
|
||||
}
|
||||
public:
|
||||
/*! \return whether column access is enabled */
|
||||
@ -241,18 +263,12 @@ namespace xgboost{
|
||||
/*! \brief get col iterator*/
|
||||
inline ColIter GetSortedCol( size_t cidx ) const{
|
||||
utils::Assert( !bst_debug || cidx < this->NumCol(), "col id exceed bound" );
|
||||
ColIter it;
|
||||
it.dptr_ = &col_data_[ col_ptr_[cidx] ] - 1;
|
||||
it.end_ = &col_data_[ col_ptr_[cidx+1] ] - 1;
|
||||
return it;
|
||||
return ColIter( &col_data_[ col_ptr_[cidx] ] - 1, &col_data_[ col_ptr_[cidx+1] ] - 1 );
|
||||
}
|
||||
/*! \brief get col iterator */
|
||||
inline ColBackIter GetReverseSortedCol( size_t cidx ) const{
|
||||
utils::Assert( !bst_debug || cidx < this->NumCol(), "col id exceed bound" );
|
||||
ColBackIter it;
|
||||
it.dptr_ = &col_data_[ col_ptr_[cidx+1] ];
|
||||
it.end_ = &col_data_[ col_ptr_[cidx] ];
|
||||
return it;
|
||||
return ColBackIter( &col_data_[ col_ptr_[cidx+1] ], &col_data_[ col_ptr_[cidx] ] );
|
||||
}
|
||||
/*!
|
||||
* \brief intialize the data so that we have both column and row major
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user