add pathdump

This commit is contained in:
tqchen
2014-02-26 17:08:23 -08:00
parent 4a612eb3ba
commit 733f8ae393
7 changed files with 91 additions and 19 deletions

View File

@@ -57,7 +57,7 @@ namespace xgboost{
virtual int GetLeafIndex( const std::vector<float> &feat,
const std::vector<bool> &funknown,
unsigned gid = 0 ){
unsigned gid = 0 ){
// start from groups that belongs to current data
int pid = (int)gid;
// tranverse tree
@@ -67,18 +67,28 @@ namespace xgboost{
}
return pid;
}
virtual void PredPath( std::vector<int> &path, const FMatrixS::Line &feat, unsigned gid = 0 ){
path.clear();
this->InitTmp();
this->PrepareTmp( feat );
int pid = (int)gid;
path.push_back( pid );
// tranverse tree
while( !tree[ pid ].is_leaf() ){
unsigned split_index = tree[ pid ].split_index();
pid = this->GetNext( pid, tmp_feat[ split_index ], tmp_funknown[ split_index ] );
path.push_back( pid );
}
this->DropTmp( feat );
}
virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){
this->InitTmp();
for( unsigned i = 0; i < feat.len; i ++ ){
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
tmp_funknown[ feat[i].findex ] = false;
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
}
this->PrepareTmp( feat );
int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid );
// set back
for( unsigned i = 0; i < feat.len; i ++ ){
tmp_funknown[ feat[i].findex ] = true;
}
this->DropTmp( feat );
return tree[ pid ].leaf_value();
}
virtual float Predict( const std::vector<float> &feat,
@@ -127,6 +137,18 @@ namespace xgboost{
std::fill( tmp_funknown.begin(), tmp_funknown.end(), true );
}
}
inline void PrepareTmp( const FMatrixS::Line &feat ){
for( unsigned i = 0; i < feat.len; i ++ ){
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
tmp_funknown[ feat[i].findex ] = false;
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
}
}
inline void DropTmp( const FMatrixS::Line &feat ){
for( unsigned i = 0; i < feat.len; i ++ ){
tmp_funknown[ feat[i].findex ] = true;
}
}
inline int GetNext( int pid, float fvalue, bool is_unknown ){
float split_value = tree[ pid ].split_cond();