gbrt modified

This commit is contained in:
kalenhaha 2014-02-11 11:07:00 +08:00
parent 3afd186ea9
commit fb568a7a47

View File

@ -2,6 +2,7 @@
#define _GBRT_H_ #define _GBRT_H_
#include "../utils/xgboost_config.h" #include "../utils/xgboost_config.h"
#include "../utils/xgboost_stream.h"
#include "xgboost_regression_data_reader.h" #include "xgboost_regression_data_reader.h"
#include "xgboost_gbmbase.h" #include "xgboost_gbmbase.h"
#include <math.h> #include <math.h>
@ -36,7 +37,7 @@ public:
grad.clear();hess.clear(); grad.clear();hess.clear();
for(int j = 0; j < instance_num; j++){ for(int j = 0; j < instance_num; j++){
label = data_reader.GetLabel(j); label = data_reader.GetLabel(j);
pred_transform = Logistic(base_model.Predict(data_reader.GetLine(j))); pred_transform = Logistic(Predict(data_reader.GetLine(j)));
grad.push_back(FirstOrderGradient(pred_transform,label)); grad.push_back(FirstOrderGradient(pred_transform,label));
hess.push_back(SecondOrderGradient(pred_transform)); hess.push_back(SecondOrderGradient(pred_transform));
} }
@ -44,6 +45,25 @@ public:
} }
} }
inline void SaveModel(IStream &fo ){
base_model.SaveModel(fo);
}
inline void LoadModel(IStream &fi ){
base_model.LoadModel(fi);
}
float Predict( const FMatrixS::Line &feat, int buffer_index = -1, unsigned rid = 0 ){
return base_model.Predict(feat,buffer_index,rid);
}
float Predict( const std::vector<float> &feat,
const std::vector<bool> &funknown,
int buffer_index = -1,
unsigned rid = 0 ){
return base_model.Predict(feat,funknown,buffer_index,rid);
}
struct GBRTParam{ struct GBRTParam{
/*! \brief path of input training data */ /*! \brief path of input training data */