full omp support for regression
This commit is contained in:
parent
550010e9d2
commit
5cdc38648b
2
Makefile
2
Makefile
@ -11,7 +11,7 @@ all: $(BIN) $(OBJ)
|
||||
export LDFLAGS= -pthread -lm
|
||||
|
||||
xgboost.o: booster/xgboost.h booster/xgboost_data.h booster/xgboost.cpp booster/*/*.hpp booster/*/*.h
|
||||
xgboost: regression/xgboost_reg_main.cpp xgboost.o
|
||||
xgboost: regression/xgboost_reg_main.cpp regression/*.h xgboost.o
|
||||
|
||||
$(BIN) :
|
||||
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||
|
||||
@ -32,7 +32,9 @@ namespace xgboost{
|
||||
class RegTreeTrainer : public IBooster{
|
||||
public:
|
||||
RegTreeTrainer( void ){
|
||||
silent = 0; tree_maker = 1;
|
||||
silent = 0; tree_maker = 1;
|
||||
// normally we won't have more than 64 OpenMP threads
|
||||
threadtemp.resize( 64, ThreadEntry() );
|
||||
}
|
||||
virtual ~RegTreeTrainer( void ){}
|
||||
public:
|
||||
@ -74,25 +76,25 @@ namespace xgboost{
|
||||
|
||||
virtual void PredPath( std::vector<int> &path, const FMatrixS::Line &feat, unsigned gid = 0 ){
|
||||
path.clear();
|
||||
this->InitTmp();
|
||||
this->PrepareTmp( feat );
|
||||
ThreadEntry &e = this->InitTmp();
|
||||
this->PrepareTmp( feat, e );
|
||||
|
||||
int pid = (int)gid;
|
||||
path.push_back( pid );
|
||||
// tranverse tree
|
||||
while( !tree[ pid ].is_leaf() ){
|
||||
unsigned split_index = tree[ pid ].split_index();
|
||||
pid = this->GetNext( pid, tmp_feat[ split_index ], tmp_funknown[ split_index ] );
|
||||
pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] );
|
||||
path.push_back( pid );
|
||||
}
|
||||
this->DropTmp( feat );
|
||||
this->DropTmp( feat, e );
|
||||
}
|
||||
|
||||
// make it OpenMP thread safe, but not thread safe in general
|
||||
virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){
|
||||
this->InitTmp();
|
||||
this->PrepareTmp( feat );
|
||||
int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid );
|
||||
this->DropTmp( feat );
|
||||
ThreadEntry &e = this->InitTmp();
|
||||
this->PrepareTmp( feat, e );
|
||||
int pid = this->GetLeafIndex( e.feat, e.funknown, gid );
|
||||
this->DropTmp( feat, e );
|
||||
return tree[ pid ].leaf_value();
|
||||
}
|
||||
virtual float Predict( const std::vector<float> &feat,
|
||||
@ -102,8 +104,7 @@ namespace xgboost{
|
||||
"input data smaller than num feature" );
|
||||
int pid = this->GetLeafIndex( feat, funknown, gid );
|
||||
return tree[ pid ].leaf_value();
|
||||
}
|
||||
|
||||
}
|
||||
virtual void DumpModel( FILE *fo ){
|
||||
tree.DumpModel( fo );
|
||||
}
|
||||
@ -137,25 +138,34 @@ namespace xgboost{
|
||||
RegTree tree;
|
||||
TreeParamTrain param;
|
||||
private:
|
||||
std::vector<float> tmp_feat;
|
||||
std::vector<bool> tmp_funknown;
|
||||
inline void InitTmp( void ){
|
||||
if( tmp_feat.size() != (size_t)tree.param.num_feature ){
|
||||
tmp_feat.resize( tree.param.num_feature );
|
||||
tmp_funknown.resize( tree.param.num_feature );
|
||||
std::fill( tmp_funknown.begin(), tmp_funknown.end(), true );
|
||||
struct ThreadEntry{
|
||||
std::vector<float> feat;
|
||||
std::vector<bool> funknown;
|
||||
};
|
||||
std::vector<ThreadEntry> threadtemp;
|
||||
private:
|
||||
|
||||
inline ThreadEntry& InitTmp( void ){
|
||||
const int tid = omp_get_thread_num();
|
||||
utils::Assert( tid < (int)threadtemp.size(), "RTreeUpdater: threadtemp pool is too small" );
|
||||
ThreadEntry &e = threadtemp[ tid ];
|
||||
if( e.feat.size() != (size_t)tree.param.num_feature ){
|
||||
e.feat.resize( tree.param.num_feature );
|
||||
e.funknown.resize( tree.param.num_feature );
|
||||
std::fill( e.funknown.begin(), e.funknown.end(), true );
|
||||
}
|
||||
return e;
|
||||
}
|
||||
inline void PrepareTmp( const FMatrixS::Line &feat ){
|
||||
inline void PrepareTmp( const FMatrixS::Line &feat, ThreadEntry &e ){
|
||||
for( unsigned i = 0; i < feat.len; i ++ ){
|
||||
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
|
||||
tmp_funknown[ feat[i].findex ] = false;
|
||||
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
|
||||
utils::Assert( feat[i].findex < (unsigned)tree.param.num_feature , "input feature execeed bound" );
|
||||
e.funknown[ feat[i].findex ] = false;
|
||||
e.feat[ feat[i].findex ] = feat[i].fvalue;
|
||||
}
|
||||
}
|
||||
inline void DropTmp( const FMatrixS::Line &feat ){
|
||||
inline void DropTmp( const FMatrixS::Line &feat, ThreadEntry &e ){
|
||||
for( unsigned i = 0; i < feat.len; i ++ ){
|
||||
tmp_funknown[ feat[i].findex ] = true;
|
||||
e.funknown[ feat[i].findex ] = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,4 +184,3 @@ namespace xgboost{
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@ -75,7 +75,9 @@ namespace xgboost{
|
||||
}
|
||||
/*!
|
||||
* \brief predict values for given sparse feature vector
|
||||
* NOTE: in tree implementation, this is not threadsafe, used dense version to ensure threadsafety
|
||||
*
|
||||
* NOTE: in tree implementation, Sparse Predict is OpenMP threadsafe, but not threadsafe in general,
|
||||
* dense version of Predict to ensures threadsafety
|
||||
* \param feat vector in sparse format
|
||||
* \param rid root id of current instance, default = 0
|
||||
* \return prediction
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#ifndef _XGBOOST_GBMBASE_H_
|
||||
#define _XGBOOST_GBMBASE_H_
|
||||
|
||||
#include <omp.h>
|
||||
#include <cstring>
|
||||
#include "xgboost.h"
|
||||
#include "../utils/xgboost_config.h"
|
||||
@ -88,6 +89,10 @@ namespace xgboost{
|
||||
}
|
||||
};
|
||||
public:
|
||||
/*! \brief number of thread used */
|
||||
GBMBaseModel( void ){
|
||||
this->nthread = 1;
|
||||
}
|
||||
/*! \brief destructor */
|
||||
virtual ~GBMBaseModel( void ){
|
||||
this->FreeSpace();
|
||||
@ -104,6 +109,7 @@ namespace xgboost{
|
||||
if( !strcmp( name, "silent") ){
|
||||
cfg.PushBack( name, val );
|
||||
}
|
||||
if( !strcmp( name, "nthread") ) nthread = atoi( val );
|
||||
if( boosters.size() == 0 ) param.SetParam( name, val );
|
||||
}
|
||||
/*!
|
||||
@ -164,6 +170,9 @@ namespace xgboost{
|
||||
* this function is reserved for solver to allocate necessary space and do other preparation
|
||||
*/
|
||||
inline void InitTrainer( void ){
|
||||
if( nthread != 0 ){
|
||||
omp_set_num_threads( nthread );
|
||||
}
|
||||
// make sure all the boosters get the latest parameters
|
||||
for( size_t i = 0; i < this->boosters.size(); i ++ ){
|
||||
this->ConfigBooster( this->boosters[i] );
|
||||
@ -312,6 +321,8 @@ namespace xgboost{
|
||||
return boosters.back();
|
||||
}
|
||||
protected:
|
||||
/*! \brief number of OpenMP threads */
|
||||
int nthread;
|
||||
/*! \brief model parameters */
|
||||
Param param;
|
||||
/*! \brief component boosters */
|
||||
|
||||
@ -8,7 +8,9 @@
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <omp.h>
|
||||
#include "xgboost_regdata.h"
|
||||
#include "xgboost_regeval.h"
|
||||
#include "../booster/xgboost_gbmbase.h"
|
||||
#include "../utils/xgboost_utils.h"
|
||||
#include "../utils/xgboost_stream.h"
|
||||
@ -16,11 +18,11 @@
|
||||
namespace xgboost{
|
||||
namespace regression{
|
||||
/*! \brief class for gradient boosted regression */
|
||||
class RegBoostLearner{
|
||||
class RegBoostLearner{
|
||||
public:
|
||||
/*! \brief constructor */
|
||||
RegBoostLearner( void ){
|
||||
silent = 0;
|
||||
silent = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief a regression booter associated with training and evaluating data
|
||||
@ -59,6 +61,8 @@ namespace xgboost{
|
||||
printf( "buffer_size=%u\n", buffer_size );
|
||||
}
|
||||
base_model.SetParam( "num_pbuffer",snum_pbuffer );
|
||||
|
||||
this->eval_preds_.resize( evals.size(), std::vector<float>() );
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
@ -66,7 +70,8 @@ namespace xgboost{
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
inline void SetParam( const char *name, const char *val ){
|
||||
if( !strcmp( name, "silent") ) silent = atoi( val );
|
||||
if( !strcmp( name, "silent") ) silent = atoi( val );
|
||||
if( !strcmp( name, "eval_metric") ) evaluator_.AddEval( val );
|
||||
mparam.SetParam( name, val );
|
||||
base_model.SetParam( name, val );
|
||||
}
|
||||
@ -76,6 +81,12 @@ namespace xgboost{
|
||||
*/
|
||||
inline void InitTrainer( void ){
|
||||
base_model.InitTrainer();
|
||||
if( mparam.loss_type == kLogisticClassify ){
|
||||
evaluator_.AddEval( "error" );
|
||||
}else{
|
||||
evaluator_.AddEval( "rmse" );
|
||||
}
|
||||
evaluator_.Init();
|
||||
}
|
||||
/*!
|
||||
* \brief initialize the current data storage for model, if the model is used first time, call this function
|
||||
@ -120,12 +131,10 @@ namespace xgboost{
|
||||
* \param iteration iteration number
|
||||
*/
|
||||
inline void UpdateOneIter( int iter ){
|
||||
std::vector<float> grad, hess, preds;
|
||||
this->Predict( preds, *train_, 0 );
|
||||
this->GetGradient( preds, train_->labels, grad, hess );
|
||||
|
||||
this->PredictBuffer( preds_, *train_, 0 );
|
||||
this->GetGradient( preds_, train_->labels, grad_, hess_ );
|
||||
std::vector<unsigned> root_index;
|
||||
base_model.DoBoost(grad,hess,train_->data,root_index);
|
||||
base_model.DoBoost( grad_, hess_, train_->data, root_index );
|
||||
}
|
||||
/*!
|
||||
* \brief evaluate the model for specific iteration
|
||||
@ -133,13 +142,13 @@ namespace xgboost{
|
||||
* \param fo file to output log
|
||||
*/
|
||||
inline void EvalOneIter( int iter, FILE *fo = stderr ){
|
||||
std::vector<float> preds;
|
||||
fprintf( fo, "[%d]", iter );
|
||||
int buffer_offset = static_cast<int>( train_->Size() );
|
||||
|
||||
for(size_t i = 0; i < evals_.size();i++){
|
||||
this->Predict(preds, *evals_[i], buffer_offset);
|
||||
this->Eval( fo, evname_[i].c_str(), preds, (*evals_[i]).labels );
|
||||
|
||||
for( size_t i = 0; i < evals_.size(); ++i ){
|
||||
std::vector<float> &preds = this->eval_preds_[ i ];
|
||||
this->PredictBuffer( preds, *evals_[i], buffer_offset);
|
||||
evaluator_.Eval( fo, evname_[i].c_str(), preds, (*evals_[i]).labels );
|
||||
buffer_offset += static_cast<int>( evals_[i]->Size() );
|
||||
}
|
||||
fprintf( fo,"\n" );
|
||||
@ -148,23 +157,22 @@ namespace xgboost{
|
||||
/*! \brief get prediction, without buffering */
|
||||
inline void Predict( std::vector<float> &preds, const DMatrix &data ){
|
||||
preds.resize( data.Size() );
|
||||
for( size_t j = 0; j < data.Size(); j++ ){
|
||||
|
||||
const unsigned ndata = static_cast<unsigned>( data.Size() );
|
||||
#pragma omp parallel for schedule( static )
|
||||
for( unsigned j = 0; j < ndata; ++ j ){
|
||||
preds[j] = mparam.PredTransform
|
||||
( mparam.base_score + base_model.Predict( data.data[j], -1 ) );
|
||||
}
|
||||
}
|
||||
private:
|
||||
/*! \brief print evaluation results */
|
||||
inline void Eval( FILE *fo, const char *evname,
|
||||
const std::vector<float> &preds,
|
||||
const std::vector<float> &labels ){
|
||||
const float loss = mparam.Loss( preds, labels );
|
||||
fprintf( fo, "\t%s:%f", evname, loss );
|
||||
}
|
||||
/*! \brief get the transformed predictions, given data */
|
||||
inline void Predict( std::vector<float> &preds, const DMatrix &data, unsigned buffer_offset ){
|
||||
inline void PredictBuffer( std::vector<float> &preds, const DMatrix &data, unsigned buffer_offset ){
|
||||
preds.resize( data.Size() );
|
||||
for( size_t j = 0; j < data.Size(); j++ ){
|
||||
|
||||
const unsigned ndata = static_cast<unsigned>( data.Size() );
|
||||
#pragma omp parallel for schedule( static )
|
||||
for( unsigned j = 0; j < ndata; ++ j ){
|
||||
preds[j] = mparam.PredTransform
|
||||
( mparam.base_score + base_model.Predict( data.data[j], buffer_offset + j ) );
|
||||
}
|
||||
@ -175,13 +183,16 @@ namespace xgboost{
|
||||
const std::vector<float> &labels,
|
||||
std::vector<float> &grad,
|
||||
std::vector<float> &hess ){
|
||||
grad.clear(); hess.clear();
|
||||
for( size_t j = 0; j < preds.size(); j++ ){
|
||||
grad.push_back( mparam.FirstOrderGradient (preds[j],labels[j]) );
|
||||
hess.push_back( mparam.SecondOrderGradient(preds[j],labels[j]) );
|
||||
grad.resize( preds.size() ); hess.resize( preds.size() );
|
||||
|
||||
const unsigned ndata = static_cast<unsigned>( preds.size() );
|
||||
#pragma omp parallel for schedule( static )
|
||||
for( unsigned j = 0; j < ndata; ++ j ){
|
||||
grad[j] = mparam.FirstOrderGradient( preds[j], labels[j] );
|
||||
hess[j] = mparam.SecondOrderGradient( preds[j], labels[j] );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
enum LossType{
|
||||
kLinearSquare = 0,
|
||||
@ -262,26 +273,26 @@ namespace xgboost{
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief calculating the loss, given the predictions, labels and the loss type
|
||||
* \param preds the given predictions
|
||||
* \param labels the given labels
|
||||
* \return the specified loss
|
||||
*/
|
||||
* \brief calculating the loss, given the predictions, labels and the loss type
|
||||
* \param preds the given predictions
|
||||
* \param labels the given labels
|
||||
* \return the specified loss
|
||||
*/
|
||||
inline float Loss(const std::vector<float> &preds, const std::vector<float> &labels) const{
|
||||
switch( loss_type ){
|
||||
case kLinearSquare: return SquareLoss(preds,labels);
|
||||
case kLogisticNeglik: return NegLoglikelihoodLoss(preds,labels);
|
||||
case kLogisticClassify: return ClassificationError(preds, labels);
|
||||
case kLogisticNeglik:
|
||||
case kLogisticClassify: return NegLoglikelihoodLoss(preds,labels);
|
||||
default: utils::Error("unknown loss_type"); return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief calculating the square loss, given the predictions and labels
|
||||
* \param preds the given predictions
|
||||
* \param labels the given labels
|
||||
* \return the summation of square loss
|
||||
*/
|
||||
* \brief calculating the square loss, given the predictions and labels
|
||||
* \param preds the given predictions
|
||||
* \param labels the given labels
|
||||
* \return the summation of square loss
|
||||
*/
|
||||
inline float SquareLoss(const std::vector<float> &preds, const std::vector<float> &labels) const{
|
||||
float ans = 0.0;
|
||||
for(size_t i = 0; i < preds.size(); i++){
|
||||
@ -292,44 +303,30 @@ namespace xgboost{
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief calculating the square loss, given the predictions and labels
|
||||
* \param preds the given predictions
|
||||
* \param labels the given labels
|
||||
* \return the summation of square loss
|
||||
*/
|
||||
* \brief calculating the square loss, given the predictions and labels
|
||||
* \param preds the given predictions
|
||||
* \param labels the given labels
|
||||
* \return the summation of square loss
|
||||
*/
|
||||
inline float NegLoglikelihoodLoss(const std::vector<float> &preds, const std::vector<float> &labels) const{
|
||||
float ans = 0.0;
|
||||
for(size_t i = 0; i < preds.size(); i++)
|
||||
ans -= labels[i] * logf(preds[i]) + ( 1 - labels[i] ) * logf(1 - preds[i]);
|
||||
return ans;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief calculating the ClassificationError loss, given the predictions and labels
|
||||
* \param preds the given predictions
|
||||
* \param labels the given labels
|
||||
* \return the summation of square loss
|
||||
*/
|
||||
inline float ClassificationError(const std::vector<float> &preds, const std::vector<float> &labels) const{
|
||||
int nerr = 0;
|
||||
for(size_t i = 0; i < preds.size(); i++){
|
||||
if( preds[i] > 0.5f ){
|
||||
if( labels[i] < 0.5f ) nerr ++;
|
||||
}else{
|
||||
if( labels[i] > 0.5f ) nerr ++;
|
||||
}
|
||||
}
|
||||
return (float)nerr/preds.size();
|
||||
}
|
||||
};
|
||||
private:
|
||||
int silent;
|
||||
EvalSet evaluator_;
|
||||
booster::GBMBaseModel base_model;
|
||||
ModelParam mparam;
|
||||
const DMatrix *train_;
|
||||
std::vector<DMatrix *> evals_;
|
||||
std::vector<std::string> evname_;
|
||||
std::vector<unsigned> buffer_index_;
|
||||
private:
|
||||
std::vector<float> grad_, hess_, preds_;
|
||||
std::vector< std::vector<float> > eval_preds_;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
#ifndef _XGBOOST_REG_TEST_H_
|
||||
#define _XGBOOST_REG_TEST_H_
|
||||
|
||||
#include<iostream>
|
||||
#include<string>
|
||||
#include<fstream>
|
||||
#include"../utils/xgboost_config.h"
|
||||
#include"xgboost_reg.h"
|
||||
#include"xgboost_regdata.h"
|
||||
#include"../utils/xgboost_string.h"
|
||||
|
||||
using namespace xgboost::utils;
|
||||
namespace xgboost{
|
||||
namespace regression{
|
||||
/*!
|
||||
* \brief wrapping the testing process of the gradient
|
||||
boosting regression model,given the configuation
|
||||
* \author Kailong Chen: chenkl198812@gmail.com
|
||||
*/
|
||||
class RegBoostTest{
|
||||
public:
|
||||
/*!
|
||||
* \brief to start the testing process of gradient boosting regression
|
||||
* model given the configuation, and finally save the prediction
|
||||
* results to the specified paths.
|
||||
* \param config_path the location of the configuration
|
||||
* \param silent whether to print feedback messages
|
||||
*/
|
||||
void test(char* config_path,bool silent = false){
|
||||
reg_boost_learner = new xgboost::regression::RegBoostLearner();
|
||||
ConfigIterator config_itr(config_path);
|
||||
//Get the training data and validation data paths, config the Learner
|
||||
while (config_itr.Next()){
|
||||
reg_boost_learner->SetParam(config_itr.name(),config_itr.val());
|
||||
test_param.SetParam(config_itr.name(),config_itr.val());
|
||||
}
|
||||
|
||||
Assert(test_param.test_paths.size() == test_param.test_names.size(),
|
||||
"The number of test data set paths is not the same as the number of test data set data set names");
|
||||
|
||||
//begin testing
|
||||
reg_boost_learner->InitModel();
|
||||
char model_path[256];
|
||||
std::vector<float> preds;
|
||||
for(size_t i = 0; i < test_param.test_paths.size(); i++){
|
||||
xgboost::regression::DMatrix test_data;
|
||||
test_data.LoadText(test_param.test_paths[i].c_str());
|
||||
sprintf(model_path,"%s/final.model",test_param.model_dir_path);
|
||||
// BUG: model need to be rb
|
||||
FileStream fin(fopen(model_path,"r"));
|
||||
reg_boost_learner->LoadModel(fin);
|
||||
fin.Close();
|
||||
reg_boost_learner->Predict(preds,test_data);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct TestParam{
|
||||
/* \brief upperbound of the number of boosters */
|
||||
int boost_iterations;
|
||||
|
||||
/* \brief the period to save the model, -1 means only save the final round model */
|
||||
int save_period;
|
||||
|
||||
/* \brief the path of directory containing the saved models */
|
||||
char model_dir_path[256];
|
||||
|
||||
/* \brief the path of directory containing the output prediction results */
|
||||
char pred_dir_path[256];
|
||||
|
||||
/* \brief the paths of test data sets */
|
||||
std::vector<std::string> test_paths;
|
||||
|
||||
/* \brief the names of the test data sets */
|
||||
std::vector<std::string> test_names;
|
||||
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
* \param name name of the parameter
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
inline void SetParam(const char *name,const char *val ){
|
||||
if( !strcmp("model_dir_path", name ) ) strcpy(model_dir_path,val);
|
||||
if( !strcmp("pred_dir_path", name ) ) strcpy(pred_dir_path,val);
|
||||
if( !strcmp("test_paths", name) ) {
|
||||
test_paths = StringProcessing::split(val,';');
|
||||
}
|
||||
if( !strcmp("test_names", name) ) {
|
||||
test_names = StringProcessing::split(val,';');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TestParam test_param;
|
||||
xgboost::regression::RegBoostLearner* reg_boost_learner;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -1,132 +0,0 @@
|
||||
#ifndef _XGBOOST_REG_TRAIN_H_
|
||||
#define _XGBOOST_REG_TRAIN_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include "../utils/xgboost_config.h"
|
||||
#include "xgboost_reg.h"
|
||||
#include "xgboost_regdata.h"
|
||||
#include "../utils/xgboost_string.h"
|
||||
|
||||
using namespace xgboost::utils;
|
||||
|
||||
namespace xgboost{
|
||||
namespace regression{
|
||||
/*!
|
||||
* \brief wrapping the training process of the gradient
|
||||
boosting regression model,given the configuation
|
||||
* \author Kailong Chen: chenkl198812@gmail.com
|
||||
*/
|
||||
class RegBoostTrain{
|
||||
public:
|
||||
/*!
|
||||
* \brief to start the training process of gradient boosting regression
|
||||
* model given the configuation, and finally saved the models
|
||||
* to the specified model directory
|
||||
* \param config_path the location of the configuration
|
||||
* \param silent whether to print feedback messages
|
||||
*/
|
||||
void train(char* config_path,bool silent = false){
|
||||
reg_boost_learner = new xgboost::regression::RegBoostLearner();
|
||||
|
||||
ConfigIterator config_itr(config_path);
|
||||
//Get the training data and validation data paths, config the Learner
|
||||
while (config_itr.Next()){
|
||||
printf("%s %s\n",config_itr.name(),config_itr.val());
|
||||
reg_boost_learner->SetParam(config_itr.name(),config_itr.val());
|
||||
train_param.SetParam(config_itr.name(),config_itr.val());
|
||||
}
|
||||
|
||||
Assert(train_param.validation_data_paths.size() == train_param.validation_data_names.size(),
|
||||
"The number of validation paths is not the same as the number of validation data set names");
|
||||
|
||||
//Load Data
|
||||
xgboost::regression::DMatrix train;
|
||||
printf("%s",train_param.train_path);
|
||||
train.LoadText(train_param.train_path);
|
||||
std::vector<const xgboost::regression::DMatrix*> evals;
|
||||
for(size_t i = 0; i < train_param.validation_data_paths.size(); i++){
|
||||
xgboost::regression::DMatrix eval;
|
||||
eval.LoadText(train_param.validation_data_paths[i].c_str());
|
||||
evals.push_back(&eval);
|
||||
}
|
||||
reg_boost_learner->SetData(&train,evals,train_param.validation_data_names);
|
||||
|
||||
//begin training
|
||||
reg_boost_learner->InitTrainer();
|
||||
char suffix[256];
|
||||
for(int i = 1; i <= train_param.boost_iterations; i++){
|
||||
reg_boost_learner->UpdateOneIter(i);
|
||||
if(train_param.save_period != 0 && i % train_param.save_period == 0){
|
||||
sprintf(suffix,"%d.model",i);
|
||||
SaveModel(suffix);
|
||||
}
|
||||
}
|
||||
|
||||
//save the final round model
|
||||
SaveModel("final.model");
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief save model in the model directory with specified suffix*/
|
||||
void SaveModel(const char* suffix){
|
||||
char model_path[256];
|
||||
//save the final round model
|
||||
sprintf(model_path,"%s/%s",train_param.model_dir_path,suffix);
|
||||
FILE* file = fopen(model_path,"w");
|
||||
FileStream fin(file);
|
||||
reg_boost_learner->SaveModel(fin);
|
||||
fin.Close();
|
||||
}
|
||||
|
||||
struct TrainParam{
|
||||
/* \brief upperbound of the number of boosters */
|
||||
int boost_iterations;
|
||||
|
||||
/* \brief the period to save the model, -1 means only save the final round model */
|
||||
int save_period;
|
||||
|
||||
/* \brief the path of training data set */
|
||||
char train_path[256];
|
||||
|
||||
/* \brief the path of directory containing the saved models */
|
||||
char model_dir_path[256];
|
||||
|
||||
/* \brief the paths of validation data sets */
|
||||
std::vector<std::string> validation_data_paths;
|
||||
|
||||
/* \brief the names of the validation data sets */
|
||||
std::vector<std::string> validation_data_names;
|
||||
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
* \param name name of the parameter
|
||||
* \param val value of the parameter
|
||||
*/
|
||||
inline void SetParam(const char *name,const char *val ){
|
||||
if( !strcmp("boost_iterations", name ) ) boost_iterations = atoi( val );
|
||||
if( !strcmp("save_period", name ) ) save_period = atoi( val );
|
||||
if( !strcmp("train_path", name ) ) strcpy(train_path,val);
|
||||
if( !strcmp("model_dir_path", name ) ) {
|
||||
strcpy(model_dir_path,val);
|
||||
}
|
||||
if( !strcmp("validation_paths", name) ) {
|
||||
validation_data_paths = StringProcessing::split(val,';');
|
||||
}
|
||||
if( !strcmp("validation_names", name) ) {
|
||||
validation_data_names = StringProcessing::split(val,';');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief the parameters of the training process*/
|
||||
TrainParam train_param;
|
||||
|
||||
/*! \brief the gradient boosting regression tree model*/
|
||||
xgboost::regression::RegBoostLearner* reg_boost_learner;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
96
regression/xgboost_regeval.h
Normal file
96
regression/xgboost_regeval.h
Normal file
@ -0,0 +1,96 @@
|
||||
#ifndef _XGBOOST_REGEVAL_H_
|
||||
#define _XGBOOST_REGEVAL_H_
|
||||
/*!
|
||||
* \file xgboost_regeval.h
|
||||
* \brief evaluation metrics for regression and classification
|
||||
* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com
|
||||
*/
|
||||
#include <omp.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "../utils/xgboost_utils.h"
|
||||
|
||||
namespace xgboost{
|
||||
namespace regression{
|
||||
/*! \brief evaluator that evaluates the loss metrics */
|
||||
struct IEvaluator{
|
||||
/*!
|
||||
* \brief evaluate a specific metric
|
||||
* \param preds prediction
|
||||
* \param labels label
|
||||
*/
|
||||
virtual float Eval( const std::vector<float> &preds,
|
||||
const std::vector<float> &labels ) const= 0;
|
||||
/*! \return name of metric */
|
||||
virtual const char *Name( void ) const= 0;
|
||||
};
|
||||
|
||||
/*! \brief RMSE */
|
||||
struct EvalRMSE : public IEvaluator{
|
||||
virtual float Eval( const std::vector<float> &preds,
|
||||
const std::vector<float> &labels ) const{
|
||||
const unsigned ndata = static_cast<unsigned>( preds.size() );
|
||||
float sum = 0.0;
|
||||
#pragma omp parallel for reduction(+:sum) schedule( static )
|
||||
for( unsigned i = 0; i < ndata; ++ i ){
|
||||
float diff = preds[i] - labels[i];
|
||||
sum += diff * diff;
|
||||
}
|
||||
return sqrtf( sum / ndata );
|
||||
}
|
||||
virtual const char *Name( void ) const{
|
||||
return "rmse";
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Error */
|
||||
struct EvalError : public IEvaluator{
|
||||
virtual float Eval( const std::vector<float> &preds,
|
||||
const std::vector<float> &labels ) const{
|
||||
const unsigned ndata = static_cast<unsigned>( preds.size() );
|
||||
unsigned nerr = 0;
|
||||
#pragma omp parallel for reduction(+:nerr) schedule( static )
|
||||
for( unsigned i = 0; i < ndata; ++ i ){
|
||||
if( preds[i] > 0.5f ){
|
||||
if( labels[i] < 0.5f ) nerr += 1;
|
||||
}else{
|
||||
if( labels[i] > 0.5f ) nerr += 1;
|
||||
}
|
||||
}
|
||||
return static_cast<float>(nerr) / ndata;
|
||||
}
|
||||
virtual const char *Name( void ) const{
|
||||
return "error";
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
namespace regression{
|
||||
/*! \brief a set of evaluators */
|
||||
struct EvalSet{
|
||||
public:
|
||||
inline void AddEval( const char *name ){
|
||||
if( !strcmp( name, "rmse") ) evals_.push_back( &rmse_ );
|
||||
if( !strcmp( name, "error") ) evals_.push_back( &error_ );
|
||||
}
|
||||
inline void Init( void ){
|
||||
std::sort( evals_.begin(), evals_.end() );
|
||||
evals_.resize( std::unique( evals_.begin(), evals_.end() ) - evals_.begin() );
|
||||
}
|
||||
inline void Eval( FILE *fo, const char *evname,
|
||||
const std::vector<float> &preds,
|
||||
const std::vector<float> &labels ) const{
|
||||
for( size_t i = 0; i < evals_.size(); ++ i ){
|
||||
float res = evals_[i]->Eval( preds, labels );
|
||||
fprintf( fo, "\t%s-%s:%f", evname, evals_[i]->Name(), res );
|
||||
}
|
||||
}
|
||||
private:
|
||||
EvalRMSE rmse_;
|
||||
EvalError error_;
|
||||
std::vector<const IEvaluator*> evals_;
|
||||
};
|
||||
};
|
||||
};
|
||||
#endif
|
||||
@ -1,31 +0,0 @@
|
||||
#ifndef _XGBOOST_STRING_H_
|
||||
#define _XGBOOST_STRING_H_
|
||||
#include<vector>
|
||||
#include<sstream>
|
||||
|
||||
namespace xgboost{
|
||||
namespace utils{
|
||||
class StringProcessing{
|
||||
|
||||
public:
|
||||
static std::vector<std::string> &split(const std::string &s, char delim, std::vector<std::string> &elems) {
|
||||
std::stringstream ss(s);
|
||||
std::string item;
|
||||
while (std::getline(ss, item, delim)) {
|
||||
elems.push_back(item);
|
||||
}
|
||||
return elems;
|
||||
}
|
||||
|
||||
|
||||
static std::vector<std::string> split(const std::string &s, char delim) {
|
||||
std::vector<std::string> elems;
|
||||
split(s, delim, elems);
|
||||
return elems;
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
Loading…
x
Reference in New Issue
Block a user