add feature constraint
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "xgboost_tree_model.h"
|
||||
#include "../../utils/xgboost_omp.h"
|
||||
#include "../../utils/xgboost_random.h"
|
||||
#include "../../utils/xgboost_fmap.h"
|
||||
#include "xgboost_base_treemaker.hpp"
|
||||
|
||||
namespace xgboost{
|
||||
@@ -23,10 +24,11 @@ namespace xgboost{
|
||||
const std::vector<float> &grad,
|
||||
const std::vector<float> &hess,
|
||||
const FMatrix &smat,
|
||||
const std::vector<unsigned> &root_index )
|
||||
const std::vector<unsigned> &root_index,
|
||||
const utils::FeatConstrain &constrain )
|
||||
: BaseTreeMaker( tree, param ),
|
||||
grad(grad), hess(hess),
|
||||
smat(smat), root_index(root_index) {
|
||||
smat(smat), root_index(root_index), constrain(constrain) {
|
||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||
@@ -282,10 +284,10 @@ namespace xgboost{
|
||||
{// initialize feature index
|
||||
int ncol = static_cast<int>( smat.NumCol() );
|
||||
for( int i = 0; i < ncol; i ++ ){
|
||||
if( smat.GetSortedCol(i).Next() ){
|
||||
if( smat.GetSortedCol(i).Next() && constrain.NotBanned(i) ){
|
||||
feat_index.push_back( i );
|
||||
}
|
||||
}
|
||||
}
|
||||
random::Shuffle( feat_index );
|
||||
}
|
||||
{// setup temp space for each thread
|
||||
@@ -326,6 +328,7 @@ namespace xgboost{
|
||||
const std::vector<float> &hess;
|
||||
const FMatrix &smat;
|
||||
const std::vector<unsigned> &root_index;
|
||||
const utils::FeatConstrain &constrain;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "xgboost_tree_model.h"
|
||||
#include "../../utils/xgboost_omp.h"
|
||||
#include "../../utils/xgboost_random.h"
|
||||
#include "../../utils/xgboost_fmap.h"
|
||||
#include "xgboost_base_treemaker.hpp"
|
||||
|
||||
namespace xgboost{
|
||||
@@ -23,10 +24,11 @@ namespace xgboost{
|
||||
const std::vector<float> &grad,
|
||||
const std::vector<float> &hess,
|
||||
const FMatrix &smat,
|
||||
const std::vector<unsigned> &root_index )
|
||||
const std::vector<unsigned> &root_index,
|
||||
const utils::FeatConstrain &constrain )
|
||||
: BaseTreeMaker( tree, param ),
|
||||
grad(grad), hess(hess),
|
||||
smat(smat), root_index(root_index) {
|
||||
smat(smat), root_index(root_index), constrain(constrain) {
|
||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||
@@ -254,14 +256,18 @@ namespace xgboost{
|
||||
for( bst_uint i = begin; i < end; ++i ){
|
||||
const bst_uint ridx = row_index_set[i];
|
||||
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
||||
builder.AddBudget( it.findex() );
|
||||
const bst_uint findex = it.findex();
|
||||
if( constrain.NotBanned( findex ) ) builder.AddBudget( findex );
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
for( bst_uint i = begin; i < end; ++i ){
|
||||
const bst_uint ridx = row_index_set[i];
|
||||
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
||||
builder.PushElem( it.findex(), FMatrixS::REntry( ridx, it.fvalue() ) );
|
||||
const bst_uint findex = it.findex();
|
||||
if( constrain.NotBanned( findex ) ) {
|
||||
builder.PushElem( findex, FMatrixS::REntry( ridx, it.fvalue() ) );
|
||||
}
|
||||
}
|
||||
}
|
||||
// --- end of building column major matrix ---
|
||||
@@ -373,6 +379,7 @@ namespace xgboost{
|
||||
const std::vector<float> &hess;
|
||||
const FMatrix &smat;
|
||||
const std::vector<unsigned> &root_index;
|
||||
const utils::FeatConstrain &constrain;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace xgboost{
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
#include "../../utils/xgboost_fmap.h"
|
||||
#include "xgboost_svdf_tree.hpp"
|
||||
#include "xgboost_col_treemaker.hpp"
|
||||
#include "xgboost_row_treemaker.hpp"
|
||||
@@ -57,6 +57,7 @@ namespace xgboost{
|
||||
}
|
||||
}
|
||||
param.SetParam( name, val );
|
||||
constrain.SetParam( name, val );
|
||||
tree.param.SetParam( name, val );
|
||||
}
|
||||
virtual void LoadModel( utils::IStream &fi ){
|
||||
@@ -90,17 +91,18 @@ namespace xgboost{
|
||||
int num_pruned;
|
||||
switch( tree_maker ){
|
||||
case 0: {
|
||||
utils::Assert( !constrain.HasConstrain(), "tree maker 0 does not support constrain" );
|
||||
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
|
||||
tree.param.max_depth = updater.do_boost( num_pruned );
|
||||
break;
|
||||
}
|
||||
case 1:{
|
||||
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
|
||||
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
|
||||
maker.Make( tree.param.max_depth, num_pruned );
|
||||
break;
|
||||
}
|
||||
case 2:{
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
|
||||
maker.Make( tree.param.max_depth, num_pruned );
|
||||
break;
|
||||
}
|
||||
@@ -178,7 +180,7 @@ namespace xgboost{
|
||||
}
|
||||
this->DropTmp( fmat.GetRow(i), e );
|
||||
}
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
|
||||
maker.Collapse( valid_index, nid );
|
||||
if( !silent ){
|
||||
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
|
||||
@@ -198,7 +200,7 @@ namespace xgboost{
|
||||
this->DropTmp( fmat.GetRow(i), e );
|
||||
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
|
||||
}
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
|
||||
bool success = maker.Expand( valid_index, nid );
|
||||
if( !silent ){
|
||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
|
||||
@@ -215,7 +217,9 @@ namespace xgboost{
|
||||
int tree_maker;
|
||||
// interaction
|
||||
int interact_type;
|
||||
int interact_node;
|
||||
int interact_node;
|
||||
// feature constrain
|
||||
utils::FeatConstrain constrain;
|
||||
private:
|
||||
struct ThreadEntry{
|
||||
std::vector<float> feat;
|
||||
|
||||
Reference in New Issue
Block a user