add feature constraint

This commit is contained in:
tqchen
2014-03-19 10:47:56 -07:00
parent d3fe4b26a9
commit 255b1f4043
5 changed files with 90 additions and 51 deletions

View File

@@ -21,7 +21,7 @@ namespace xgboost{
}
};
};
#include "../../utils/xgboost_fmap.h"
#include "xgboost_svdf_tree.hpp"
#include "xgboost_col_treemaker.hpp"
#include "xgboost_row_treemaker.hpp"
@@ -57,6 +57,7 @@ namespace xgboost{
}
}
param.SetParam( name, val );
constrain.SetParam( name, val );
tree.param.SetParam( name, val );
}
virtual void LoadModel( utils::IStream &fi ){
@@ -90,17 +91,18 @@ namespace xgboost{
int num_pruned;
switch( tree_maker ){
case 0: {
utils::Assert( !constrain.HasConstrain(), "tree maker 0 does not support constrain" );
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
tree.param.max_depth = updater.do_boost( num_pruned );
break;
}
case 1:{
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
maker.Make( tree.param.max_depth, num_pruned );
break;
}
case 2:{
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
maker.Make( tree.param.max_depth, num_pruned );
break;
}
@@ -178,7 +180,7 @@ namespace xgboost{
}
this->DropTmp( fmat.GetRow(i), e );
}
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
maker.Collapse( valid_index, nid );
if( !silent ){
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
@@ -198,7 +200,7 @@ namespace xgboost{
this->DropTmp( fmat.GetRow(i), e );
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
}
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
bool success = maker.Expand( valid_index, nid );
if( !silent ){
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
@@ -215,7 +217,9 @@ namespace xgboost{
int tree_maker;
// interaction
int interact_type;
int interact_node;
int interact_node;
// feature constrain
utils::FeatConstrain constrain;
private:
struct ThreadEntry{
std::vector<float> feat;