JSON configuration IO. (#5111)
* Add saving/loading JSON configuration. * Implement Python pickle interface with new IO routines. * Basic tests for training continuation.
This commit is contained in:
@@ -461,9 +461,69 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
* \param out_dptr the argument to hold the output data pointer
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, bst_ulong *out_len,
|
||||
const char **out_dptr);
|
||||
|
||||
/*!
|
||||
* \brief Memory snapshot based serialization method. Saves everything states
|
||||
* into buffer.
|
||||
*
|
||||
* \param handle handle
|
||||
* \param out_len the argument to hold the output length
|
||||
* \param out_dptr the argument to hold the output data pointer
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
|
||||
const char **out_dptr);
|
||||
/*!
|
||||
* \brief Memory snapshot based serialization method. Loads the buffer returned
|
||||
* from `XGBoosterSerializeToBuffer'.
|
||||
*
|
||||
* \param handle handle
|
||||
* \param buf pointer to the buffer
|
||||
* \param len the length of the buffer
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
||||
const void *buf, bst_ulong len);
|
||||
|
||||
/*!
|
||||
* \brief Initialize the booster from rabit checkpoint.
|
||||
* This is used in distributed training API.
|
||||
* \param handle handle
|
||||
* \param version The output version of the model.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version);
|
||||
|
||||
/*!
|
||||
* \brief Save the current checkpoint to rabit.
|
||||
* \param handle handle
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Save XGBoost's internal configuration into a JSON document.
|
||||
* \param handle handle to Booster object.
|
||||
* \param out_str A valid pointer to array of characters. The characters array is
|
||||
* allocated and managed by XGBoost, while pointer to that array needs to
|
||||
* be managed by caller.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, bst_ulong *out_len,
|
||||
char const **out_str);
|
||||
/*!
|
||||
* \brief Load XGBoost's internal configuration from a JSON document.
|
||||
* \param handle handle to Booster object.
|
||||
* \param json_parameters string representation of a JSON document.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle,
|
||||
char const *json_parameters);
|
||||
|
||||
/*!
|
||||
* \brief dump model, return array of strings representing model dump
|
||||
* \param handle handle
|
||||
@@ -570,25 +630,4 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
||||
XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
bst_ulong* out_len,
|
||||
const char*** out);
|
||||
|
||||
// --- Distributed training API----
|
||||
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
|
||||
/*!
|
||||
* \brief Initialize the booster from rabit checkpoint.
|
||||
* This is used in distributed training API.
|
||||
* \param handle handle
|
||||
* \param version The output version of the model.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(
|
||||
BoosterHandle handle,
|
||||
int* version);
|
||||
|
||||
/*!
|
||||
* \brief Save the current checkpoint to rabit.
|
||||
* \param handle handle
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
|
||||
|
||||
#endif // XGBOOST_C_API_H_
|
||||
|
||||
@@ -32,7 +32,7 @@ struct LearnerModelParam;
|
||||
/*!
|
||||
* \brief interface of gradient boosting model.
|
||||
*/
|
||||
class GradientBooster : public Model {
|
||||
class GradientBooster : public Model, public Configurable {
|
||||
protected:
|
||||
GenericParameter const* generic_param_;
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class Json;
|
||||
*
|
||||
* \endcode
|
||||
*/
|
||||
class Learner : public Model, public rabit::Serializable {
|
||||
class Learner : public Model, public Configurable, public rabit::Serializable {
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
~Learner() override;
|
||||
@@ -53,16 +53,6 @@ class Learner : public Model, public rabit::Serializable {
|
||||
* \brief Configure Learner based on set parameters.
|
||||
*/
|
||||
virtual void Configure() = 0;
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream.
|
||||
*/
|
||||
void Load(dmlc::Stream* fi) override = 0;
|
||||
/*!
|
||||
* \brief save model to stream.
|
||||
* \param fo output stream
|
||||
*/
|
||||
void Save(dmlc::Stream* fo) const override = 0;
|
||||
/*!
|
||||
* \brief update the model for one iteration
|
||||
* With the specified objective function.
|
||||
@@ -110,6 +100,13 @@ class Learner : public Model, public rabit::Serializable {
|
||||
bool pred_contribs = false,
|
||||
bool approx_contribs = false,
|
||||
bool pred_interactions = false) = 0;
|
||||
|
||||
void LoadModel(Json const& in) override = 0;
|
||||
void SaveModel(Json* out) const override = 0;
|
||||
|
||||
virtual void LoadModel(dmlc::Stream* fi) = 0;
|
||||
virtual void SaveModel(dmlc::Stream* fo) const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Set multiple parameters at once.
|
||||
*
|
||||
|
||||
@@ -99,6 +99,7 @@ struct XGBoostParameter : public dmlc::Parameter<Type> {
|
||||
return unknown;
|
||||
}
|
||||
}
|
||||
bool GetInitialised() const { return static_cast<bool>(this->initialised_); }
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
Reference in New Issue
Block a user