add feature constraint

This commit is contained in:
tqchen
2014-03-19 10:47:56 -07:00
parent d3fe4b26a9
commit 255b1f4043
5 changed files with 90 additions and 51 deletions

View File

@@ -11,6 +11,7 @@
#include "xgboost_tree_model.h"
#include "../../utils/xgboost_omp.h"
#include "../../utils/xgboost_random.h"
#include "../../utils/xgboost_fmap.h"
#include "xgboost_base_treemaker.hpp"
namespace xgboost{
@@ -23,10 +24,11 @@ namespace xgboost{
const std::vector<float> &grad,
const std::vector<float> &hess,
const FMatrix &smat,
const std::vector<unsigned> &root_index )
const std::vector<unsigned> &root_index,
const utils::FeatConstrain &constrain )
: BaseTreeMaker( tree, param ),
grad(grad), hess(hess),
smat(smat), root_index(root_index) {
smat(smat), root_index(root_index), constrain(constrain) {
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
@@ -282,10 +284,10 @@ namespace xgboost{
{// initialize feature index
int ncol = static_cast<int>( smat.NumCol() );
for( int i = 0; i < ncol; i ++ ){
if( smat.GetSortedCol(i).Next() ){
if( smat.GetSortedCol(i).Next() && constrain.NotBanned(i) ){
feat_index.push_back( i );
}
}
}
random::Shuffle( feat_index );
}
{// setup temp space for each thread
@@ -326,6 +328,7 @@ namespace xgboost{
const std::vector<float> &hess;
const FMatrix &smat;
const std::vector<unsigned> &root_index;
const utils::FeatConstrain &constrain;
};
};
};