gbrt modified

This commit is contained in:
kalenhaha 2014-02-11 11:07:00 +08:00
parent 3afd186ea9
commit fb568a7a47

View File

@ -2,6 +2,7 @@
#define _GBRT_H_
#include "../utils/xgboost_config.h"
#include "../utils/xgboost_stream.h"
#include "xgboost_regression_data_reader.h"
#include "xgboost_gbmbase.h"
#include <math.h>
@ -36,13 +37,32 @@ public:
grad.clear();hess.clear();
for(int j = 0; j < instance_num; j++){
label = data_reader.GetLabel(j);
pred_transform = Logistic(base_model.Predict(data_reader.GetLine(j)));
pred_transform = Logistic(Predict(data_reader.GetLine(j)));
grad.push_back(FirstOrderGradient(pred_transform,label));
hess.push_back(SecondOrderGradient(pred_transform));
}
base_model.DoBoost(grad,hess,data_reader.GetImage(),root_index );
}
}
inline void SaveModel(IStream &fo ){
base_model.SaveModel(fo);
}
inline void LoadModel(IStream &fi ){
base_model.LoadModel(fi);
}
float Predict( const FMatrixS::Line &feat, int buffer_index = -1, unsigned rid = 0 ){
return base_model.Predict(feat,buffer_index,rid);
}
float Predict( const std::vector<float> &feat,
const std::vector<bool> &funknown,
int buffer_index = -1,
unsigned rid = 0 ){
return base_model.Predict(feat,funknown,buffer_index,rid);
}
struct GBRTParam{