gbrt modified
This commit is contained in:
parent
3afd186ea9
commit
fb568a7a47
@ -2,6 +2,7 @@
|
||||
#define _GBRT_H_
|
||||
|
||||
#include "../utils/xgboost_config.h"
|
||||
#include "../utils/xgboost_stream.h"
|
||||
#include "xgboost_regression_data_reader.h"
|
||||
#include "xgboost_gbmbase.h"
|
||||
#include <math.h>
|
||||
@ -36,13 +37,32 @@ public:
|
||||
grad.clear();hess.clear();
|
||||
for(int j = 0; j < instance_num; j++){
|
||||
label = data_reader.GetLabel(j);
|
||||
pred_transform = Logistic(base_model.Predict(data_reader.GetLine(j)));
|
||||
pred_transform = Logistic(Predict(data_reader.GetLine(j)));
|
||||
grad.push_back(FirstOrderGradient(pred_transform,label));
|
||||
hess.push_back(SecondOrderGradient(pred_transform));
|
||||
}
|
||||
base_model.DoBoost(grad,hess,data_reader.GetImage(),root_index );
|
||||
}
|
||||
}
|
||||
|
||||
inline void SaveModel(IStream &fo ){
|
||||
base_model.SaveModel(fo);
|
||||
}
|
||||
|
||||
inline void LoadModel(IStream &fi ){
|
||||
base_model.LoadModel(fi);
|
||||
}
|
||||
|
||||
float Predict( const FMatrixS::Line &feat, int buffer_index = -1, unsigned rid = 0 ){
|
||||
return base_model.Predict(feat,buffer_index,rid);
|
||||
}
|
||||
|
||||
float Predict( const std::vector<float> &feat,
|
||||
const std::vector<bool> &funknown,
|
||||
int buffer_index = -1,
|
||||
unsigned rid = 0 ){
|
||||
return base_model.Predict(feat,funknown,buffer_index,rid);
|
||||
}
|
||||
|
||||
struct GBRTParam{
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user