Merge branch 'master' of ssh://github.com/tqchen/xgboost

This commit is contained in:
tqchen 2014-02-15 17:42:31 -08:00 committed by tqchen
commit ece5f00ca1
9 changed files with 548 additions and 146 deletions

View File

@ -1,4 +1,4 @@
xgboost
xgboost: A Gradient Boosting Library
=======
Creater: Tianqi Chen: tianqi.tchen AT gmail
@ -7,16 +7,16 @@ General Purpose Gradient Boosting Library
Goal: A stand-alone efficient library to do learning via boosting in functional space
Features:
(1) Sparse feature format, handling of missing features. This allows efficient categorical feature encoding as indicators. The speed of booster only depens on number of existing features.
(2) Layout of gradient boosting algorithm to support generic tasks, see project wiki.
* Sparse feature format, handling of missing features. This allows efficient categorical feature encoding as indicators. The speed of booster only depends on number of existing features.
* Layout of gradient boosting algorithm to support generic tasks, see project wiki.
Planned key components:
(1) Gradient boosting models:
* Gradient boosting models:
- regression tree (GBRT)
- linear model/lasso
(2) Objectives to support tasks:
* Objectives to support tasks:
- regression
- classification
- ranking

View File

@ -9,6 +9,7 @@
#include <vector>
#include "../utils/xgboost_utils.h"
#include "../utils/xgboost_stream.h"
namespace xgboost{
namespace booster{
@ -143,7 +144,7 @@ namespace xgboost{
* the function is not consistent between 64bit and 32bit machine
* \param fo output stream
*/
inline void SaveBinary( utils::IStream &fo ) const{
inline void SaveBinary(utils::IStream &fo ) const{
size_t nrow = this->NumRow();
fo.Write( &nrow, sizeof(size_t) );
fo.Write( &row_ptr[0], row_ptr.size() * sizeof(size_t) );

View File

@ -1,10 +1,10 @@
#ifndef _XGBOOST_REG_H_
#define _XGBOOST_REG_H_
/*!
* \file xgboost_reg.h
* \brief class for gradient boosted regression
* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com
*/
* \file xgboost_reg.h
* \brief class for gradient boosted regression
* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com
*/
#include <cmath>
#include "xgboost_regdata.h"
#include "../booster/xgboost_gbmbase.h"
@ -12,143 +12,265 @@
#include "../utils/xgboost_stream.h"
namespace xgboost{
namespace regression{
/*! \brief class for gradient boosted regression */
class RegBoostLearner{
public:
/*!
* \brief a regression booter associated with training and evaluating data
* \param train pointer to the training data
* \param evals array of evaluating data
* \param evname name of evaluation data, used print statistics
namespace regression{
/*! \brief class for gradient boosted regression */
class RegBoostLearner{
public:
RegBoostLearner(bool silent = false){
this->silent = silent;
}
/*!
* \brief a regression booter associated with training and evaluating data
* \param train pointer to the training data
* \param evals array of evaluating data
* \param evname name of evaluation data, used print statistics
*/
RegBoostLearner( const DMatrix *train,
std::vector<const DMatrix *> evals,
std::vector<std::string> evname, bool silent = false ){
this->silent = silent;
SetData(train,evals,evname);
}
/*!
* \brief associate regression booster with training and evaluating data
* \param train pointer to the training data
* \param evals array of evaluating data
* \param evname name of evaluation data, used print statistics
*/
inline void SetData(const DMatrix *train,
std::vector<const DMatrix *> evals,
std::vector<std::string> evname){
this->train_ = train;
this->evals_ = evals;
this->evname_ = evname;
//assign buffer index
int buffer_size = (*train).size();
for(int i = 0; i < evals.size(); i++){
buffer_size += (*evals[i]).size();
}
char str[25];
itoa(buffer_size,str,10);
base_model.SetParam("num_pbuffer",str);
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam( const char *name, const char *val ){
mparam.SetParam( name, val );
base_model.SetParam( name, val );
}
/*!
* \brief initialize solver before training, called before training
* this function is reserved for solver to allocate necessary space and do other preparation
*/
inline void InitTrainer( void ){
base_model.InitTrainer();
mparam.AdjustBase();
}
/*!
* \brief initialize the current data storage for model, if the model is used first time, call this function
*/
RegBoostLearner( const DMatrix *train,
std::vector<const DMatrix *> evals,
std::vector<std::string> evname ){
this->train_ = train;
this->evals_ = evals;
this->evname_ = evname;
//TODO: assign buffer index
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam( const char *name, const char *val ){
mparam.SetParam( name, val );
base_model.SetParam( name, val );
}
/*!
* \brief initialize solver before training, called before training
* this function is reserved for solver to allocate necessary space and do other preparation
*/
inline void InitTrainer( void ){
base_model.InitTrainer();
mparam.AdjustBase();
}
/*!
* \brief load model from stream
* \param fi input stream
*/
inline void LoadModel( utils::IStream &fi ){
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
base_model.LoadModel( fi );
}
/*!
* \brief save model to stream
* \param fo output stream
*/
inline void SaveModel( utils::IStream &fo ) const{
fo.Write( &mparam, sizeof(ModelParam) );
base_model.SaveModel( fo );
}
/*!
* \brief update the model for one iteration
*/
inline void UpdateOneIter( void ){
//TODO
}
/*! \brief predict the results, given data */
inline void Predict( std::vector<float> &preds, const DMatrix &data ){
//TODO
}
private:
/*! \brief training parameter for regression */
struct ModelParam{
/* \brief global bias */
float base_score;
/* \brief type of loss function */
int loss_type;
ModelParam( void ){
base_score = 0.5f;
loss_type = 0;
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam( const char *name, const char *val ){
if( !strcmp("base_score", name ) ) base_score = (float)atof( val );
if( !strcmp("loss_type", name ) ) loss_type = atoi( val );
}
/*!
* \brief adjust base_score
*/
inline void AdjustBase( void ){
if( loss_type == 1 ){
utils::Assert( base_score > 0.0f && base_score < 1.0f, "sigmoid range constrain" );
base_score = - logf( 1.0f / base_score - 1.0f );
}
}
/*!
* \brief calculate first order gradient of loss, given transformed prediction
* \param predt transformed prediction
* \param label true label
* \return first order gradient
*/
inline float FirstOrderGradient( float predt, float label ) const{
switch( loss_type ){
case 0: return predt - label;
case 1: return predt - label;
default: utils::Error("unknown loss_type"); return 0.0f;
}
}
/*!
* \brief calculate second order gradient of loss, given transformed prediction
* \param predt transformed prediction
* \param label true label
* \return second order gradient
*/
inline float SecondOrderGradient( float predt, float label ) const{
switch( loss_type ){
case 0: return 1.0f;
case 1: return predt * ( 1 - predt );
default: utils::Error("unknown loss_type"); return 0.0f;
}
}
/*!
* \brief transform the linear sum to prediction
* \param x linear sum of boosting ensemble
* \return transformed prediction
*/
inline float PredTransform( float x ){
switch( loss_type ){
case 0: return x;
case 1: return 1.0f/(1.0f + expf(-x));
default: utils::Error("unknown loss_type"); return 0.0f;
}
}
};
private:
booster::GBMBaseModel base_model;
ModelParam mparam;
const DMatrix *train_;
std::vector<const DMatrix *> evals_;
std::vector<std::string> evname_;
};
};
inline void InitModel( void ){
base_model.InitModel();
}
/*!
* \brief load model from stream
* \param fi input stream
*/
inline void LoadModel( utils::IStream &fi ){
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
base_model.LoadModel( fi );
}
/*!
* \brief save model to stream
* \param fo output stream
*/
inline void SaveModel( utils::IStream &fo ) const{
fo.Write( &mparam, sizeof(ModelParam) );
base_model.SaveModel( fo );
}
/*!
* \brief update the model for one iteration
* \param iteration the number of updating iteration
*/
inline void UpdateOneIter( int iteration ){
std::vector<float> grad,hess,preds;
std::vector<unsigned> root_index;
booster::FMatrixS::Image train_image((*train_).data);
Predict(preds,*train_,0);
Gradient(preds,(*train_).labels,grad,hess);
base_model.DoBoost(grad,hess,train_image,root_index);
int buffer_index_offset = (*train_).size();
float loss = 0.0;
for(int i = 0; i < evals_.size();i++){
Predict(preds, *evals_[i], buffer_index_offset);
loss = mparam.Loss(preds,(*evals_[i]).labels);
if(!silent){
printf("The loss of %s data set in %d the \
iteration is %f",evname_[i].c_str(),&iteration,&loss);
}
buffer_index_offset += (*evals_[i]).size();
}
}
/*! \brief get the transformed predictions, given data */
inline void Predict( std::vector<float> &preds, const DMatrix &data,int buffer_index_offset = 0 ){
int data_size = data.size();
preds.resize(data_size);
for(int j = 0; j < data_size; j++){
preds[j] = mparam.PredTransform(mparam.base_score +
base_model.Predict(data.data[j],buffer_index_offset + j));
}
}
private:
/*! \brief get the first order and second order gradient, given the transformed predictions and labels*/
inline void Gradient(const std::vector<float> &preds, const std::vector<float> &labels, std::vector<float> &grad,
std::vector<float> &hess){
grad.clear();
hess.clear();
for(int j = 0; j < preds.size(); j++){
grad.push_back(mparam.FirstOrderGradient(preds[j],labels[j]));
hess.push_back(mparam.SecondOrderGradient(preds[j],labels[j]));
}
}
enum LOSS_TYPE_LIST{
LINEAR_SQUARE,
LOGISTIC_NEGLOGLIKELIHOOD,
};
/*! \brief training parameter for regression */
struct ModelParam{
/* \brief global bias */
float base_score;
/* \brief type of loss function */
int loss_type;
ModelParam( void ){
base_score = 0.5f;
loss_type = 0;
}
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam( const char *name, const char *val ){
if( !strcmp("base_score", name ) ) base_score = (float)atof( val );
if( !strcmp("loss_type", name ) ) loss_type = atoi( val );
}
/*!
* \brief adjust base_score
*/
inline void AdjustBase( void ){
if( loss_type == 1 ){
utils::Assert( base_score > 0.0f && base_score < 1.0f, "sigmoid range constrain" );
base_score = - logf( 1.0f / base_score - 1.0f );
}
}
/*!
* \brief calculate first order gradient of loss, given transformed prediction
* \param predt transformed prediction
* \param label true label
* \return first order gradient
*/
inline float FirstOrderGradient( float predt, float label ) const{
switch( loss_type ){
case LINEAR_SQUARE: return predt - label;
case 1: return predt - label;
default: utils::Error("unknown loss_type"); return 0.0f;
}
}
/*!
* \brief calculate second order gradient of loss, given transformed prediction
* \param predt transformed prediction
* \param label true label
* \return second order gradient
*/
inline float SecondOrderGradient( float predt, float label ) const{
switch( loss_type ){
case LINEAR_SQUARE: return 1.0f;
case LOGISTIC_NEGLOGLIKELIHOOD: return predt * ( 1 - predt );
default: utils::Error("unknown loss_type"); return 0.0f;
}
}
/*!
* \brief calculating the loss, given the predictions, labels and the loss type
* \param preds the given predictions
* \param labels the given labels
* \return the specified loss
*/
inline float Loss(const std::vector<float> &preds, const std::vector<float> &labels) const{
switch( loss_type ){
case LINEAR_SQUARE: return SquareLoss(preds,labels);
case LOGISTIC_NEGLOGLIKELIHOOD: return NegLoglikelihoodLoss(preds,labels);
default: utils::Error("unknown loss_type"); return 0.0f;
}
}
/*!
* \brief calculating the square loss, given the predictions and labels
* \param preds the given predictions
* \param labels the given labels
* \return the summation of square loss
*/
inline float SquareLoss(const std::vector<float> &preds, const std::vector<float> &labels) const{
float ans = 0.0;
for(int i = 0; i < preds.size(); i++)
ans += pow(preds[i] - labels[i], 2);
return ans;
}
/*!
* \brief calculating the square loss, given the predictions and labels
* \param preds the given predictions
* \param labels the given labels
* \return the summation of square loss
*/
inline float NegLoglikelihoodLoss(const std::vector<float> &preds, const std::vector<float> &labels) const{
float ans = 0.0;
for(int i = 0; i < preds.size(); i++)
ans -= labels[i] * log(preds[i]) + ( 1 - labels[i] ) * log(1 - preds[i]);
return ans;
}
/*!
* \brief transform the linear sum to prediction
* \param x linear sum of boosting ensemble
* \return transformed prediction
*/
inline float PredTransform( float x ){
switch( loss_type ){
case LINEAR_SQUARE: return x;
case LOGISTIC_NEGLOGLIKELIHOOD: return 1.0f/(1.0f + expf(-x));
default: utils::Error("unknown loss_type"); return 0.0f;
}
}
};
private:
booster::GBMBaseModel base_model;
ModelParam mparam;
const DMatrix *train_;
std::vector<const DMatrix *> evals_;
std::vector<std::string> evname_;
bool silent;
};
}
};
#endif

View File

@ -0,0 +1,14 @@
#include"xgboost_reg_train.h"
#include"xgboost_reg_test.h"
using namespace xgboost::regression;
int main(int argc, char *argv[]){
// char* config_path = argv[1];
// bool silent = ( atoi(argv[2]) == 1 );
char* config_path = "c:\\cygwin64\\home\\chen\\github\\gboost\\demo\\regression\\reg.conf";
bool silent = false;
RegBoostTrain train;
RegBoostTest test;
train.train(config_path,false);
test.test(config_path,false);
}

View File

@ -0,0 +1,99 @@
#ifndef _XGBOOST_REG_TEST_H_
#define _XGBOOST_REG_TEST_H_
#include<iostream>
#include<string>
#include<fstream>
#include"../utils/xgboost_config.h"
#include"xgboost_reg.h"
#include"xgboost_regdata.h"
#include"../utils/xgboost_string.h"
using namespace xgboost::utils;
namespace xgboost{
namespace regression{
/*!
* \brief wrapping the testing process of the gradient
boosting regression model,given the configuation
* \author Kailong Chen: chenkl198812@gmail.com
*/
class RegBoostTest{
public:
/*!
* \brief to start the testing process of gradient boosting regression
* model given the configuation, and finally save the prediction
* results to the specified paths.
* \param config_path the location of the configuration
* \param silent whether to print feedback messages
*/
void test(char* config_path,bool silent = false){
reg_boost_learner = new xgboost::regression::RegBoostLearner(silent);
ConfigIterator config_itr(config_path);
//Get the training data and validation data paths, config the Learner
while (config_itr.Next()){
reg_boost_learner->SetParam(config_itr.name(),config_itr.val());
test_param.SetParam(config_itr.name(),config_itr.val());
}
Assert(test_param.test_paths.size() == test_param.test_names.size(),
"The number of test data set paths is not the same as the number of test data set data set names");
//begin testing
reg_boost_learner->InitModel();
char model_path[256];
std::vector<float> preds;
for(int i = 0; i < test_param.test_paths.size(); i++){
xgboost::regression::DMatrix test_data;
test_data.LoadText(test_param.test_paths[i].c_str());
sscanf(model_path,"%s/final.model",test_param.model_dir_path);
FileStream fin(fopen(model_path,"r"));
reg_boost_learner->LoadModel(fin);
fin.Close();
reg_boost_learner->Predict(preds,test_data);
}
}
private:
struct TestParam{
/* \brief upperbound of the number of boosters */
int boost_iterations;
/* \brief the period to save the model, -1 means only save the final round model */
int save_period;
/* \brief the path of directory containing the saved models */
const char* model_dir_path;
/* \brief the path of directory containing the output prediction results */
const char* pred_dir_path;
/* \brief the paths of test data sets */
std::vector<std::string> test_paths;
/* \brief the names of the test data sets */
std::vector<std::string> test_names;
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam(const char *name,const char *val ){
if( !strcmp("model_dir_path", name ) ) model_dir_path = val;
if( !strcmp("pred_dir_path", name ) ) model_dir_path = val;
if( !strcmp("test_paths", name) ) {
test_paths = StringProcessing::split(val,';');
}
if( !strcmp("test_names", name) ) {
test_names = StringProcessing::split(val,';');
}
}
};
TestParam test_param;
xgboost::regression::RegBoostLearner* reg_boost_learner;
};
}
}
#endif

View File

@ -0,0 +1,127 @@
#ifndef _XGBOOST_REG_TRAIN_H_
#define _XGBOOST_REG_TRAIN_H_
#include<iostream>
#include<string>
#include<fstream>
#include"../utils/xgboost_config.h"
#include"xgboost_reg.h"
#include"xgboost_regdata.h"
#include"../utils/xgboost_string.h"
using namespace xgboost::utils;
namespace xgboost{
namespace regression{
/*!
* \brief wrapping the training process of the gradient
boosting regression model,given the configuation
* \author Kailong Chen: chenkl198812@gmail.com
*/
class RegBoostTrain{
public:
/*!
* \brief to start the training process of gradient boosting regression
* model given the configuation, and finally saved the models
* to the specified model directory
* \param config_path the location of the configuration
* \param silent whether to print feedback messages
*/
void train(char* config_path,bool silent = false){
reg_boost_learner = new xgboost::regression::RegBoostLearner(silent);
ConfigIterator config_itr(config_path);
//Get the training data and validation data paths, config the Learner
while (config_itr.Next()){
reg_boost_learner->SetParam(config_itr.name(),config_itr.val());
train_param.SetParam(config_itr.name(),config_itr.val());
}
Assert(train_param.validation_data_paths.size() == train_param.validation_data_names.size(),
"The number of validation paths is not the same as the number of validation data set names");
//Load Data
xgboost::regression::DMatrix train;
train.LoadText(train_param.train_path);
std::vector<const xgboost::regression::DMatrix*> evals;
for(int i = 0; i < train_param.validation_data_paths.size(); i++){
xgboost::regression::DMatrix eval;
eval.LoadText(train_param.validation_data_paths[i].c_str());
evals.push_back(&eval);
}
reg_boost_learner->SetData(&train,evals,train_param.validation_data_names);
//begin training
reg_boost_learner->InitTrainer();
char suffix[256];
for(int i = 1; i <= train_param.boost_iterations; i++){
reg_boost_learner->UpdateOneIter(i);
if(train_param.save_period != 0 && i % train_param.save_period == 0){
sscanf(suffix,"%d.model",i);
SaveModel(suffix);
}
}
//save the final round model
SaveModel("final.model");
}
private:
/*! \brief save model in the model directory with specified suffix*/
void SaveModel(const char* suffix){
char model_path[256];
//save the final round model
sscanf(model_path,"%s/%s",train_param.model_dir_path,suffix);
FILE* file = fopen(model_path,"w");
FileStream fin(file);
reg_boost_learner->SaveModel(fin);
fin.Close();
}
struct TrainParam{
/* \brief upperbound of the number of boosters */
int boost_iterations;
/* \brief the period to save the model, -1 means only save the final round model */
int save_period;
/* \brief the path of training data set */
const char* train_path;
/* \brief the path of directory containing the saved models */
const char* model_dir_path;
/* \brief the paths of validation data sets */
std::vector<std::string> validation_data_paths;
/* \brief the names of the validation data sets */
std::vector<std::string> validation_data_names;
/*!
* \brief set parameters from outside
* \param name name of the parameter
* \param val value of the parameter
*/
inline void SetParam(const char *name,const char *val ){
if( !strcmp("boost_iterations", name ) ) boost_iterations = (float)atof( val );
if( !strcmp("save_period", name ) ) save_period = atoi( val );
if( !strcmp("train_path", name ) ) train_path = val;
if( !strcmp("model_dir_path", name ) ) model_dir_path = val;
if( !strcmp("validation_paths", name) ) {
validation_data_paths = StringProcessing::split(val,';');
}
if( !strcmp("validation_names", name) ) {
validation_data_names = StringProcessing::split(val,';');
}
}
};
/*! \brief the parameters of the training process*/
TrainParam train_param;
/*! \brief the gradient boosting regression tree model*/
xgboost::regression::RegBoostLearner* reg_boost_learner;
};
}
}
#endif

View File

@ -30,6 +30,13 @@ namespace xgboost{
public:
/*! \brief default constructor */
DMatrix( void ){}
/*! \brief get the number of instances */
inline int size() const{
return labels.size();
}
/*!
* \brief load from text file
* \param fname name of text data

View File

@ -10,6 +10,7 @@
#include <cstring>
#include <string>
#include "xgboost_utils.h"
#include <vector>
namespace xgboost{
namespace utils{

31
utils/xgboost_string.h Normal file
View File

@ -0,0 +1,31 @@
#ifndef _XGBOOST_STRING_H_
#define _XGBOOST_STRING_H_
#include<vector>
#include<sstream>
namespace xgboost{
namespace utils{
class StringProcessing{
public:
static std::vector<std::string> &split(const std::string &s, char delim, std::vector<std::string> &elems) {
std::stringstream ss(s);
std::string item;
while (std::getline(ss, item, delim)) {
elems.push_back(item);
}
return elems;
}
static std::vector<std::string> split(const std::string &s, char delim) {
std::vector<std::string> elems;
split(s, delim, elems);
return elems;
}
};
}
}
#endif