Model IO in JSON. (#5110)

This commit is contained in:
Jiaming Yuan
2019-12-11 11:20:40 +08:00
committed by GitHub
parent c7cc657a4d
commit 208ab3b1ff
25 changed files with 667 additions and 165 deletions

View File

@@ -428,7 +428,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
const float **out_result);
/*!
* \brief load model from existing file
* \brief Load model from existing file
* \param handle handle
* \param fname file name
* \return 0 when success, -1 when failure happens
@@ -436,7 +436,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle,
const char *fname);
/*!
* \brief save model into existing file
* \brief Save model into existing file
* \param handle handle
* \param fname file name
* \return 0 when success, -1 when failure happens

View File

@@ -32,7 +32,7 @@ struct LearnerModelParam;
/*!
* \brief interface of gradient boosting model.
*/
class GradientBooster {
class GradientBooster : public Model {
protected:
GenericParameter const* generic_param_;

View File

@@ -21,7 +21,7 @@ class FixedPrecisionStreamContainer : public std::basic_stringstream<
char, std::char_traits<char>, Allocator> {
public:
FixedPrecisionStreamContainer() {
this->precision(std::numeric_limits<Number::Float>::max_digits10);
this->precision(std::numeric_limits<double>::max_digits10);
this->imbue(std::locale("C"));
this->setf(std::ios::scientific);
}

View File

@@ -16,15 +16,15 @@ class Json;
struct Model {
/*!
* \brief Save the model to stream.
* \param fo output write stream
* \brief load the model from a json object
* \param in json object where to load the model from
*/
virtual void SaveModel(dmlc::Stream* fo) const = 0;
virtual void LoadModel(Json const& in) = 0;
/*!
* \brief Load the model from stream.
* \param fi input read stream
* \brief saves the model config to a json object
* \param out json container where to save the model to
*/
virtual void LoadModel(dmlc::Stream* fi) = 0;
virtual void SaveModel(Json* out) const = 0;
};
struct Configurable {

View File

@@ -303,12 +303,15 @@ class RegTree : public Model {
* \brief load model from stream
* \param fi input stream
*/
void LoadModel(dmlc::Stream* fi) override;
void Load(dmlc::Stream* fi);
/*!
* \brief save model to stream
* \param fo output stream
*/
void SaveModel(dmlc::Stream* fo) const override;
void Save(dmlc::Stream* fo) const;
void LoadModel(Json const& in) override;
void SaveModel(Json* out) const override;
bool operator==(const RegTree& b) const {
return nodes_ == b.nodes_ && stats_ == b.stats_ &&