modify tree so that training is standalone

This commit is contained in:
tqchen 2014-02-26 16:03:00 -08:00
parent 2c6922f432
commit 4a612eb3ba
8 changed files with 197 additions and 287 deletions

View File

@ -1,6 +1,6 @@
export CC = gcc export CC = gcc
export CXX = g++ export CXX = g++
export CFLAGS = -Wall -O3 -msse2 export CFLAGS = -Wall -O3 -msse2 -fopenmp
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost

View File

@ -3,7 +3,7 @@
/*! /*!
* \file xgboost_linear.h * \file xgboost_linear.h
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net * \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
* the update rule is coordinate descent * the update rule is coordinate descent, require column major format
* \author Tianqi Chen: tianqi.tchen@gmail.com * \author Tianqi Chen: tianqi.tchen@gmail.com
*/ */
#include <vector> #include <vector>
@ -11,7 +11,6 @@
#include "../xgboost.h" #include "../xgboost.h"
#include "../../utils/xgboost_utils.h" #include "../../utils/xgboost_utils.h"
#include "../../utils/xgboost_matrix_csr.h"
namespace xgboost{ namespace xgboost{
namespace booster{ namespace booster{
@ -41,7 +40,7 @@ namespace xgboost{
const FMatrixS &smat, const FMatrixS &smat,
const std::vector<unsigned> &root_index ){ const std::vector<unsigned> &root_index ){
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" ); utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
this->Update( smat, grad, hess ); this->UpdateWeights( grad, hess, smat );
} }
virtual float Predict( const FMatrixS::Line &sp, unsigned rid = 0 ){ virtual float Predict( const FMatrixS::Line &sp, unsigned rid = 0 ){
float sum = model.bias(); float sum = model.bias();
@ -149,29 +148,17 @@ namespace xgboost{
return weight.back(); return weight.back();
} }
}; };
/*! \brief array entry for column based feature construction */
struct SCEntry{
/*! \brief feature value */
float fvalue;
/*! \brief row index related to each row */
unsigned rindex;
/*! \brief default constructor */
SCEntry( void ){}
/*! \brief constructor using entry */
SCEntry( float fvalue, unsigned rindex ){
this->fvalue = fvalue; this->rindex = rindex;
}
};
private: private:
int silent; int silent;
protected: protected:
Model model; Model model;
ParamTrain param; ParamTrain param;
protected: protected:
// update weights, should work for any FMatrix
template<typename FMatrix>
inline void UpdateWeights( std::vector<float> &grad, inline void UpdateWeights( std::vector<float> &grad,
const std::vector<float> &hess, const std::vector<float> &hess,
const std::vector<size_t> &rptr, const FMatrix &smat ){
const std::vector<SCEntry> &entry ){
{// optimize bias {// optimize bias
double sum_grad = 0.0, sum_hess = 0.0; double sum_grad = 0.0, sum_hess = 0.0;
for( size_t i = 0; i < grad.size(); i ++ ){ for( size_t i = 0; i < grad.size(); i ++ ){
@ -187,70 +174,25 @@ namespace xgboost{
} }
// optimize weight // optimize weight
const int nfeat = model.param.num_feature; const unsigned nfeat= (unsigned)smat.NumCol();
for( int i = 0; i < nfeat; i ++ ){ for( unsigned i = 0; i < nfeat; i ++ ){
size_t start = rptr[i]; if( !smat.GetSortedCol( i ).Next() ) continue;
size_t end = rptr[i+1];
if( start >= end ) continue;
double sum_grad = 0.0, sum_hess = 0.0; double sum_grad = 0.0, sum_hess = 0.0;
for( size_t j = start; j < end; j ++ ){ for( typename FMatrix::ColIter it = smat.GetSortedCol(i); it.Next(); ){
const float v = entry[j].fvalue; const float v = it.fvalue();
sum_grad += grad[ entry[j].rindex ] * v; sum_grad += grad[ it.rindex() ] * v;
sum_hess += hess[ entry[j].rindex ] * v * v; sum_hess += hess[ it.rindex() ] * v * v;
} }
float w = model.weight[ i ]; float w = model.weight[ i ];
double dw = param.learning_rate * param.CalcDelta( sum_grad, sum_hess, w ); double dw = param.learning_rate * param.CalcDelta( sum_grad, sum_hess, w );
model.weight[ i ] += dw; model.weight[ i ] += dw;
// update grad value // update grad value
for( size_t j = start; j < end; j ++ ){ for( typename FMatrix::ColIter it = smat.GetSortedCol(i); it.Next(); ){
const float v = entry[j].fvalue; const float v = it.fvalue();
grad[ entry[j].rindex ] += hess[ entry[j].rindex ] * v * dw; grad[ it.rindex() ] += hess[ it.rindex() ] * v * dw;
} }
} }
} }
inline void MakeCmajor( std::vector<size_t> &rptr,
std::vector<SCEntry> &entry,
const std::vector<float> &hess,
const FMatrixS &smat ){
// transform to column order first
const int nfeat = model.param.num_feature;
// build CSR column major format data
utils::SparseCSRMBuilder<SCEntry> builder( rptr, entry );
builder.InitBudget( nfeat );
for( unsigned i = 0; i < (unsigned)hess.size(); i ++ ){
// skip deleted entries
if( hess[i] < 0.0f ) continue;
// add sparse part budget
FMatrixS::Line sp = smat[ i ];
for( unsigned j = 0; j < sp.len; j ++ ){
if( j == 0 || sp[j-1].findex != sp[j].findex ){
builder.AddBudget( sp[j].findex );
}
}
}
builder.InitStorage();
for( unsigned i = 0; i < (unsigned)hess.size(); i ++ ){
// skip deleted entries
if( hess[i] < 0.0f ) continue;
// add sparse part budget
FMatrixS::Line sp = smat[ i ];
for( unsigned j = 0; j < sp.len; j ++ ){
// skip duplicated terms
if( j == 0 || sp[j-1].findex != sp[j].findex ){
builder.PushElem( sp[j].findex, SCEntry( sp[j].fvalue, i ) );
}
}
}
}
protected:
virtual void Update( const FMatrixS &smat,
std::vector<float> &grad,
const std::vector<float> &hess ){
std::vector<size_t> rptr;
std::vector<SCEntry> entry;
this->MakeCmajor( rptr, entry, hess, smat );
this->UpdateWeights( grad, hess, rptr, entry );
}
}; };
}; };
}; };

View File

@ -2,7 +2,7 @@
#define _XGBOOST_APEX_TREE_HPP_ #define _XGBOOST_APEX_TREE_HPP_
/*! /*!
* \file xgboost_svdf_tree.hpp * \file xgboost_svdf_tree.hpp
* \brief implementation of regression tree, with layerwise support * \brief implementation of regression tree constructor, with layerwise support
* this file is adapted from GBRT implementation in SVDFeature project * this file is adapted from GBRT implementation in SVDFeature project
* \author Tianqi Chen: tqchen@apex.sjtu.edu.cn, tianqi.tchen@gmail.com * \author Tianqi Chen: tqchen@apex.sjtu.edu.cn, tianqi.tchen@gmail.com
*/ */
@ -13,17 +13,6 @@
namespace xgboost{ namespace xgboost{
namespace booster{ namespace booster{
const bool rt_debug = false;
// whether to check bugs
const bool check_bug = false;
const float rt_eps = 1e-5f;
const float rt_2eps = rt_eps * 2.0f;
inline double sqr( double a ){
return a * a;
}
inline void assert_sorted( unsigned *idset, int len ){ inline void assert_sorted( unsigned *idset, int len ){
if( !rt_debug || !check_bug ) return; if( !rt_debug || !check_bug ) return;
for( int i = 1; i < len; i ++ ){ for( int i = 1; i < len; i ++ ){
@ -33,20 +22,6 @@ namespace xgboost{
}; };
namespace booster{ namespace booster{
// node stat used in rtree
struct RTreeNodeStat{
// loss chg caused by current split
float loss_chg;
// weight of current node
float base_weight;
// number of child that is leaf node known up to now
int leaf_child_cnt;
};
// structure of Regression Tree
class RTree: public TreeModel<float,RTreeNodeStat>{
};
// selecter of rtree to find the suitable candidate // selecter of rtree to find the suitable candidate
class RTSelecter{ class RTSelecter{
public: public:
@ -88,7 +63,9 @@ namespace xgboost{
} }
}; };
// updater of rtree, allows the parameters to be stored inside, key solver // updater of rtree, allows the parameters to be stored inside, key solver
template<typename FMatrix>
class RTreeUpdater{ class RTreeUpdater{
protected: protected:
// training task, element of single task // training task, element of single task
@ -128,10 +105,10 @@ namespace xgboost{
// training parameter // training parameter
const TreeParamTrain &param; const TreeParamTrain &param;
// parameters, reference // parameters, reference
RTree &tree; RegTree &tree;
std::vector<float> &grad; std::vector<float> &grad;
std::vector<float> &hess; std::vector<float> &hess;
const FMatrixS &smat; const FMatrix &smat;
const std::vector<unsigned> &group_id; const std::vector<unsigned> &group_id;
private: private:
// maximum depth up to now // maximum depth up to now
@ -158,7 +135,7 @@ namespace xgboost{
inline void try_prune_leaf( int nid, int depth ){ inline void try_prune_leaf( int nid, int depth ){
if( tree[ nid ].is_root() ) return; if( tree[ nid ].is_root() ) return;
int pid = tree[ nid ].parent(); int pid = tree[ nid ].parent();
RTree::NodeStat &s = tree.stat( pid ); RegTree::NodeStat &s = tree.stat( pid );
s.leaf_child_cnt ++; s.leaf_child_cnt ++;
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){ if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
@ -186,7 +163,7 @@ namespace xgboost{
// make split for current task, re-arrange positions in idset // make split for current task, re-arrange positions in idset
inline void make_split( Task tsk, const SCEntry *entry, int num, float loss_chg, double base_weight ){ inline void make_split( Task tsk, const SCEntry *entry, int num, float loss_chg, double base_weight ){
// before split, first prepare statistics // before split, first prepare statistics
RTree::NodeStat &s = tree.stat( tsk.nid ); RegTree::NodeStat &s = tree.stat( tsk.nid );
s.loss_chg = loss_chg; s.loss_chg = loss_chg;
s.leaf_child_cnt = 0; s.leaf_child_cnt = 0;
s.base_weight = static_cast<float>( base_weight ); s.base_weight = static_cast<float>( base_weight );
@ -214,7 +191,7 @@ namespace xgboost{
} }
} }
// get two parts // get two parts
RTree::Node &n = tree[ tsk.nid ]; RegTree::Node &n = tree[ tsk.nid ];
Task def_part( n.default_left() ? n.cleft() : n.cright(), tsk.idset, tsk.len - qset.size(), s.base_weight ); Task def_part( n.default_left() ? n.cleft() : n.cright(), tsk.idset, tsk.len - qset.size(), s.base_weight );
Task spl_part( n.default_left() ? n.cright(): n.cleft() , tsk.idset + def_part.len, qset.size(), s.base_weight ); Task spl_part( n.default_left() ? n.cright(): n.cleft() , tsk.idset + def_part.len, qset.size(), s.base_weight );
// fill back split part // fill back split part
@ -320,9 +297,8 @@ namespace xgboost{
rsum_grad += grad[ ridx ]; rsum_grad += grad[ ridx ];
rsum_hess += hess[ ridx ]; rsum_hess += hess[ ridx ];
FMatrixS::Line sp = smat[ ridx ]; for( typename FMatrix::RowIter it = smat.GetRow(ridx); it.Next(); ){
for( unsigned j = 0; j < sp.len; j ++ ){ builder.AddBudget( it.findex() );
builder.AddBudget( sp[j].findex );
} }
} }
@ -334,10 +310,9 @@ namespace xgboost{
builder.InitStorage(); builder.InitStorage();
for( unsigned i = 0; i < tsk.len; i ++ ){ for( unsigned i = 0; i < tsk.len; i ++ ){
const unsigned ridx = tsk.idset[i]; const unsigned ridx = tsk.idset[i];
FMatrixS::Line sp = smat[ ridx ]; for( typename FMatrix::RowIter it = smat.GetRow(ridx); it.Next(); ){
for( unsigned j = 0; j < sp.len; j ++ ){ builder.PushElem( it.findex(), SCEntry( it.fvalue(), ridx ) );
builder.PushElem( sp[j].findex, SCEntry( sp[j].fvalue, ridx ) ); }
}
} }
// --- end of building column major matrix --- // --- end of building column major matrix ---
// after this point, tmp_rptr and entry is ready to use // after this point, tmp_rptr and entry is ready to use
@ -426,10 +401,10 @@ namespace xgboost{
} }
public: public:
RTreeUpdater( const TreeParamTrain &pparam, RTreeUpdater( const TreeParamTrain &pparam,
RTree &ptree, RegTree &ptree,
std::vector<float> &pgrad, std::vector<float> &pgrad,
std::vector<float> &phess, std::vector<float> &phess,
const FMatrixS &psmat, const FMatrix &psmat,
const std::vector<unsigned> &pgroup_id ): const std::vector<unsigned> &pgroup_id ):
param( pparam ), tree( ptree ), grad( pgrad ), hess( phess ), param( pparam ), tree( ptree ), grad( pgrad ), hess( phess ),
smat( psmat ), group_id( pgroup_id ){ smat( psmat ), group_id( pgroup_id ){
@ -446,113 +421,6 @@ namespace xgboost{
return max_depth; return max_depth;
} }
}; };
class RTreeTrainer : public IBooster{
private:
int silent;
// tree of current shape
RTree tree;
TreeParamTrain param;
private:
std::vector<float> tmp_feat;
std::vector<bool> tmp_funknown;
inline void init_tmpfeat( void ){
if( tmp_feat.size() != (size_t)tree.param.num_feature ){
tmp_feat.resize( tree.param.num_feature );
tmp_funknown.resize( tree.param.num_feature );
std::fill( tmp_funknown.begin(), tmp_funknown.end(), true );
}
}
public:
virtual void SetParam( const char *name, const char *val ){
if( !strcmp( name, "silent") ) silent = atoi( val );
param.SetParam( name, val );
tree.param.SetParam( name, val );
}
virtual void LoadModel( utils::IStream &fi ){
tree.LoadModel( fi );
}
virtual void SaveModel( utils::IStream &fo ) const{
tree.SaveModel( fo );
}
virtual void InitModel( void ){
tree.InitModel();
}
private:
inline int get_next( int pid, float fvalue, bool is_unknown ){
float split_value = tree[ pid ].split_cond();
if( is_unknown ){
if( tree[ pid ].default_left() ) return tree[ pid ].cleft();
else return tree[ pid ].cright();
}else{
if( fvalue < split_value ) return tree[ pid ].cleft();
else return tree[ pid ].cright();
}
}
public:
virtual void DoBoost( std::vector<float> &grad,
std::vector<float> &hess,
const FMatrixS &smat,
const std::vector<unsigned> &group_id ){
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
if( !silent ){
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
}
// start with a id set
RTreeUpdater updater( param, tree, grad, hess, smat, group_id );
int num_pruned;
tree.param.max_depth = updater.do_boost( num_pruned );
if( !silent ){
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
}
}
virtual int GetLeafIndex( const std::vector<float> &feat,
const std::vector<bool> &funknown,
unsigned gid = 0 ){
// start from groups that belongs to current data
int pid = (int)gid;
// tranverse tree
while( !tree[ pid ].is_leaf() ){
unsigned split_index = tree[ pid ].split_index();
pid = this->get_next( pid, feat[ split_index ], funknown[ split_index ] );
}
return pid;
}
virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){
this->init_tmpfeat();
for( unsigned i = 0; i < feat.len; i ++ ){
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
tmp_funknown[ feat[i].findex ] = false;
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
}
int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid );
// set back
for( unsigned i = 0; i < feat.len; i ++ ){
tmp_funknown[ feat[i].findex ] = true;
}
return tree[ pid ].leaf_value();
}
virtual float Predict( const std::vector<float> &feat,
const std::vector<bool> &funknown,
unsigned gid = 0 ){
utils::Assert( feat.size() >= (size_t)tree.param.num_feature,
"input data smaller than num feature" );
int pid = this->GetLeafIndex( feat, funknown, gid );
return tree[ pid ].leaf_value();
}
virtual void DumpModel( FILE *fo ){
tree.DumpModel( fo );
}
public:
RTreeTrainer( void ){ silent = 0; }
virtual ~RTreeTrainer( void ){}
};
}; };
}; };
#endif #endif

View File

@ -431,5 +431,20 @@ namespace xgboost{
} }
}; };
}; };
namespace booster{
/*! \brief node statistics used in regression tree */
struct RTreeNodeStat{
// loss chg caused by current split
float loss_chg;
// weight of current node
float base_weight;
// number of child that is leaf node known up to now
int leaf_child_cnt;
};
/*! \brief most comment structure of regression tree */
class RegTree: public TreeModel<bst_float,RTreeNodeStat>{
};
};
}; };
#endif #endif

View File

@ -11,18 +11,11 @@
#include "../utils/xgboost_utils.h" #include "../utils/xgboost_utils.h"
#include "xgboost_gbmbase.h" #include "xgboost_gbmbase.h"
// implementations of boosters // implementations of boosters
#include "tree/xgboost_svdf_tree.hpp" #include "tree/xgboost_tree.hpp"
#include "linear/xgboost_linear.hpp" #include "linear/xgboost_linear.hpp"
namespace xgboost{ namespace xgboost{
namespace booster{ namespace booster{
/*
* \brief listing the types of boosters
*/
enum BOOSTER_TYPE_LIST{
TREE,
LINEAR,
};
/*! /*!
* \brief create a gradient booster, given type of booster * \brief create a gradient booster, given type of booster
* \param booster_type type of gradient booster, can be used to specify implements * \param booster_type type of gradient booster, can be used to specify implements
@ -30,8 +23,8 @@ namespace xgboost{
*/ */
IBooster *CreateBooster( int booster_type ){ IBooster *CreateBooster( int booster_type ){
switch( booster_type ){ switch( booster_type ){
case TREE: return new RTreeTrainer(); case 0: return new RegTreeTrainer();
case LINEAR: return new LinearBooster(); case 1: return new LinearBooster();
default: utils::Error("unknown booster_type"); return NULL; default: utils::Error("unknown booster_type"); return NULL;
} }
} }

View File

@ -11,6 +11,7 @@
#include <climits> #include <climits>
#include "../utils/xgboost_utils.h" #include "../utils/xgboost_utils.h"
#include "../utils/xgboost_stream.h" #include "../utils/xgboost_stream.h"
#include "../utils/xgboost_matrix_csr.h"
namespace xgboost{ namespace xgboost{
namespace booster{ namespace booster{
@ -66,10 +67,6 @@ namespace xgboost{
inline bst_float fvalue( void ) const; inline bst_float fvalue( void ) const;
}; };
public: public:
/*!
* \brief prepare sorted columns so that GetSortedCol can be called
*/
inline void MakeSortedCol( void );
/*! /*!
* \brief get number of rows * \brief get number of rows
* \return number of rows * \return number of rows
@ -114,13 +111,13 @@ namespace xgboost{
bst_uint findex; bst_uint findex;
/*! \brief feature value */ /*! \brief feature value */
bst_float fvalue; bst_float fvalue;
}; /*! \brief constructor */
/*! \brief one entry in a row */ REntry( void ){}
struct CEntry{ /*! \brief constructor */
/*! \brief row index */ REntry( bst_uint findex, bst_float fvalue ) : findex(findex), fvalue(fvalue){}
bst_uint rindex; inline static bool cmp_fvalue( const REntry &a, const REntry &b ){
/*! \brief feature value */ return a.fvalue < b.fvalue;
bst_float fvalue; }
}; };
/*! \brief one row of sparse feature matrix */ /*! \brief one row of sparse feature matrix */
struct Line{ struct Line{
@ -128,24 +125,31 @@ namespace xgboost{
const REntry *data_; const REntry *data_;
/*! \brief size of the data */ /*! \brief size of the data */
bst_uint len; bst_uint len;
/*! \brief get k-th element */
inline const REntry& operator[]( unsigned i ) const{ inline const REntry& operator[]( unsigned i ) const{
return data_[i]; return data_[i];
} }
}; };
public: /*! \brief row iterator */
struct RowIter{ struct RowIter{
const REntry *dptr, *end; const REntry *dptr_, *end_;
inline bool Next( void ){ inline bool Next( void ){
if( dptr == end ) return false; if( dptr_ == end_ ) return false;
else{ else{
++ dptr; return true; ++ dptr_; return true;
} }
} }
inline bst_uint findex( void ) const{ inline bst_uint findex( void ) const{
return dptr->findex; return dptr_->findex;
} }
inline bst_float fvalue( void ) const{ inline bst_float fvalue( void ) const{
return dptr->fvalue; return dptr_->fvalue;
}
};
/*! \brief column iterator */
struct ColIter: public RowIter{
inline bst_uint rindex( void ) const{
return this->findex();
} }
}; };
public: public:
@ -167,6 +171,8 @@ namespace xgboost{
row_ptr_.clear(); row_ptr_.clear();
row_ptr_.push_back( 0 ); row_ptr_.push_back( 0 );
row_data_.clear(); row_data_.clear();
col_ptr_.clear();
col_data_.clear();
} }
/*! \brief get sparse part of current row */ /*! \brief get sparse part of current row */
inline Line operator[]( size_t sidx ) const{ inline Line operator[]( size_t sidx ) const{
@ -176,14 +182,6 @@ namespace xgboost{
sp.data_ = &row_data_[ row_ptr_[ sidx ] ]; sp.data_ = &row_data_[ row_ptr_[ sidx ] ];
return sp; return sp;
} }
/*! \brief get row iterator*/
inline RowIter GetRow( size_t ridx ) const{
utils::Assert( !bst_debug || ridx < this->NumRow(), "row id exceed bound" );
RowIter it;
it.dptr = &row_data_[ row_ptr_[ridx] ] - 1;
it.dptr = &row_data_[ row_ptr_[ridx+1] ] - 1;
return it;
}
/*! /*!
* \brief add a row to the matrix, with data stored in STL container * \brief add a row to the matrix, with data stored in STL container
* \param findex feature index * \param findex feature index
@ -199,43 +197,124 @@ namespace xgboost{
unsigned cnt = 0; unsigned cnt = 0;
for( size_t i = 0; i < findex.size(); i ++ ){ for( size_t i = 0; i < findex.size(); i ++ ){
if( findex[i] < fstart || findex[i] >= fend ) continue; if( findex[i] < fstart || findex[i] >= fend ) continue;
REntry e; e.findex = findex[i]; e.fvalue = fvalue[i]; row_data_.push_back( REntry( findex[i], fvalue[i] ) );
row_data_.push_back( e );
cnt ++; cnt ++;
} }
row_ptr_.push_back( row_ptr_.back() + cnt ); row_ptr_.push_back( row_ptr_.back() + cnt );
return row_ptr_.size() - 2; return row_ptr_.size() - 2;
} }
/*! \brief get row iterator*/
inline RowIter GetRow( size_t ridx ) const{
utils::Assert( !bst_debug || ridx < this->NumRow(), "row id exceed bound" );
RowIter it;
it.dptr_ = &row_data_[ row_ptr_[ridx] ] - 1;
it.end_ = &row_data_[ row_ptr_[ridx+1] ] - 1;
return it;
}
public: public:
/*! \return whether column access is enabled */
inline bool HaveColAccess( void ) const{
return col_ptr_.size() != 0 && col_data_.size() == row_data_.size();
}
/*! \brief get number of colmuns */
inline size_t NumCol( void ) const{
utils::Assert( this->HaveColAccess() );
return col_ptr_.size() - 1;
}
/*! \brief get col iterator*/
inline ColIter GetSortedCol( size_t cidx ) const{
utils::Assert( !bst_debug || cidx < this->NumCol(), "col id exceed bound" );
ColIter it;
it.dptr_ = &col_data_[ col_ptr_[cidx] ] - 1;
it.end_ = &col_data_[ col_ptr_[cidx+1] ] - 1;
return it;
}
/*!
* \brief intialize the data so that we have both column and row major
* access, call this whenever we need column access
*/
inline void InitData( void ){
utils::SparseCSRMBuilder<REntry> builder( col_ptr_, col_data_ );
builder.InitBudget( 0 );
for( size_t i = 0; i < this->NumRow(); i ++ ){
for( RowIter it = this->GetRow(i); it.Next(); ){
builder.AddBudget( it.findex() );
}
}
builder.InitStorage();
for( size_t i = 0; i < this->NumRow(); i ++ ){
for( RowIter it = this->GetRow(i); it.Next(); ){
builder.PushElem( it.findex(), REntry( (bst_uint)i, it.fvalue() ) );
}
}
// sort columns
unsigned ncol = static_cast<unsigned>( this->NumCol() );
for( unsigned i = 0; i < ncol; i ++ ){
std::sort( &col_data_[ col_ptr_[ i ] ], &col_data_[ col_ptr_[ i+1 ] ], REntry::cmp_fvalue );
}
}
/*! /*!
* \brief save data to binary stream * \brief save data to binary stream
* note: since we have size_t in row_ptr, * note: since we have size_t in ptr,
* the function is not consistent between 64bit and 32bit machine * the function is not consistent between 64bit and 32bit machine
* \param fo output stream * \param fo output stream
*/ */
inline void SaveBinary(utils::IStream &fo ) const{ inline void SaveBinary( utils::IStream &fo ) const{
size_t nrow = this->NumRow(); FMatrixS::SaveBinary( fo, row_ptr_, row_data_ );
fo.Write( &nrow, sizeof(size_t) ); int col_access = this->HaveColAccess() ? 1 : 0;
fo.Write( &row_ptr_[0], row_ptr_.size() * sizeof(size_t) ); fo.Write( &col_access, sizeof(int) );
if( row_data_.size() != 0 ){ if( col_access != 0 ){
fo.Write( &row_data_[0] , row_data_.size() * sizeof(REntry) ); FMatrixS::SaveBinary( fo, col_ptr_, col_data_ );
} }
} }
/*! /*!
* \brief load data from binary stream * \brief load data from binary stream
* note: since we have size_t in row_ptr, * note: since we have size_t in ptr,
* the function is not consistent between 64bit and 32bit machine * the function is not consistent between 64bit and 32bit machin
* \param fi output stream * \param fi input stream
*/ */
inline void LoadBinary( utils::IStream &fi ){ inline void LoadBinary( utils::IStream &fi ){
FMatrixS::LoadBinary( fi, row_ptr_, row_data_ );
int col_access;
fi.Read( &col_access, sizeof(int) );
if( col_access != 0 ){
FMatrixS::LoadBinary( fi, col_ptr_, col_data_ );
}
}
private:
/*!
* \brief save data to binary stream
* \param fo output stream
* \param ptr pointer data
* \param data data content
*/
inline static void SaveBinary( utils::IStream &fo,
const std::vector<size_t> &ptr,
const std::vector<REntry> &data ){
size_t nrow = ptr.size() - 1;
fo.Write( &nrow, sizeof(size_t) );
fo.Write( &ptr[0], ptr.size() * sizeof(size_t) );
if( data.size() != 0 ){
fo.Write( &data[0] , data.size() * sizeof(REntry) );
}
}
/*!
* \brief load data from binary stream
* \param fi input stream
* \param ptr pointer data
* \param data data content
*/
inline static void LoadBinary( utils::IStream &fi,
std::vector<size_t> &ptr,
std::vector<REntry> &data ){
size_t nrow; size_t nrow;
utils::Assert( fi.Read( &nrow, sizeof(size_t) ) != 0, "Load FMatrixS" ); utils::Assert( fi.Read( &nrow, sizeof(size_t) ) != 0, "Load FMatrixS" );
row_ptr_.resize( nrow + 1 ); ptr.resize( nrow + 1 );
utils::Assert( fi.Read( &row_ptr_[0], row_ptr_.size() * sizeof(size_t) ), "Load FMatrixS" ); utils::Assert( fi.Read( &ptr[0], ptr.size() * sizeof(size_t) ), "Load FMatrixS" );
row_data_.resize( row_ptr_.back() ); data.resize( ptr.back() );
if( row_data_.size() != 0 ){ if( data.size() != 0 ){
utils::Assert( fi.Read( &row_data_[0] , row_data_.size() * sizeof(REntry) ) , "Load FMatrixS" ); utils::Assert( fi.Read( &data[0] , data.size() * sizeof(REntry) ) , "Load FMatrixS" );
} }
} }
private: private:
@ -243,6 +322,10 @@ namespace xgboost{
std::vector<size_t> row_ptr_; std::vector<size_t> row_ptr_;
/*! \brief data in the row */ /*! \brief data in the row */
std::vector<REntry> row_data_; std::vector<REntry> row_data_;
/*! \brief column pointer of CSC format */
std::vector<size_t> col_ptr_;
/*! \brief column datas */
std::vector<REntry> col_data_;
}; };
}; };
}; };

View File

@ -65,11 +65,12 @@ namespace xgboost{
labels.push_back( label ); labels.push_back( label );
data.AddRow( findex, fvalue ); data.AddRow( findex, fvalue );
// initialize column support as well
data.InitData();
this->UpdateInfo();
if( !silent ){ if( !silent ){
printf("%ux%u matrix with %lu entries is loaded from %s\n", printf("%ux%u matrix with %lu entries is loaded from %s\n",
(unsigned)labels.size(), num_feature, (unsigned long)data.NumEntry(), fname ); (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname );
} }
fclose(file); fclose(file);
} }
@ -87,10 +88,12 @@ namespace xgboost{
labels.resize( data.NumRow() ); labels.resize( data.NumRow() );
utils::Assert( fs.Read( &labels[0], sizeof(float) * data.NumRow() ) != 0, "DMatrix LoadBinary" ); utils::Assert( fs.Read( &labels[0], sizeof(float) * data.NumRow() ) != 0, "DMatrix LoadBinary" );
fs.Close(); fs.Close();
this->UpdateInfo(); // initialize column support as well
data.InitData();
if( !silent ){ if( !silent ){
printf("%ux%u matrix with %lu entries is loaded from %s\n", printf("%ux%u matrix with %lu entries is loaded from %s\n",
(unsigned)labels.size(), num_feature, (unsigned long)data.NumEntry(), fname ); (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname );
} }
return true; return true;
} }
@ -100,13 +103,16 @@ namespace xgboost{
* \param silent whether print information or not * \param silent whether print information or not
*/ */
inline void SaveBinary( const char* fname, bool silent = false ){ inline void SaveBinary( const char* fname, bool silent = false ){
// initialize column support as well
data.InitData();
utils::FileStream fs( utils::FopenCheck( fname, "wb" ) ); utils::FileStream fs( utils::FopenCheck( fname, "wb" ) );
data.SaveBinary( fs ); data.SaveBinary( fs );
fs.Write( &labels[0], sizeof(float) * data.NumRow() ); fs.Write( &labels[0], sizeof(float) * data.NumRow() );
fs.Close(); fs.Close();
if( !silent ){ if( !silent ){
printf("%ux%u matrix with %lu entries is saved to %s\n", printf("%ux%u matrix with %lu entries is saved to %s\n",
(unsigned)labels.size(), num_feature, (unsigned long)data.NumEntry(), fname ); (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname );
} }
} }
/*! /*!

View File

@ -43,13 +43,13 @@ namespace xgboost{
} }
public: public:
/*! /*!
* \brief step 1: initialize the number of rows in the data * \brief step 1: initialize the number of rows in the data, not necessary exact
* \nrows number of rows in the matrix * \nrows number of rows in the matrix, can be smaller than expected
*/ */
inline void InitBudget( size_t nrows ){ inline void InitBudget( size_t nrows = 0 ){
if( !UseAcList ){ if( !UseAcList ){
rptr.resize( nrows + 1 ); rptr.clear();
std::fill( rptr.begin(), rptr.end(), 0 ); rptr.resize( nrows + 1, 0 );
}else{ }else{
Assert( nrows + 1 == rptr.size(), "rptr must be initialized already" ); Assert( nrows + 1 == rptr.size(), "rptr must be initialized already" );
this->Cleanup(); this->Cleanup();
@ -61,6 +61,9 @@ namespace xgboost{
* \param nelem number of element budget add to this row * \param nelem number of element budget add to this row
*/ */
inline void AddBudget( size_t row_id, size_t nelem = 1 ){ inline void AddBudget( size_t row_id, size_t nelem = 1 ){
if( rptr.size() < row_id + 2 ){
rptr.resize( row_id + 2, 0 );
}
if( UseAcList ){ if( UseAcList ){
if( rptr[ row_id + 1 ] == 0 ) aclist.push_back( row_id ); if( rptr[ row_id + 1 ] == 0 ) aclist.push_back( row_id );
} }