full omp support for regression
This commit is contained in:
@@ -32,7 +32,9 @@ namespace xgboost{
|
||||
class RegTreeTrainer : public IBooster{
|
||||
public:
|
||||
RegTreeTrainer( void ){
|
||||
silent = 0; tree_maker = 1;
|
||||
silent = 0; tree_maker = 1;
|
||||
// normally we won't have more than 64 OpenMP threads
|
||||
threadtemp.resize( 64, ThreadEntry() );
|
||||
}
|
||||
virtual ~RegTreeTrainer( void ){}
|
||||
public:
|
||||
@@ -74,25 +76,25 @@ namespace xgboost{
|
||||
|
||||
virtual void PredPath( std::vector<int> &path, const FMatrixS::Line &feat, unsigned gid = 0 ){
|
||||
path.clear();
|
||||
this->InitTmp();
|
||||
this->PrepareTmp( feat );
|
||||
ThreadEntry &e = this->InitTmp();
|
||||
this->PrepareTmp( feat, e );
|
||||
|
||||
int pid = (int)gid;
|
||||
path.push_back( pid );
|
||||
// tranverse tree
|
||||
while( !tree[ pid ].is_leaf() ){
|
||||
unsigned split_index = tree[ pid ].split_index();
|
||||
pid = this->GetNext( pid, tmp_feat[ split_index ], tmp_funknown[ split_index ] );
|
||||
pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] );
|
||||
path.push_back( pid );
|
||||
}
|
||||
this->DropTmp( feat );
|
||||
this->DropTmp( feat, e );
|
||||
}
|
||||
|
||||
// make it OpenMP thread safe, but not thread safe in general
|
||||
virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){
|
||||
this->InitTmp();
|
||||
this->PrepareTmp( feat );
|
||||
int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid );
|
||||
this->DropTmp( feat );
|
||||
ThreadEntry &e = this->InitTmp();
|
||||
this->PrepareTmp( feat, e );
|
||||
int pid = this->GetLeafIndex( e.feat, e.funknown, gid );
|
||||
this->DropTmp( feat, e );
|
||||
return tree[ pid ].leaf_value();
|
||||
}
|
||||
virtual float Predict( const std::vector<float> &feat,
|
||||
@@ -102,8 +104,7 @@ namespace xgboost{
|
||||
"input data smaller than num feature" );
|
||||
int pid = this->GetLeafIndex( feat, funknown, gid );
|
||||
return tree[ pid ].leaf_value();
|
||||
}
|
||||
|
||||
}
|
||||
virtual void DumpModel( FILE *fo ){
|
||||
tree.DumpModel( fo );
|
||||
}
|
||||
@@ -137,25 +138,34 @@ namespace xgboost{
|
||||
RegTree tree;
|
||||
TreeParamTrain param;
|
||||
private:
|
||||
std::vector<float> tmp_feat;
|
||||
std::vector<bool> tmp_funknown;
|
||||
inline void InitTmp( void ){
|
||||
if( tmp_feat.size() != (size_t)tree.param.num_feature ){
|
||||
tmp_feat.resize( tree.param.num_feature );
|
||||
tmp_funknown.resize( tree.param.num_feature );
|
||||
std::fill( tmp_funknown.begin(), tmp_funknown.end(), true );
|
||||
struct ThreadEntry{
|
||||
std::vector<float> feat;
|
||||
std::vector<bool> funknown;
|
||||
};
|
||||
std::vector<ThreadEntry> threadtemp;
|
||||
private:
|
||||
|
||||
inline ThreadEntry& InitTmp( void ){
|
||||
const int tid = omp_get_thread_num();
|
||||
utils::Assert( tid < (int)threadtemp.size(), "RTreeUpdater: threadtemp pool is too small" );
|
||||
ThreadEntry &e = threadtemp[ tid ];
|
||||
if( e.feat.size() != (size_t)tree.param.num_feature ){
|
||||
e.feat.resize( tree.param.num_feature );
|
||||
e.funknown.resize( tree.param.num_feature );
|
||||
std::fill( e.funknown.begin(), e.funknown.end(), true );
|
||||
}
|
||||
return e;
|
||||
}
|
||||
inline void PrepareTmp( const FMatrixS::Line &feat ){
|
||||
inline void PrepareTmp( const FMatrixS::Line &feat, ThreadEntry &e ){
|
||||
for( unsigned i = 0; i < feat.len; i ++ ){
|
||||
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
|
||||
tmp_funknown[ feat[i].findex ] = false;
|
||||
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
|
||||
utils::Assert( feat[i].findex < (unsigned)tree.param.num_feature , "input feature execeed bound" );
|
||||
e.funknown[ feat[i].findex ] = false;
|
||||
e.feat[ feat[i].findex ] = feat[i].fvalue;
|
||||
}
|
||||
}
|
||||
inline void DropTmp( const FMatrixS::Line &feat ){
|
||||
inline void DropTmp( const FMatrixS::Line &feat, ThreadEntry &e ){
|
||||
for( unsigned i = 0; i < feat.len; i ++ ){
|
||||
tmp_funknown[ feat[i].findex ] = true;
|
||||
e.funknown[ feat[i].findex ] = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,4 +184,3 @@ namespace xgboost{
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -75,7 +75,9 @@ namespace xgboost{
|
||||
}
|
||||
/*!
|
||||
* \brief predict values for given sparse feature vector
|
||||
* NOTE: in tree implementation, this is not threadsafe, used dense version to ensure threadsafety
|
||||
*
|
||||
* NOTE: in tree implementation, Sparse Predict is OpenMP threadsafe, but not threadsafe in general,
|
||||
* dense version of Predict to ensures threadsafety
|
||||
* \param feat vector in sparse format
|
||||
* \param rid root id of current instance, default = 0
|
||||
* \return prediction
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef _XGBOOST_GBMBASE_H_
|
||||
#define _XGBOOST_GBMBASE_H_
|
||||
|
||||
#include <omp.h>
|
||||
#include <cstring>
|
||||
#include "xgboost.h"
|
||||
#include "../utils/xgboost_config.h"
|
||||
@@ -88,6 +89,10 @@ namespace xgboost{
|
||||
}
|
||||
};
|
||||
public:
|
||||
/*! \brief number of thread used */
|
||||
GBMBaseModel( void ){
|
||||
this->nthread = 1;
|
||||
}
|
||||
/*! \brief destructor */
|
||||
virtual ~GBMBaseModel( void ){
|
||||
this->FreeSpace();
|
||||
@@ -104,6 +109,7 @@ namespace xgboost{
|
||||
if( !strcmp( name, "silent") ){
|
||||
cfg.PushBack( name, val );
|
||||
}
|
||||
if( !strcmp( name, "nthread") ) nthread = atoi( val );
|
||||
if( boosters.size() == 0 ) param.SetParam( name, val );
|
||||
}
|
||||
/*!
|
||||
@@ -164,6 +170,9 @@ namespace xgboost{
|
||||
* this function is reserved for solver to allocate necessary space and do other preparation
|
||||
*/
|
||||
inline void InitTrainer( void ){
|
||||
if( nthread != 0 ){
|
||||
omp_set_num_threads( nthread );
|
||||
}
|
||||
// make sure all the boosters get the latest parameters
|
||||
for( size_t i = 0; i < this->boosters.size(); i ++ ){
|
||||
this->ConfigBooster( this->boosters[i] );
|
||||
@@ -312,6 +321,8 @@ namespace xgboost{
|
||||
return boosters.back();
|
||||
}
|
||||
protected:
|
||||
/*! \brief number of OpenMP threads */
|
||||
int nthread;
|
||||
/*! \brief model parameters */
|
||||
Param param;
|
||||
/*! \brief component boosters */
|
||||
|
||||
Reference in New Issue
Block a user