start add coltree maker

This commit is contained in:
tqchen
2014-02-28 11:44:50 -08:00
parent 82807b3a55
commit b57656902e
5 changed files with 217 additions and 35 deletions

View File

@@ -23,6 +23,7 @@ namespace xgboost{
};
#include "xgboost_svdf_tree.hpp"
#include "xgboost_col_treemaker.hpp"
namespace xgboost{
namespace booster{
@@ -30,7 +31,9 @@ namespace xgboost{
// see RegTreeUpdater
class RegTreeTrainer : public IBooster{
public:
RegTreeTrainer( void ){ silent = 0; }
RegTreeTrainer( void ){
silent = 0; tree_maker = 0;
}
virtual ~RegTreeTrainer( void ){}
public:
virtual void SetParam( const char *name, const char *val ){
@@ -51,8 +54,8 @@ namespace xgboost{
virtual void DoBoost( std::vector<float> &grad,
std::vector<float> &hess,
const FMatrixS &smat,
const std::vector<unsigned> &group_id ){
this->DoBoost_( grad, hess, smat, group_id );
const std::vector<unsigned> &root_index ){
this->DoBoost_( grad, hess, smat, root_index );
}
virtual int GetLeafIndex( const std::vector<float> &feat,
@@ -108,23 +111,28 @@ namespace xgboost{
inline void DoBoost_( std::vector<float> &grad,
std::vector<float> &hess,
const FMatrix &smat,
const std::vector<unsigned> &group_id ){
const std::vector<unsigned> &root_index ){
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
if( !silent ){
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
}
// start with a id set
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, group_id );
int num_pruned;
tree.param.max_depth = updater.do_boost( num_pruned );
if( !silent ){
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
if( tree_maker == 0 ){
// start with a id set
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
int num_pruned;
tree.param.max_depth = updater.do_boost( num_pruned );
if( !silent ){
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
}
}else{
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
maker.Make();
}
}
private:
int silent;
int tree_maker;
RegTree tree;
TreeParamTrain param;
private: