diff --git a/booster/gbrt.h b/booster/gbrt.h index 6046364ea..0ce830b9e 100644 --- a/booster/gbrt.h +++ b/booster/gbrt.h @@ -2,6 +2,7 @@ #define _GBRT_H_ #include "../utils/xgboost_config.h" +#include "../utils/xgboost_stream.h" #include "xgboost_regression_data_reader.h" #include "xgboost_gbmbase.h" #include @@ -36,13 +37,32 @@ public: grad.clear();hess.clear(); for(int j = 0; j < instance_num; j++){ label = data_reader.GetLabel(j); - pred_transform = Logistic(base_model.Predict(data_reader.GetLine(j))); + pred_transform = Logistic(Predict(data_reader.GetLine(j))); grad.push_back(FirstOrderGradient(pred_transform,label)); hess.push_back(SecondOrderGradient(pred_transform)); } base_model.DoBoost(grad,hess,data_reader.GetImage(),root_index ); } } + + inline void SaveModel(IStream &fo ){ + base_model.SaveModel(fo); + } + + inline void LoadModel(IStream &fi ){ + base_model.LoadModel(fi); + } + + float Predict( const FMatrixS::Line &feat, int buffer_index = -1, unsigned rid = 0 ){ + return base_model.Predict(feat,buffer_index,rid); + } + + float Predict( const std::vector &feat, + const std::vector &funknown, + int buffer_index = -1, + unsigned rid = 0 ){ + return base_model.Predict(feat,funknown,buffer_index,rid); + } struct GBRTParam{