start add coltree maker
This commit is contained in:
@@ -23,6 +23,7 @@ namespace xgboost{
|
||||
};
|
||||
|
||||
#include "xgboost_svdf_tree.hpp"
|
||||
#include "xgboost_col_treemaker.hpp"
|
||||
|
||||
namespace xgboost{
|
||||
namespace booster{
|
||||
@@ -30,7 +31,9 @@ namespace xgboost{
|
||||
// see RegTreeUpdater
|
||||
class RegTreeTrainer : public IBooster{
|
||||
public:
|
||||
RegTreeTrainer( void ){ silent = 0; }
|
||||
RegTreeTrainer( void ){
|
||||
silent = 0; tree_maker = 0;
|
||||
}
|
||||
virtual ~RegTreeTrainer( void ){}
|
||||
public:
|
||||
virtual void SetParam( const char *name, const char *val ){
|
||||
@@ -51,8 +54,8 @@ namespace xgboost{
|
||||
virtual void DoBoost( std::vector<float> &grad,
|
||||
std::vector<float> &hess,
|
||||
const FMatrixS &smat,
|
||||
const std::vector<unsigned> &group_id ){
|
||||
this->DoBoost_( grad, hess, smat, group_id );
|
||||
const std::vector<unsigned> &root_index ){
|
||||
this->DoBoost_( grad, hess, smat, root_index );
|
||||
}
|
||||
|
||||
virtual int GetLeafIndex( const std::vector<float> &feat,
|
||||
@@ -108,23 +111,28 @@ namespace xgboost{
|
||||
inline void DoBoost_( std::vector<float> &grad,
|
||||
std::vector<float> &hess,
|
||||
const FMatrix &smat,
|
||||
const std::vector<unsigned> &group_id ){
|
||||
const std::vector<unsigned> &root_index ){
|
||||
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
|
||||
if( !silent ){
|
||||
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
|
||||
}
|
||||
// start with a id set
|
||||
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, group_id );
|
||||
int num_pruned;
|
||||
tree.param.max_depth = updater.do_boost( num_pruned );
|
||||
|
||||
if( !silent ){
|
||||
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
||||
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
|
||||
if( tree_maker == 0 ){
|
||||
// start with a id set
|
||||
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
|
||||
int num_pruned;
|
||||
tree.param.max_depth = updater.do_boost( num_pruned );
|
||||
if( !silent ){
|
||||
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
||||
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
|
||||
}
|
||||
}else{
|
||||
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
|
||||
maker.Make();
|
||||
}
|
||||
}
|
||||
private:
|
||||
int silent;
|
||||
int tree_maker;
|
||||
RegTree tree;
|
||||
TreeParamTrain param;
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user