pass simple test

This commit is contained in:
tqchen 2014-02-20 22:28:05 -08:00
parent e52720976c
commit daab1fef19
5 changed files with 30 additions and 33 deletions

View File

@ -1,35 +1,22 @@
boost_iterations=10
num_round=10
save_period=0
train_path=C:\cygwin64\home\Chen\GitHub\xgboost\demo\regression\train.txt
model_dir_path=C:\cygwin64\home\Chen\GitHub\xgboost\demo\regression\model
validation_paths=C:\cygwin64\home\Chen\GitHub\xgboost\demo\regression\validation.txt
validation_names=validation
test_paths=C:\cygwin64\home\Chen\GitHub\xgboost\demo\regression\test.txt
test_names=test
data = "train.txt"
test:data = "test.txt"
eval[valid] = "validation.txt"
booster_type=1
do_reboost=0
bst:num_feature=3
learning_rate=0.01
bst:learning_rate=0.01
min_child_weight=1
bst:min_child_weight=1
min_split_loss=0.1
bst:min_split_loss=0.1
max_depth=3
bst:max_depth=3
reg_lambda=0.1
bst:reg_lambda=0.1
subsample=1
use_layerwise=0

View File

@ -1,3 +1,4 @@
1 0:1 1:2 2:1
1 0:2 1:1 2:1
0 0:5 1:0 2:0

View File

@ -19,7 +19,9 @@ namespace xgboost{
class RegBoostLearner{
public:
/*! \brief constructor */
RegBoostLearner( void ){}
RegBoostLearner( void ){
silent = 0;
}
/*!
* \brief a regression booter associated with training and evaluating data
* \param train pointer to the training data
@ -29,6 +31,7 @@ namespace xgboost{
RegBoostLearner( const DMatrix *train,
const std::vector<DMatrix *> &evals,
const std::vector<std::string> &evname ){
silent = 0;
this->SetData(train,evals,evname);
}
@ -51,7 +54,10 @@ namespace xgboost{
buffer_size += static_cast<unsigned>( evals[i]->Size() );
}
char snum_pbuffer[25];
printf( snum_pbuffer, "%u", buffer_size );
sprintf( snum_pbuffer, "%u", buffer_size );
if( !silent ){
printf( "buffer_size=%u\n", buffer_size );
}
base_model.SetParam( "num_pbuffer",snum_pbuffer );
}
/*!
@ -60,6 +66,7 @@ namespace xgboost{
* \param val value of the parameter
*/
inline void SetParam( const char *name, const char *val ){
if( !strcmp( name, "silent") ) silent = atoi( val );
mparam.SetParam( name, val );
base_model.SetParam( name, val );
}
@ -173,7 +180,6 @@ namespace xgboost{
float base_score;
/* \brief type of loss function */
int loss_type;
ModelParam( void ){
base_score = 0.5f;
loss_type = 0;
@ -280,6 +286,7 @@ namespace xgboost{
}
};
private:
int silent;
booster::GBMBaseModel base_model;
ModelParam mparam;
const DMatrix *train_;

View File

@ -96,11 +96,11 @@ namespace xgboost{
learner.SetParam( cfg.name(), cfg.val() );
}
if( strcmp( model_in.c_str(), "NULL" ) != 0 ){
utils::Assert( !strcmp( task.c_str(), "train"), "model_in not specified" );
utils::FileStream fi( utils::FopenCheck( model_in.c_str(), "rb") );
learner.LoadModel( fi );
fi.Close();
}else{
utils::Assert( !strcmp( task.c_str(), "train"), "model_in not specified" );
learner.InitModel();
}
learner.InitTrainer();
@ -114,19 +114,19 @@ namespace xgboost{
learner.UpdateOneIter( i );
learner.EvalOneIter( i );
if( save_period != 0 && (i+1) % save_period == 0 ){
SaveModel( i );
this->SaveModel( i );
}
elapsed = (unsigned long)(time(NULL) - start);
}
// always save final round
if( num_round % save_period != 0 ){
SaveModel( num_round );
if( save_period == 0 || num_round % save_period != 0 ){
this->SaveModel( num_round );
}
if( !silent ){
printf("\nupdating end, %lu sec in all\n", elapsed );
}
}
inline void SaveModel( int i ){
inline void SaveModel( int i ) const{
char fname[256];
sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 );
utils::FileStream fo( utils::FopenCheck( fname, "wb" ) );
@ -135,7 +135,9 @@ namespace xgboost{
}
inline void TaskTest( void ){
std::vector<float> preds;
if( !silent ) printf("start prediction...\n");
learner.Predict( preds, data );
if( !silent ) printf("writing prediction to %s\n", name_pred.c_str() );
FILE *fo = utils::FopenCheck( name_pred.c_str(), "w" );
for( size_t i = 0; i < preds.size(); i ++ ){
fprintf( fo, "%f\n", preds[i] );

View File

@ -122,7 +122,7 @@ namespace xgboost{
sprintf( bname, "%s.buffer", fname );
if( !this->LoadBinary( bname, silent ) ){
this->LoadText( fname, silent );
this->SaveBinary( fname, silent );
this->SaveBinary( bname, silent );
}
}
private: