full omp support for regression

This commit is contained in:
tqchen
2014-03-01 20:56:25 -08:00
parent 550010e9d2
commit 5cdc38648b
9 changed files with 206 additions and 354 deletions

View File

@@ -32,7 +32,9 @@ namespace xgboost{
class RegTreeTrainer : public IBooster{
public:
RegTreeTrainer( void ){
silent = 0; tree_maker = 1;
silent = 0; tree_maker = 1;
// normally we won't have more than 64 OpenMP threads
threadtemp.resize( 64, ThreadEntry() );
}
virtual ~RegTreeTrainer( void ){}
public:
@@ -74,25 +76,25 @@ namespace xgboost{
virtual void PredPath( std::vector<int> &path, const FMatrixS::Line &feat, unsigned gid = 0 ){
path.clear();
this->InitTmp();
this->PrepareTmp( feat );
ThreadEntry &e = this->InitTmp();
this->PrepareTmp( feat, e );
int pid = (int)gid;
path.push_back( pid );
// tranverse tree
while( !tree[ pid ].is_leaf() ){
unsigned split_index = tree[ pid ].split_index();
pid = this->GetNext( pid, tmp_feat[ split_index ], tmp_funknown[ split_index ] );
pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] );
path.push_back( pid );
}
this->DropTmp( feat );
this->DropTmp( feat, e );
}
// make it OpenMP thread safe, but not thread safe in general
virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){
this->InitTmp();
this->PrepareTmp( feat );
int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid );
this->DropTmp( feat );
ThreadEntry &e = this->InitTmp();
this->PrepareTmp( feat, e );
int pid = this->GetLeafIndex( e.feat, e.funknown, gid );
this->DropTmp( feat, e );
return tree[ pid ].leaf_value();
}
virtual float Predict( const std::vector<float> &feat,
@@ -102,8 +104,7 @@ namespace xgboost{
"input data smaller than num feature" );
int pid = this->GetLeafIndex( feat, funknown, gid );
return tree[ pid ].leaf_value();
}
}
virtual void DumpModel( FILE *fo ){
tree.DumpModel( fo );
}
@@ -137,25 +138,34 @@ namespace xgboost{
RegTree tree;
TreeParamTrain param;
private:
std::vector<float> tmp_feat;
std::vector<bool> tmp_funknown;
inline void InitTmp( void ){
if( tmp_feat.size() != (size_t)tree.param.num_feature ){
tmp_feat.resize( tree.param.num_feature );
tmp_funknown.resize( tree.param.num_feature );
std::fill( tmp_funknown.begin(), tmp_funknown.end(), true );
struct ThreadEntry{
std::vector<float> feat;
std::vector<bool> funknown;
};
std::vector<ThreadEntry> threadtemp;
private:
inline ThreadEntry& InitTmp( void ){
const int tid = omp_get_thread_num();
utils::Assert( tid < (int)threadtemp.size(), "RTreeUpdater: threadtemp pool is too small" );
ThreadEntry &e = threadtemp[ tid ];
if( e.feat.size() != (size_t)tree.param.num_feature ){
e.feat.resize( tree.param.num_feature );
e.funknown.resize( tree.param.num_feature );
std::fill( e.funknown.begin(), e.funknown.end(), true );
}
return e;
}
inline void PrepareTmp( const FMatrixS::Line &feat ){
inline void PrepareTmp( const FMatrixS::Line &feat, ThreadEntry &e ){
for( unsigned i = 0; i < feat.len; i ++ ){
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
tmp_funknown[ feat[i].findex ] = false;
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
utils::Assert( feat[i].findex < (unsigned)tree.param.num_feature , "input feature execeed bound" );
e.funknown[ feat[i].findex ] = false;
e.feat[ feat[i].findex ] = feat[i].fvalue;
}
}
inline void DropTmp( const FMatrixS::Line &feat ){
inline void DropTmp( const FMatrixS::Line &feat, ThreadEntry &e ){
for( unsigned i = 0; i < feat.len; i ++ ){
tmp_funknown[ feat[i].findex ] = true;
e.funknown[ feat[i].findex ] = true;
}
}
@@ -174,4 +184,3 @@ namespace xgboost{
};
#endif