add feature constraint
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "xgboost_tree_model.h"
|
||||
#include "../../utils/xgboost_omp.h"
|
||||
#include "../../utils/xgboost_random.h"
|
||||
#include "../../utils/xgboost_fmap.h"
|
||||
#include "xgboost_base_treemaker.hpp"
|
||||
|
||||
namespace xgboost{
|
||||
@@ -23,10 +24,11 @@ namespace xgboost{
|
||||
const std::vector<float> &grad,
|
||||
const std::vector<float> &hess,
|
||||
const FMatrix &smat,
|
||||
const std::vector<unsigned> &root_index )
|
||||
const std::vector<unsigned> &root_index,
|
||||
const utils::FeatConstrain &constrain )
|
||||
: BaseTreeMaker( tree, param ),
|
||||
grad(grad), hess(hess),
|
||||
smat(smat), root_index(root_index) {
|
||||
smat(smat), root_index(root_index), constrain(constrain) {
|
||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||
@@ -282,10 +284,10 @@ namespace xgboost{
|
||||
{// initialize feature index
|
||||
int ncol = static_cast<int>( smat.NumCol() );
|
||||
for( int i = 0; i < ncol; i ++ ){
|
||||
if( smat.GetSortedCol(i).Next() ){
|
||||
if( smat.GetSortedCol(i).Next() && constrain.NotBanned(i) ){
|
||||
feat_index.push_back( i );
|
||||
}
|
||||
}
|
||||
}
|
||||
random::Shuffle( feat_index );
|
||||
}
|
||||
{// setup temp space for each thread
|
||||
@@ -326,6 +328,7 @@ namespace xgboost{
|
||||
const std::vector<float> &hess;
|
||||
const FMatrix &smat;
|
||||
const std::vector<unsigned> &root_index;
|
||||
const utils::FeatConstrain &constrain;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user