try interact mode
This commit is contained in:
@@ -34,7 +34,10 @@ namespace xgboost{
|
||||
class RegTreeTrainer : public InterfaceBooster<FMatrix>{
|
||||
public:
|
||||
RegTreeTrainer( void ){
|
||||
silent = 0; tree_maker = 1;
|
||||
silent = 0; tree_maker = 1;
|
||||
// interact mode
|
||||
interact_type = 0;
|
||||
interact_node = 0;
|
||||
// normally we won't have more than 64 OpenMP threads
|
||||
threadtemp.resize( 64, ThreadEntry() );
|
||||
}
|
||||
@@ -43,6 +46,16 @@ namespace xgboost{
|
||||
virtual void SetParam( const char *name, const char *val ){
|
||||
if( !strcmp( name, "silent") ) silent = atoi( val );
|
||||
if( !strcmp( name, "tree_maker") ) tree_maker = atoi( val );
|
||||
if( !strncmp( name, "interact:", 9) ){
|
||||
const char *ename = name + 9;
|
||||
interact_node = atoi( val );
|
||||
if( !strcmp( ename, "expand") ) {
|
||||
interact_type = 1;
|
||||
}
|
||||
if( !strcmp( ename, "remove") ) {
|
||||
interact_type = 2;
|
||||
}
|
||||
}
|
||||
param.SetParam( name, val );
|
||||
tree.param.SetParam( name, val );
|
||||
}
|
||||
@@ -61,6 +74,16 @@ namespace xgboost{
|
||||
const FMatrix &smat,
|
||||
const std::vector<unsigned> &root_index ){
|
||||
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
|
||||
|
||||
// interactive update
|
||||
if( interact_type != 0 ){
|
||||
switch( interact_type ){
|
||||
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
|
||||
case 2:
|
||||
default: utils::Error("unknown interact type");
|
||||
}
|
||||
}
|
||||
|
||||
if( !silent ){
|
||||
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
|
||||
}
|
||||
@@ -135,10 +158,38 @@ namespace xgboost{
|
||||
tree.DumpModel( fo, fmap, with_stats );
|
||||
}
|
||||
private:
|
||||
inline void ExpandNode( std::vector<float> &grad,
|
||||
std::vector<float> &hess,
|
||||
const FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
int nid ){
|
||||
std::vector<bst_uint> valid_index;
|
||||
for( size_t i = 0; i < grad.size(); i ++ ){
|
||||
ThreadEntry &e = this->InitTmp();
|
||||
this->PrepareTmp( fmat.GetRow(i), e );
|
||||
unsigned rtidx = root_index.size() == 0 ? 0 : root_index[i];
|
||||
int pid = this->GetLeafIndex( e.feat, e.funknown, rtidx );
|
||||
this->DropTmp( fmat.GetRow(i), e );
|
||||
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
|
||||
}
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||
bool success = maker.Expand( valid_index, nid );
|
||||
if( !silent ){
|
||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth );
|
||||
}
|
||||
}
|
||||
private:
|
||||
// silent
|
||||
int silent;
|
||||
int tree_maker;
|
||||
RegTree tree;
|
||||
TreeParamTrain param;
|
||||
private:
|
||||
// some training parameters
|
||||
// tree maker
|
||||
int tree_maker;
|
||||
// interaction
|
||||
int interact_type;
|
||||
int interact_node;
|
||||
private:
|
||||
struct ThreadEntry{
|
||||
std::vector<float> feat;
|
||||
|
||||
Reference in New Issue
Block a user