try interact mode

This commit is contained in:
tqchen
2014-03-05 15:28:53 -08:00
parent 2bdcad9630
commit ef5a389ecf
7 changed files with 142 additions and 24 deletions

View File

@@ -34,7 +34,10 @@ namespace xgboost{
class RegTreeTrainer : public InterfaceBooster<FMatrix>{
public:
RegTreeTrainer( void ){
silent = 0; tree_maker = 1;
silent = 0; tree_maker = 1;
// interact mode
interact_type = 0;
interact_node = 0;
// normally we won't have more than 64 OpenMP threads
threadtemp.resize( 64, ThreadEntry() );
}
@@ -43,6 +46,16 @@ namespace xgboost{
virtual void SetParam( const char *name, const char *val ){
if( !strcmp( name, "silent") ) silent = atoi( val );
if( !strcmp( name, "tree_maker") ) tree_maker = atoi( val );
if( !strncmp( name, "interact:", 9) ){
const char *ename = name + 9;
interact_node = atoi( val );
if( !strcmp( ename, "expand") ) {
interact_type = 1;
}
if( !strcmp( ename, "remove") ) {
interact_type = 2;
}
}
param.SetParam( name, val );
tree.param.SetParam( name, val );
}
@@ -61,6 +74,16 @@ namespace xgboost{
const FMatrix &smat,
const std::vector<unsigned> &root_index ){
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
// interactive update
if( interact_type != 0 ){
switch( interact_type ){
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
case 2:
default: utils::Error("unknown interact type");
}
}
if( !silent ){
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
}
@@ -135,10 +158,38 @@ namespace xgboost{
tree.DumpModel( fo, fmap, with_stats );
}
private:
inline void ExpandNode( std::vector<float> &grad,
std::vector<float> &hess,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
int nid ){
std::vector<bst_uint> valid_index;
for( size_t i = 0; i < grad.size(); i ++ ){
ThreadEntry &e = this->InitTmp();
this->PrepareTmp( fmat.GetRow(i), e );
unsigned rtidx = root_index.size() == 0 ? 0 : root_index[i];
int pid = this->GetLeafIndex( e.feat, e.funknown, rtidx );
this->DropTmp( fmat.GetRow(i), e );
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
}
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
bool success = maker.Expand( valid_index, nid );
if( !silent ){
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth );
}
}
private:
// silent
int silent;
int tree_maker;
RegTree tree;
TreeParamTrain param;
private:
// some training parameters
// tree maker
int tree_maker;
// interaction
int interact_type;
int interact_node;
private:
struct ThreadEntry{
std::vector<float> feat;