JSON configuration IO. (#5111)

* Add saving/loading JSON configuration.
* Implement Python pickle interface with new IO routines.
* Basic tests for training continuation.
This commit is contained in:
Jiaming Yuan
2019-12-15 17:31:53 +08:00
committed by GitHub
parent 5aa007d7b2
commit 3136185bc5
24 changed files with 761 additions and 390 deletions

View File

@@ -461,9 +461,69 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
* \param out_dptr the argument to hold the output data pointer
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
bst_ulong *out_len,
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, bst_ulong *out_len,
const char **out_dptr);
/*!
* \brief Memory snapshot based serialization method. Saves everything states
* into buffer.
*
* \param handle handle
* \param out_len the argument to hold the output length
* \param out_dptr the argument to hold the output data pointer
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
const char **out_dptr);
/*!
* \brief Memory snapshot based serialization method. Loads the buffer returned
* from `XGBoosterSerializeToBuffer'.
*
* \param handle handle
* \param buf pointer to the buffer
* \param len the length of the buffer
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf, bst_ulong len);
/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version);
/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
/*!
* \brief Save XGBoost's internal configuration into a JSON document.
* \param handle handle to Booster object.
* \param out_str A valid pointer to array of characters. The characters array is
* allocated and managed by XGBoost, while pointer to that array needs to
* be managed by caller.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, bst_ulong *out_len,
char const **out_str);
/*!
* \brief Load XGBoost's internal configuration from a JSON document.
* \param handle handle to Booster object.
* \param json_parameters string representation of a JSON document.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle,
char const *json_parameters);
/*!
* \brief dump model, return array of strings representing model dump
* \param handle handle
@@ -570,25 +630,4 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
bst_ulong* out_len,
const char*** out);
// --- Distributed training API----
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(
BoosterHandle handle,
int* version);
/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
#endif // XGBOOST_C_API_H_

View File

@@ -32,7 +32,7 @@ struct LearnerModelParam;
/*!
* \brief interface of gradient boosting model.
*/
class GradientBooster : public Model {
class GradientBooster : public Model, public Configurable {
protected:
GenericParameter const* generic_param_;

View File

@@ -45,7 +45,7 @@ class Json;
*
* \endcode
*/
class Learner : public Model, public rabit::Serializable {
class Learner : public Model, public Configurable, public rabit::Serializable {
public:
/*! \brief virtual destructor */
~Learner() override;
@@ -53,16 +53,6 @@ class Learner : public Model, public rabit::Serializable {
* \brief Configure Learner based on set parameters.
*/
virtual void Configure() = 0;
/*!
* \brief load model from stream
* \param fi input stream.
*/
void Load(dmlc::Stream* fi) override = 0;
/*!
* \brief save model to stream.
* \param fo output stream
*/
void Save(dmlc::Stream* fo) const override = 0;
/*!
* \brief update the model for one iteration
* With the specified objective function.
@@ -110,6 +100,13 @@ class Learner : public Model, public rabit::Serializable {
bool pred_contribs = false,
bool approx_contribs = false,
bool pred_interactions = false) = 0;
void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;
virtual void LoadModel(dmlc::Stream* fi) = 0;
virtual void SaveModel(dmlc::Stream* fo) const = 0;
/*!
* \brief Set multiple parameters at once.
*

View File

@@ -99,6 +99,7 @@ struct XGBoostParameter : public dmlc::Parameter<Type> {
return unknown;
}
}
bool GetInitialised() const { return static_cast<bool>(this->initialised_); }
};
} // namespace xgboost