fix some bugs

This commit is contained in:
kalenhaha 2014-02-16 11:44:03 +08:00
parent 32e670a4da
commit 6d500b2964
6 changed files with 43 additions and 30 deletions

View File

@ -1,8 +1,8 @@
/*!
* \file xgboost.cpp
* \brief bootser implementations
* \author Tianqi Chen: tianqi.tchen@gmail.com
*/
* \file xgboost.cpp
* \brief bootser implementations
* \author Tianqi Chen: tianqi.tchen@gmail.com
*/
// implementation of boosters go to here
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
@ -15,20 +15,27 @@
#include "linear/xgboost_linear.hpp"
namespace xgboost{
namespace booster{
/*!
* \brief create a gradient booster, given type of booster
* \param booster_type type of gradient booster, can be used to specify implements
* \return the pointer to the gradient booster created
*/
IBooster *CreateBooster( int booster_type ){
switch( booster_type ){
case 0: return new RTreeTrainer();
case 1: return new LinearBooster();
default: utils::Error("unknown booster_type"); return NULL;
}
}
};
namespace booster{
/*£¡
* \brief listing the types of boosters
*/
enum BOOSTER_TYPE_LIST{
TREE,
LINEAR,
};
/*!
* \brief create a gradient booster, given type of booster
* \param booster_type type of gradient booster, can be used to specify implements
* \return the pointer to the gradient booster created
*/
IBooster *CreateBooster( int booster_type ){
switch( booster_type ){
case TREE: return new RTreeTrainer();
case LINEAR: return new LinearBooster();
default: utils::Error("unknown booster_type"); return NULL;
}
}
};
};

View File

@ -52,7 +52,8 @@ namespace xgboost{
buffer_size += (*evals[i]).size();
}
char str[25];
itoa(buffer_size,str,10);
_itoa(buffer_size,str,10);
base_model.SetParam("num_pbuffer",str);
base_model.SetParam("num_pbuffer",str);
}
@ -71,6 +72,7 @@ namespace xgboost{
*/
inline void InitTrainer( void ){
base_model.InitTrainer();
InitModel();
mparam.AdjustBase();
}

View File

@ -5,7 +5,7 @@ using namespace xgboost::regression;
int main(int argc, char *argv[]){
// char* config_path = argv[1];
// bool silent = ( atoi(argv[2]) == 1 );
char* config_path = "c:\\cygwin64\\home\\chen\\github\\gboost\\demo\\regression\\reg.conf";
char* config_path = "c:\\cygwin64\\home\\chen\\github\\xgboost\\demo\\regression\\reg.conf";
bool silent = false;
RegBoostTrain train;
RegBoostTest test;

View File

@ -32,6 +32,7 @@ namespace xgboost{
ConfigIterator config_itr(config_path);
//Get the training data and validation data paths, config the Learner
while (config_itr.Next()){
printf("%s %s\n",config_itr.name(),config_itr.val());
reg_boost_learner->SetParam(config_itr.name(),config_itr.val());
train_param.SetParam(config_itr.name(),config_itr.val());
}
@ -41,6 +42,7 @@ namespace xgboost{
//Load Data
xgboost::regression::DMatrix train;
printf("%s",train_param.train_path);
train.LoadText(train_param.train_path);
std::vector<const xgboost::regression::DMatrix*> evals;
for(int i = 0; i < train_param.validation_data_paths.size(); i++){
@ -70,7 +72,7 @@ namespace xgboost{
void SaveModel(const char* suffix){
char model_path[256];
//save the final round model
sscanf(model_path,"%s/%s",train_param.model_dir_path,suffix);
sprintf(model_path,"%s/%s",train_param.model_dir_path,suffix);
FILE* file = fopen(model_path,"w");
FileStream fin(file);
reg_boost_learner->SaveModel(fin);
@ -85,10 +87,10 @@ namespace xgboost{
int save_period;
/* \brief the path of training data set */
const char* train_path;
char train_path[256];
/* \brief the path of directory containing the saved models */
const char* model_dir_path;
char model_dir_path[256];
/* \brief the paths of validation data sets */
std::vector<std::string> validation_data_paths;
@ -102,10 +104,12 @@ namespace xgboost{
* \param val value of the parameter
*/
inline void SetParam(const char *name,const char *val ){
if( !strcmp("boost_iterations", name ) ) boost_iterations = (float)atof( val );
if( !strcmp("boost_iterations", name ) ) boost_iterations = atoi( val );
if( !strcmp("save_period", name ) ) save_period = atoi( val );
if( !strcmp("train_path", name ) ) train_path = val;
if( !strcmp("model_dir_path", name ) ) model_dir_path = val;
if( !strcmp("train_path", name ) ) strcpy(train_path,val);
if( !strcmp("model_dir_path", name ) ) {
strcpy(model_dir_path,val);
}
if( !strcmp("validation_paths", name) ) {
validation_data_paths = StringProcessing::split(val,';');
}

View File

@ -64,10 +64,9 @@ namespace xgboost{
init = false;
}
}
if( init ){
labels.push_back( label );
data.AddRow( findex, fvalue );
}
labels.push_back( label );
data.AddRow( findex, fvalue );
this->UpdateInfo();
if( !silent ){

View File

@ -5,6 +5,7 @@
* \brief simple utils to support the code
* \author Tianqi Chen: tianqi.tchen@gmail.com
*/
#define _CRT_SECURE_NO_WARNINGS
#ifdef _MSC_VER
#define fopen64 fopen