fix some bugs

This commit is contained in:
kalenhaha 2014-02-16 11:44:03 +08:00
parent 32e670a4da
commit 6d500b2964
6 changed files with 43 additions and 30 deletions

View File

@ -16,6 +16,13 @@
namespace xgboost{ namespace xgboost{
namespace booster{ namespace booster{
/*£¡
* \brief listing the types of boosters
*/
enum BOOSTER_TYPE_LIST{
TREE,
LINEAR,
};
/*! /*!
* \brief create a gradient booster, given type of booster * \brief create a gradient booster, given type of booster
* \param booster_type type of gradient booster, can be used to specify implements * \param booster_type type of gradient booster, can be used to specify implements
@ -23,8 +30,8 @@ namespace xgboost{
*/ */
IBooster *CreateBooster( int booster_type ){ IBooster *CreateBooster( int booster_type ){
switch( booster_type ){ switch( booster_type ){
case 0: return new RTreeTrainer(); case TREE: return new RTreeTrainer();
case 1: return new LinearBooster(); case LINEAR: return new LinearBooster();
default: utils::Error("unknown booster_type"); return NULL; default: utils::Error("unknown booster_type"); return NULL;
} }
} }

View File

@ -52,7 +52,8 @@ namespace xgboost{
buffer_size += (*evals[i]).size(); buffer_size += (*evals[i]).size();
} }
char str[25]; char str[25];
itoa(buffer_size,str,10); _itoa(buffer_size,str,10);
base_model.SetParam("num_pbuffer",str);
base_model.SetParam("num_pbuffer",str); base_model.SetParam("num_pbuffer",str);
} }
@ -71,6 +72,7 @@ namespace xgboost{
*/ */
inline void InitTrainer( void ){ inline void InitTrainer( void ){
base_model.InitTrainer(); base_model.InitTrainer();
InitModel();
mparam.AdjustBase(); mparam.AdjustBase();
} }

View File

@ -5,7 +5,7 @@ using namespace xgboost::regression;
int main(int argc, char *argv[]){ int main(int argc, char *argv[]){
// char* config_path = argv[1]; // char* config_path = argv[1];
// bool silent = ( atoi(argv[2]) == 1 ); // bool silent = ( atoi(argv[2]) == 1 );
char* config_path = "c:\\cygwin64\\home\\chen\\github\\gboost\\demo\\regression\\reg.conf"; char* config_path = "c:\\cygwin64\\home\\chen\\github\\xgboost\\demo\\regression\\reg.conf";
bool silent = false; bool silent = false;
RegBoostTrain train; RegBoostTrain train;
RegBoostTest test; RegBoostTest test;

View File

@ -32,6 +32,7 @@ namespace xgboost{
ConfigIterator config_itr(config_path); ConfigIterator config_itr(config_path);
//Get the training data and validation data paths, config the Learner //Get the training data and validation data paths, config the Learner
while (config_itr.Next()){ while (config_itr.Next()){
printf("%s %s\n",config_itr.name(),config_itr.val());
reg_boost_learner->SetParam(config_itr.name(),config_itr.val()); reg_boost_learner->SetParam(config_itr.name(),config_itr.val());
train_param.SetParam(config_itr.name(),config_itr.val()); train_param.SetParam(config_itr.name(),config_itr.val());
} }
@ -41,6 +42,7 @@ namespace xgboost{
//Load Data //Load Data
xgboost::regression::DMatrix train; xgboost::regression::DMatrix train;
printf("%s",train_param.train_path);
train.LoadText(train_param.train_path); train.LoadText(train_param.train_path);
std::vector<const xgboost::regression::DMatrix*> evals; std::vector<const xgboost::regression::DMatrix*> evals;
for(int i = 0; i < train_param.validation_data_paths.size(); i++){ for(int i = 0; i < train_param.validation_data_paths.size(); i++){
@ -70,7 +72,7 @@ namespace xgboost{
void SaveModel(const char* suffix){ void SaveModel(const char* suffix){
char model_path[256]; char model_path[256];
//save the final round model //save the final round model
sscanf(model_path,"%s/%s",train_param.model_dir_path,suffix); sprintf(model_path,"%s/%s",train_param.model_dir_path,suffix);
FILE* file = fopen(model_path,"w"); FILE* file = fopen(model_path,"w");
FileStream fin(file); FileStream fin(file);
reg_boost_learner->SaveModel(fin); reg_boost_learner->SaveModel(fin);
@ -85,10 +87,10 @@ namespace xgboost{
int save_period; int save_period;
/* \brief the path of training data set */ /* \brief the path of training data set */
const char* train_path; char train_path[256];
/* \brief the path of directory containing the saved models */ /* \brief the path of directory containing the saved models */
const char* model_dir_path; char model_dir_path[256];
/* \brief the paths of validation data sets */ /* \brief the paths of validation data sets */
std::vector<std::string> validation_data_paths; std::vector<std::string> validation_data_paths;
@ -102,10 +104,12 @@ namespace xgboost{
* \param val value of the parameter * \param val value of the parameter
*/ */
inline void SetParam(const char *name,const char *val ){ inline void SetParam(const char *name,const char *val ){
if( !strcmp("boost_iterations", name ) ) boost_iterations = (float)atof( val ); if( !strcmp("boost_iterations", name ) ) boost_iterations = atoi( val );
if( !strcmp("save_period", name ) ) save_period = atoi( val ); if( !strcmp("save_period", name ) ) save_period = atoi( val );
if( !strcmp("train_path", name ) ) train_path = val; if( !strcmp("train_path", name ) ) strcpy(train_path,val);
if( !strcmp("model_dir_path", name ) ) model_dir_path = val; if( !strcmp("model_dir_path", name ) ) {
strcpy(model_dir_path,val);
}
if( !strcmp("validation_paths", name) ) { if( !strcmp("validation_paths", name) ) {
validation_data_paths = StringProcessing::split(val,';'); validation_data_paths = StringProcessing::split(val,';');
} }

View File

@ -64,10 +64,9 @@ namespace xgboost{
init = false; init = false;
} }
} }
if( init ){
labels.push_back( label ); labels.push_back( label );
data.AddRow( findex, fvalue ); data.AddRow( findex, fvalue );
}
this->UpdateInfo(); this->UpdateInfo();
if( !silent ){ if( !silent ){

View File

@ -5,6 +5,7 @@
* \brief simple utils to support the code * \brief simple utils to support the code
* \author Tianqi Chen: tianqi.tchen@gmail.com * \author Tianqi Chen: tianqi.tchen@gmail.com
*/ */
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#ifdef _MSC_VER #ifdef _MSC_VER
#define fopen64 fopen #define fopen64 fopen