JSON configuration IO. (#5111)

* Add saving/loading JSON configuration.
* Implement Python pickle interface with new IO routines.
* Basic tests for training continuation.
This commit is contained in:
Jiaming Yuan 2019-12-15 17:31:53 +08:00 committed by GitHub
parent 5aa007d7b2
commit 3136185bc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 761 additions and 390 deletions

View File

@ -461,9 +461,69 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
* \param out_dptr the argument to hold the output data pointer * \param out_dptr the argument to hold the output data pointer
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, bst_ulong *out_len,
bst_ulong *out_len,
const char **out_dptr); const char **out_dptr);
/*!
* \brief Memory snapshot based serialization method. Saves everything states
* into buffer.
*
* \param handle handle
* \param out_len the argument to hold the output length
* \param out_dptr the argument to hold the output data pointer
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
const char **out_dptr);
/*!
* \brief Memory snapshot based serialization method. Loads the buffer returned
* from `XGBoosterSerializeToBuffer'.
*
* \param handle handle
* \param buf pointer to the buffer
* \param len the length of the buffer
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf, bst_ulong len);
/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version);
/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
/*!
* \brief Save XGBoost's internal configuration into a JSON document.
* \param handle handle to Booster object.
* \param out_str A valid pointer to array of characters. The characters array is
* allocated and managed by XGBoost, while pointer to that array needs to
* be managed by caller.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle, bst_ulong *out_len,
char const **out_str);
/*!
* \brief Load XGBoost's internal configuration from a JSON document.
* \param handle handle to Booster object.
* \param json_parameters string representation of a JSON document.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle,
char const *json_parameters);
/*! /*!
* \brief dump model, return array of strings representing model dump * \brief dump model, return array of strings representing model dump
* \param handle handle * \param handle handle
@ -570,25 +630,4 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle, XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
bst_ulong* out_len, bst_ulong* out_len,
const char*** out); const char*** out);
// --- Distributed training API----
// NOTE: functions in rabit/c_api.h will be also available in libxgboost.so
/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(
BoosterHandle handle,
int* version);
/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);
#endif // XGBOOST_C_API_H_ #endif // XGBOOST_C_API_H_

View File

@ -32,7 +32,7 @@ struct LearnerModelParam;
/*! /*!
* \brief interface of gradient boosting model. * \brief interface of gradient boosting model.
*/ */
class GradientBooster : public Model { class GradientBooster : public Model, public Configurable {
protected: protected:
GenericParameter const* generic_param_; GenericParameter const* generic_param_;

View File

@ -45,7 +45,7 @@ class Json;
* *
* \endcode * \endcode
*/ */
class Learner : public Model, public rabit::Serializable { class Learner : public Model, public Configurable, public rabit::Serializable {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
~Learner() override; ~Learner() override;
@ -53,16 +53,6 @@ class Learner : public Model, public rabit::Serializable {
* \brief Configure Learner based on set parameters. * \brief Configure Learner based on set parameters.
*/ */
virtual void Configure() = 0; virtual void Configure() = 0;
/*!
* \brief load model from stream
* \param fi input stream.
*/
void Load(dmlc::Stream* fi) override = 0;
/*!
* \brief save model to stream.
* \param fo output stream
*/
void Save(dmlc::Stream* fo) const override = 0;
/*! /*!
* \brief update the model for one iteration * \brief update the model for one iteration
* With the specified objective function. * With the specified objective function.
@ -110,6 +100,13 @@ class Learner : public Model, public rabit::Serializable {
bool pred_contribs = false, bool pred_contribs = false,
bool approx_contribs = false, bool approx_contribs = false,
bool pred_interactions = false) = 0; bool pred_interactions = false) = 0;
void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;
virtual void LoadModel(dmlc::Stream* fi) = 0;
virtual void SaveModel(dmlc::Stream* fo) const = 0;
/*! /*!
* \brief Set multiple parameters at once. * \brief Set multiple parameters at once.
* *

View File

@ -99,6 +99,7 @@ struct XGBoostParameter : public dmlc::Parameter<Type> {
return unknown; return unknown;
} }
} }
bool GetInitialised() const { return static_cast<bool>(this->initialised_); }
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -1076,28 +1076,47 @@ class Booster(object):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)), _check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
self.set_param({'seed': 0})
self.set_param(params or {}) self.set_param(params or {})
if (params is not None) and ('booster' in params): if (params is not None) and ('booster' in params):
self.booster = params['booster'] self.booster = params['booster']
else: else:
self.booster = 'gbtree' self.booster = 'gbtree'
if model_file is not None: if isinstance(model_file, Booster):
assert self.handle is not None
# We use the pickle interface for getting memory snapshot from
# another model, and load the snapshot with this booster.
state = model_file.__getstate__()
handle = state['handle']
del state['handle']
ptr = (ctypes.c_char * len(handle)).from_buffer(handle)
length = c_bst_ulong(len(handle))
_check_call(
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
self.__dict__.update(state)
elif isinstance(model_file, (STRING_TYPES, os_PathLike)):
self.load_model(model_file) self.load_model(model_file)
elif model_file is None:
pass
else:
raise TypeError('Unknown type:', model_file)
def __del__(self): def __del__(self):
if self.handle is not None: if hasattr(self, 'handle') and self.handle is not None:
_check_call(_LIB.XGBoosterFree(self.handle)) _check_call(_LIB.XGBoosterFree(self.handle))
self.handle = None self.handle = None
def __getstate__(self): def __getstate__(self):
# can't pickle ctypes pointers # can't pickle ctypes pointers, put model content in bytearray
# put model content in bytearray
this = self.__dict__.copy() this = self.__dict__.copy()
handle = this['handle'] handle = this['handle']
if handle is not None: if handle is not None:
raw = self.save_raw() length = c_bst_ulong()
this["handle"] = raw cptr = ctypes.POINTER(ctypes.c_char)()
_check_call(_LIB.XGBoosterSerializeToBuffer(self.handle,
ctypes.byref(length),
ctypes.byref(cptr)))
buf = ctypes2buffer(cptr, length.value)
this["handle"] = buf
return this return this
def __setstate__(self, state): def __setstate__(self, state):
@ -1107,18 +1126,44 @@ class Booster(object):
buf = handle buf = handle
dmats = c_array(ctypes.c_void_p, []) dmats = c_array(ctypes.c_void_p, [])
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(0), ctypes.byref(handle))) _check_call(_LIB.XGBoosterCreate(
dmats, c_bst_ulong(0), ctypes.byref(handle)))
length = c_bst_ulong(len(buf)) length = c_bst_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf) ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
_check_call(_LIB.XGBoosterLoadModelFromBuffer(handle, ptr, length)) _check_call(
_LIB.XGBoosterUnserializeFromBuffer(handle, ptr, length))
state['handle'] = handle state['handle'] = handle
self.__dict__.update(state) self.__dict__.update(state)
def save_config(self):
'''Output internal parameter configuration of Booster as a JSON
string.'''
json_string = ctypes.c_char_p()
length = c_bst_ulong()
_check_call(_LIB.XGBoosterSaveJsonConfig(
self.handle,
ctypes.byref(length),
ctypes.byref(json_string)))
json_string = json_string.value.decode()
return json_string
def load_config(self, config):
'''Load configuration returned by `save_config`.'''
assert isinstance(config, str)
_check_call(_LIB.XGBoosterLoadJsonConfig(
self.handle,
c_str(config)))
def __copy__(self): def __copy__(self):
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, _): def __deepcopy__(self, _):
return Booster(model_file=self.save_raw()) '''Return a copy of booster. Caches for DMatrix are not copied so continue
training on copied booster will result in lower performance and
slightly different result.
'''
return Booster(model_file=self)
def copy(self): def copy(self):
"""Copy the booster object. """Copy the booster object.
@ -1451,20 +1496,22 @@ class Booster(object):
def save_model(self, fname): def save_model(self, fname):
"""Save the model to a file. """Save the model to a file.
The model is saved in an XGBoost internal binary format which is The model is saved in an XGBoost internal format which is universal
universal among the various XGBoost interfaces. Auxiliary attributes of among the various XGBoost interfaces. Auxiliary attributes of the
the Python Booster object (such as feature_names) will not be saved. Python Booster object (such as feature_names) will not be saved. To
To preserve all attributes, pickle the Booster object. preserve all attributes, pickle the Booster object.
Parameters Parameters
---------- ----------
fname : string or os.PathLike fname : string or os.PathLike
Output file name Output file name
""" """
if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(os_fspath(fname)))) _check_call(_LIB.XGBoosterSaveModel(
self.handle, c_str(os_fspath(fname))))
else: else:
raise TypeError("fname must be a string") raise TypeError("fname must be a string or os_PathLike")
def save_raw(self): def save_raw(self):
"""Save the model to a in memory buffer representation """Save the model to a in memory buffer representation
@ -1481,26 +1528,26 @@ class Booster(object):
return ctypes2buffer(cptr, length.value) return ctypes2buffer(cptr, length.value)
def load_model(self, fname): def load_model(self, fname):
"""Load the model from a file. """Load the model from a file, local or as URI.
The model is loaded from an XGBoost internal binary format which is The model is loaded from an XGBoost format which is universal among the
universal among the various XGBoost interfaces. Auxiliary attributes of various XGBoost interfaces. Auxiliary attributes of the Python Booster
the Python Booster object (such as feature_names) will not be loaded. object (such as feature_names) will not be loaded. To preserve all
To preserve all attributes, pickle the Booster object. attributes, pickle the Booster object.
Parameters Parameters
---------- ----------
fname : string, os.PathLike, or a memory buffer fname : string, os.PathLike, or a memory buffer
Input file name or memory buffer(see also save_raw) Input file name or memory buffer(see also save_raw)
""" """
if isinstance(fname, (STRING_TYPES, os_PathLike)): if isinstance(fname, (STRING_TYPES, os_PathLike)):
# assume file name, cannot use os.path.exist to check, file can be from URL. # assume file name, cannot use os.path.exist to check, file can be
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(os_fspath(fname)))) # from URL.
_check_call(_LIB.XGBoosterLoadModel(
self.handle, c_str(os_fspath(fname))))
else: else:
buf = fname raise TypeError('Unknown file type: ', fname)
length = c_bst_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"): def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""Dump model into a text or JSON file. """Dump model into a text or JSON file.

View File

@ -34,9 +34,8 @@ def _train_internal(params, dtrain,
num_parallel_tree = 1 num_parallel_tree = 1
if xgb_model is not None: if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES): bst = Booster(params, [dtrain] + [d[0] for d in evals],
xgb_model = xgb_model.save_raw() model_file=xgb_model)
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
nboost = len(bst.get_dump()) nboost = len(bst.get_dump())
_params = dict(params) if isinstance(params, list) else params _params = dict(params) if isinstance(params, list) else params

View File

@ -485,6 +485,31 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
API_BEGIN();
CHECK_HANDLE();
std::string str {json_parameters};
Json config { Json::Load(StringView{str.c_str(), str.size()}) };
static_cast<Learner*>(handle)->LoadConfig(config);
API_END();
}
XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle,
xgboost::bst_ulong *out_len,
char const** out_str) {
API_BEGIN();
CHECK_HANDLE();
Json config { Object() };
auto* learner = static_cast<Learner*>(handle);
learner->Configure();
learner->SaveConfig(&config);
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
Json::Dump(config, &raw_str);
*out_str = raw_str.c_str();
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
}
XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle, XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
int iter, int iter,
DMatrixHandle dtrain) { DMatrixHandle dtrain) {
@ -579,7 +604,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
static_cast<Learner*>(handle)->LoadModel(in); static_cast<Learner*>(handle)->LoadModel(in);
} else { } else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r")); std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Learner*>(handle)->Load(fi.get()); static_cast<Learner*>(handle)->LoadModel(fi.get());
} }
API_END(); API_END();
} }
@ -598,20 +623,18 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* c_fname) {
fo->Write(str.c_str(), str.size()); fo->Write(str.c_str(), str.size());
} else { } else {
auto *bst = static_cast<Learner*>(handle); auto *bst = static_cast<Learner*>(handle);
bst->Save(fo.get()); bst->SaveModel(fo.get());
} }
API_END(); API_END();
} }
// The following two functions are `Load` and `Save` for memory based serialization
// methods. E.g. Python pickle.
XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
const void* buf, const void* buf,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*) common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
static_cast<Learner*>(handle)->Load(&fs); static_cast<Learner*>(handle)->LoadModel(&fs);
API_END(); API_END();
} }
@ -621,6 +644,25 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str; std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
raw_str.resize(0); raw_str.resize(0);
API_BEGIN();
CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str);
auto *learner = static_cast<Learner*>(handle);
learner->Configure();
learner->SaveModel(&fo);
*out_dptr = dmlc::BeginPtr(raw_str);
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
API_END();
}
// The following two functions are `Load` and `Save` for memory based
// serialization methods. E.g. Python pickle.
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle,
xgboost::bst_ulong *out_len,
const char **out_dptr) {
std::string &raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
raw_str.resize(0);
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str); common::MemoryBufferStream fo(&raw_str);
@ -632,6 +674,41 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
static_cast<Learner*>(handle)->Load(&fs);
API_END();
}
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
*version = rabit::LoadCheckPoint(bst);
if (*version != 0) {
bst->Configure();
}
API_END();
}
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto* learner = static_cast<Learner*>(handle);
learner->Configure();
if (learner->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(learner);
} else {
rabit::CheckPoint(learner);
}
API_END();
}
inline void XGBoostDumpModelImpl( inline void XGBoostDumpModelImpl(
BoosterHandle handle, BoosterHandle handle,
const FeatureMap& fmap, const FeatureMap& fmap,
@ -758,29 +835,5 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
*version = rabit::LoadCheckPoint(bst);
if (*version != 0) {
bst->Configure();
}
API_END();
}
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Learner*>(handle);
if (bst->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(bst);
} else {
rabit::CheckPoint(bst);
}
API_END();
}
// force link rabit // force link rabit
static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag(); static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag();

View File

@ -99,6 +99,16 @@ class GBLinear : public GradientBooster {
model_.LoadModel(model); model_.LoadModel(model);
} }
void LoadConfig(Json const& in) override {
CHECK_EQ(get<String>(in["name"]), "gblinear");
fromJson(in["gblinear_train_param"], &param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String{"gblinear"};
out["gblinear_train_param"] = toJson(param_);
}
void DoBoost(DMatrix *p_fmat, void DoBoost(DMatrix *p_fmat,
HostDeviceVector<GradientPair> *in_gpair, HostDeviceVector<GradientPair> *in_gpair,
ObjFunction* obj) override { ObjFunction* obj) override {

View File

@ -112,7 +112,8 @@ class GBLinearModel : public Model {
<< " \"weight\": [" << std::endl; << " \"weight\": [" << std::endl;
for (unsigned i = 0; i < nfeature; ++i) { for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
if (i != 0 || gid != 0) fo << "," << std::endl; if (i != 0 || gid != 0)
fo << "," << std::endl;
fo << " " << (*this)[i][gid]; fo << " " << (*this)[i][gid];
} }
} }
@ -134,5 +135,6 @@ class GBLinearModel : public Model {
return v; return v;
} }
}; };
} // namespace gbm } // namespace gbm
} // namespace xgboost } // namespace xgboost

View File

@ -34,6 +34,7 @@ DMLC_REGISTRY_FILE_TAG(gbtree);
void GBTree::Configure(const Args& cfg) { void GBTree::Configure(const Args& cfg) {
this->cfg_ = cfg; this->cfg_ = cfg;
std::string updater_seq = tparam_.updater_seq;
tparam_.UpdateAllowUnknown(cfg); tparam_.UpdateAllowUnknown(cfg);
model_.Configure(cfg); model_.Configure(cfg);
@ -75,24 +76,31 @@ void GBTree::Configure(const Args& cfg) {
"`tree_method` parameter instead."; "`tree_method` parameter instead.";
// Don't drive users to silent XGBOost. // Don't drive users to silent XGBOost.
showed_updater_warning_ = true; showed_updater_warning_ = true;
} else {
this->ConfigureUpdaters();
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
} }
this->ConfigureUpdaters();
if (updater_seq != tparam_.updater_seq) {
updaters_.clear();
this->InitUpdater(cfg);
} else {
for (auto &up : updaters_) { for (auto &up : updaters_) {
up->Configure(cfg); up->Configure(cfg);
} }
}
configured_ = true; configured_ = true;
} }
// FIXME(trivialfis): This handles updaters and predictor. Because the choice of updaters // FIXME(trivialfis): This handles updaters. Because the choice of updaters depends on
// depends on whether external memory is used and how large is dataset. We can remove the // whether external memory is used and how large is dataset. We can remove the dependency
// dependency on DMatrix once `hist` tree method can handle external memory so that we can // on DMatrix once `hist` tree method can handle external memory so that we can make it
// make it default. // default.
void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) { void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
CHECK(this->configured_);
std::string updater_seq = tparam_.updater_seq; std::string updater_seq = tparam_.updater_seq;
CHECK(tparam_.GetInitialised());
tparam_.UpdateAllowUnknown(cfg);
this->PerformTreeMethodHeuristic(fmat); this->PerformTreeMethodHeuristic(fmat);
this->ConfigureUpdaters(); this->ConfigureUpdaters();
@ -101,10 +109,9 @@ void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
if (updater_seq != tparam_.updater_seq) { if (updater_seq != tparam_.updater_seq) {
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq; LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
this->updaters_.clear(); this->updaters_.clear();
}
this->InitUpdater(cfg); this->InitUpdater(cfg);
} }
}
void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) { void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
if (specified_updater_) { if (specified_updater_) {
@ -141,6 +148,9 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
} }
void GBTree::ConfigureUpdaters() { void GBTree::ConfigureUpdaters() {
if (specified_updater_) {
return;
}
// `updater` parameter was manually specified // `updater` parameter was manually specified
/* Choose updaters according to tree_method parameters */ /* Choose updaters according to tree_method parameters */
switch (tparam_.tree_method) { switch (tparam_.tree_method) {
@ -289,6 +299,46 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
monitor_.Stop("CommitModel"); monitor_.Stop("CommitModel");
} }
void GBTree::LoadConfig(Json const& in) {
CHECK_EQ(get<String>(in["name"]), "gbtree");
fromJson(in["gbtree_train_param"], &tparam_);
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}});
}
if (n_gpus == 0 && tparam_.tree_method == TreeMethod::kGPUHist) {
tparam_.UpdateAllowUnknown(Args{{"tree_method", "hist"}});
LOG(WARNING)
<< "Loading from a raw memory buffer on CPU only machine. "
"Change tree_method to hist.";
}
auto const& j_updaters = get<Object const>(in["updater"]);
updaters_.clear();
for (auto const& kv : j_updaters) {
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(kv.first, generic_param_));
up->LoadConfig(kv.second);
updaters_.push_back(std::move(up));
}
specified_updater_ = get<Boolean>(in["specified_updater"]);
}
void GBTree::SaveConfig(Json* p_out) const {
auto& out = *p_out;
out["name"] = String("gbtree");
out["gbtree_train_param"] = toJson(tparam_);
out["updater"] = Object();
auto& j_updaters = out["updater"];
for (auto const& up : updaters_) {
j_updaters[up->Name()] = Object();
auto& j_up = j_updaters[up->Name()];
up->SaveConfig(&j_up);
}
out["specified_updater"] = Boolean{specified_updater_};
}
void GBTree::LoadModel(Json const& in) { void GBTree::LoadModel(Json const& in) {
CHECK_EQ(get<String>(in["name"]), "gbtree"); CHECK_EQ(get<String>(in["name"]), "gbtree");
model_.LoadModel(in["model"]); model_.LoadModel(in["model"]);
@ -324,7 +374,7 @@ class Dart : public GBTree {
for (size_t i = 0; i < weight_drop_.size(); ++i) { for (size_t i = 0; i < weight_drop_.size(); ++i) {
j_weight_drop[i] = Number(weight_drop_[i]); j_weight_drop[i] = Number(weight_drop_[i]);
} }
out["weight_drop"] = Array(j_weight_drop); out["weight_drop"] = Array(std::move(j_weight_drop));
} }
void LoadModel(Json const& in) override { void LoadModel(Json const& in) override {
CHECK_EQ(get<String>(in["name"]), "dart"); CHECK_EQ(get<String>(in["name"]), "dart");
@ -352,6 +402,21 @@ class Dart : public GBTree {
} }
} }
void LoadConfig(Json const& in) override {
CHECK_EQ(get<String>(in["name"]), "dart");
auto const& gbtree = in["gbtree"];
GBTree::LoadConfig(gbtree);
fromJson(in["dart_train_param"], &dparam_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("dart");
out["gbtree"] = Object();
auto& gbtree = out["gbtree"];
GBTree::SaveConfig(&gbtree);
out["dart_train_param"] = toJson(dparam_);
}
// predict the leaf scores with dropout if ntree_limit = 0 // predict the leaf scores with dropout if ntree_limit = 0
void PredictBatch(DMatrix* p_fmat, void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds, HostDeviceVector<bst_float>* out_preds,

View File

@ -192,6 +192,9 @@ class GBTree : public GradientBooster {
model_.Save(fo); model_.Save(fo);
} }
void LoadConfig(Json const& in) override;
void SaveConfig(Json* p_out) const override;
void SaveModel(Json* p_out) const override; void SaveModel(Json* p_out) const override;
void LoadModel(Json const& in) override; void LoadModel(Json const& in) override;

View File

@ -46,7 +46,8 @@ void GBTreeModel::SaveModel(Json* p_out) const {
for (auto const& tree : trees) { for (auto const& tree : trees) {
Json tree_json{Object()}; Json tree_json{Object()};
tree->SaveModel(&tree_json); tree->SaveModel(&tree_json);
tree_json["id"] = std::to_string(t); // The field is not used in XGBoost, but might be useful for external project.
tree_json["id"] = Integer(t);
trees_json.emplace_back(tree_json); trees_json.emplace_back(tree_json);
t++; t++;
} }

View File

@ -30,6 +30,7 @@
#include "common/common.h" #include "common/common.h"
#include "common/io.h" #include "common/io.h"
#include "common/observer.h"
#include "common/random.h" #include "common/random.h"
#include "common/timer.h" #include "common/timer.h"
#include "common/version.h" #include "common/version.h"
@ -37,27 +38,6 @@
namespace { namespace {
const char* kMaxDeltaStepDefaultValue = "0.7"; const char* kMaxDeltaStepDefaultValue = "0.7";
inline bool IsFloat(const std::string& str) {
std::stringstream ss(str);
float f{};
return !((ss >> std::noskipws >> f).rdstate() ^ std::ios_base::eofbit);
}
inline bool IsInt(const std::string& str) {
std::stringstream ss(str);
int i{};
return !((ss >> std::noskipws >> i).rdstate() ^ std::ios_base::eofbit);
}
inline std::string RenderParamVal(const std::string& str) {
if (IsFloat(str) || IsInt(str)) {
return str;
} else {
return std::string("'") + str + "'";
}
}
} // anonymous namespace } // anonymous namespace
namespace xgboost { namespace xgboost {
@ -323,11 +303,77 @@ class LearnerImpl : public Learner {
} }
} }
void Load(dmlc::Stream* fi) override { void LoadConfig(Json const& in) override {
CHECK(IsA<Object>(in));
Version::Load(in, true);
auto const& learner_parameters = get<Object>(in["learner"]);
fromJson(learner_parameters.at("learner_train_param"), &tparam_);
auto const& gradient_booster = learner_parameters.at("gradient_booster");
auto const& objective_fn = learner_parameters.at("objective");
if (!obj_) {
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
}
obj_->LoadConfig(objective_fn);
tparam_.booster = get<String>(gradient_booster["name"]);
if (!gbm_) {
gbm_.reset(GradientBooster::Create(tparam_.booster,
&generic_parameters_, &learner_model_param_,
cache_));
}
gbm_->LoadConfig(gradient_booster);
auto const& j_metrics = learner_parameters.at("metrics");
auto n_metrics = get<Array const>(j_metrics).size();
metric_names_.resize(n_metrics);
metrics_.resize(n_metrics);
for (size_t i = 0; i < n_metrics; ++i) {
metric_names_[i]= get<String>(j_metrics[i]);
metrics_[i] = std::unique_ptr<Metric>(
Metric::Create(metric_names_.back(), &generic_parameters_));
}
fromJson(learner_parameters.at("generic_param"), &generic_parameters_);
this->need_configuration_ = true;
}
void SaveConfig(Json* p_out) const override {
CHECK(!this->need_configuration_) << "Call Configure before saving model.";
Version::Save(p_out);
Json& out { *p_out };
// parameters
out["learner"] = Object();
auto& learner_parameters = out["learner"];
learner_parameters["learner_train_param"] = toJson(tparam_);
learner_parameters["gradient_booster"] = Object();
auto& gradient_booster = learner_parameters["gradient_booster"];
gbm_->SaveConfig(&gradient_booster);
learner_parameters["objective"] = Object();
auto& objective_fn = learner_parameters["objective"];
obj_->SaveConfig(&objective_fn);
std::vector<Json> metrics(metrics_.size());
for (size_t i = 0; i < metrics_.size(); ++i) {
metrics[i] = String(metrics_[i]->Name());
}
learner_parameters["metrics"] = Array(metrics);
learner_parameters["generic_param"] = toJson(generic_parameters_);
}
// About to be deprecated by JSON format
void LoadModel(dmlc::Stream* fi) override {
generic_parameters_.UpdateAllowUnknown(Args{}); generic_parameters_.UpdateAllowUnknown(Args{});
tparam_.Init(std::vector<std::pair<std::string, std::string>>{}); tparam_.Init(std::vector<std::pair<std::string, std::string>>{});
// TODO(tqchen) mark deprecation of old format. // TODO(tqchen) mark deprecation of old format.
common::PeekableInStream fp(fi); common::PeekableInStream fp(fi);
// backward compatible header check. // backward compatible header check.
std::string header; std::string header;
header.resize(4); header.resize(4);
@ -338,6 +384,15 @@ class LearnerImpl : public Learner {
CHECK_EQ(fp.Read(&header[0], 4), 4U); CHECK_EQ(fp.Read(&header[0], 4), 4U);
} }
} }
if (header[0] == '{') {
auto json_stream = common::FixedSizeStream(&fp);
std::string buffer;
json_stream.Take(&buffer);
auto model = Json::Load({buffer.c_str(), buffer.size()});
this->LoadModel(model);
return;
}
// use the peekable reader. // use the peekable reader.
fi = &fp; fi = &fp;
// read parameter // read parameter
@ -370,43 +425,9 @@ class LearnerImpl : public Learner {
std::vector<std::pair<std::string, std::string> > attr; std::vector<std::pair<std::string, std::string> > attr;
fi->Read(&attr); fi->Read(&attr);
for (auto& kv : attr) { for (auto& kv : attr) {
// Load `predictor`, `gpu_id` parameters from extra attributes
const std::string prefix = "SAVED_PARAM_"; const std::string prefix = "SAVED_PARAM_";
if (kv.first.find(prefix) == 0) { if (kv.first.find(prefix) == 0) {
const std::string saved_param = kv.first.substr(prefix.length()); const std::string saved_param = kv.first.substr(prefix.length());
bool is_gpu_predictor = saved_param == "predictor" && kv.second == "gpu_predictor";
#ifdef XGBOOST_USE_CUDA
if (saved_param == "predictor" || saved_param == "gpu_id") {
cfg_[saved_param] = kv.second;
LOG(INFO)
<< "Parameter '" << saved_param << "' has been recovered from "
<< "the saved model. It will be set to "
<< RenderParamVal(kv.second) << " for prediction. To "
<< "override the predictor behavior, explicitly set '"
<< saved_param << "' parameter as follows:\n"
<< " * Python package: bst.set_param('"
<< saved_param << "', [new value])\n"
<< " * R package: xgb.parameters(bst) <- list("
<< saved_param << " = [new value])\n"
<< " * JVM packages: bst.setParam(\""
<< saved_param << "\", [new value])";
}
#else
if (is_gpu_predictor) {
cfg_["predictor"] = "cpu_predictor";
kv.second = "cpu_predictor";
}
#endif // XGBOOST_USE_CUDA
#if defined(XGBOOST_USE_CUDA)
// NO visible GPU in current environment
if (is_gpu_predictor && common::AllVisibleGPUs() == 0) {
cfg_["predictor"] = "cpu_predictor";
kv.second = "cpu_predictor";
LOG(INFO) << "Switch gpu_predictor to cpu_predictor.";
} else if (is_gpu_predictor) {
cfg_["predictor"] = "gpu_predictor";
}
#endif // defined(XGBOOST_USE_CUDA)
if (saved_configs_.find(saved_param) != saved_configs_.end()) { if (saved_configs_.find(saved_param) != saved_configs_.end()) {
cfg_[saved_param] = kv.second; cfg_[saved_param] = kv.second;
} }
@ -447,26 +468,12 @@ class LearnerImpl : public Learner {
tparam_.dsplit = DataSplitMode::kRow; tparam_.dsplit = DataSplitMode::kRow;
} }
// There's no logic for state machine for binary IO, as it has a mix of everything and
// half loaded model.
this->Configure(); this->Configure();
} }
// rabit save model to rabit checkpoint // Save model into binary format. The code is about to be deprecated by more robust
void Save(dmlc::Stream* fo) const override { // JSON serialization format.
if (this->need_configuration_) { void SaveModel(dmlc::Stream* fo) const override {
// Save empty model. Calling Configure in a dummy LearnerImpl avoids violating
// constness.
LearnerImpl empty(std::move(this->cache_));
empty.SetParams({this->cfg_.cbegin(), this->cfg_.cend()});
for (auto const& kv : attributes_) {
empty.SetAttr(kv.first, kv.second);
}
empty.Configure();
empty.Save(fo);
return;
}
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
std::vector<std::pair<std::string, std::string> > extra_attr; std::vector<std::pair<std::string, std::string> > extra_attr;
// extra attributed to be added just before saving // extra attributed to be added just before saving
@ -479,14 +486,13 @@ class LearnerImpl : public Learner {
} }
} }
{ {
std::vector<std::string> saved_params{"predictor", "gpu_id"}; std::vector<std::string> saved_params;
// check if rabit_bootstrap_cache were set to non zero before adding to checkpoint // check if rabit_bootstrap_cache were set to non zero before adding to checkpoint
if (cfg_.find("rabit_bootstrap_cache") != cfg_.end() && if (cfg_.find("rabit_bootstrap_cache") != cfg_.end() &&
(cfg_.find("rabit_bootstrap_cache"))->second != "0") { (cfg_.find("rabit_bootstrap_cache"))->second != "0") {
std::copy(saved_configs_.begin(), saved_configs_.end(), std::copy(saved_configs_.begin(), saved_configs_.end(),
std::back_inserter(saved_params)); std::back_inserter(saved_params));
} }
// Write `predictor`, `n_gpus`, `gpu_id` parameters as extra attributes
for (const auto& key : saved_params) { for (const auto& key : saved_params) {
auto it = cfg_.find(key); auto it = cfg_.find(key);
if (it != cfg_.end()) { if (it != cfg_.end()) {
@ -495,19 +501,6 @@ class LearnerImpl : public Learner {
} }
} }
} }
#if defined(XGBOOST_USE_CUDA)
{
// Force save gpu_id.
if (std::none_of(extra_attr.cbegin(), extra_attr.cend(),
[](std::pair<std::string, std::string> const& it) {
return it.first == "SAVED_PARAM_gpu_id";
})) {
mparam.contain_extra_attrs = 1;
extra_attr.emplace_back("SAVED_PARAM_gpu_id",
std::to_string(generic_parameters_.gpu_id));
}
}
#endif // defined(XGBOOST_USE_CUDA)
fo->Write(&mparam, sizeof(LearnerModelParamLegacy)); fo->Write(&mparam, sizeof(LearnerModelParamLegacy));
fo->Write(tparam_.objective); fo->Write(tparam_.objective);
fo->Write(tparam_.booster); fo->Write(tparam_.booster);
@ -541,6 +534,69 @@ class LearnerImpl : public Learner {
} }
} }
void Save(dmlc::Stream* fo) const override {
if (generic_parameters_.enable_experimental_json_serialization) {
Json memory_snapshot{Object()};
memory_snapshot["Model"] = Object();
auto &model = memory_snapshot["Model"];
this->SaveModel(&model);
memory_snapshot["Config"] = Object();
auto &config = memory_snapshot["Config"];
this->SaveConfig(&config);
std::string out_str;
Json::Dump(memory_snapshot, &out_str);
fo->Write(out_str.c_str(), out_str.size());
} else {
std::string binary_buf;
common::MemoryBufferStream s(&binary_buf);
this->SaveModel(&s);
Json config{ Object() };
// Do not use std::size_t as it's not portable.
int64_t const json_offset = binary_buf.size();
this->SaveConfig(&config);
std::string config_str;
Json::Dump(config, &config_str);
// concatonate the model and config at final output, it's a temporary solution for
// continuing support for binary model format
fo->Write(&serialisation_header_[0], serialisation_header_.size());
fo->Write(&json_offset, sizeof(json_offset));
fo->Write(&binary_buf[0], binary_buf.size());
fo->Write(&config_str[0], config_str.size());
}
}
void Load(dmlc::Stream* fi) override {
common::PeekableInStream fp(fi);
char c {0};
fp.PeekRead(&c, 1);
if (c == '{') {
std::string buffer;
common::FixedSizeStream{&fp}.Take(&buffer);
auto memory_snapshot = Json::Load({buffer.c_str(), buffer.size()});
this->LoadModel(memory_snapshot["Model"]);
this->LoadConfig(memory_snapshot["Config"]);
} else {
std::string header;
header.resize(serialisation_header_.size());
CHECK_EQ(fp.Read(&header[0], header.size()), serialisation_header_.size());
CHECK_EQ(header, serialisation_header_);
int64_t json_offset {-1};
CHECK_EQ(fp.Read(&json_offset, sizeof(json_offset)), sizeof(json_offset));
CHECK_GT(json_offset, 0);
std::string buffer;
common::FixedSizeStream{&fp}.Take(&buffer);
common::MemoryFixSizeBuffer binary_buf(&buffer[0], json_offset);
this->LoadModel(&binary_buf);
common::MemoryFixSizeBuffer json_buf {&buffer[0] + json_offset,
buffer.size() - json_offset};
auto config = Json::Load({buffer.c_str() + json_offset, buffer.size() - json_offset});
this->LoadConfig(config);
}
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const override { std::string format) const override {
@ -551,6 +607,7 @@ class LearnerImpl : public Learner {
void UpdateOneIter(int iter, DMatrix* train) override { void UpdateOneIter(int iter, DMatrix* train) override {
monitor_.Start("UpdateOneIter"); monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter);
this->Configure(); this->Configure();
if (generic_parameters_.seed_per_iteration || rabit::IsDistributed()) { if (generic_parameters_.seed_per_iteration || rabit::IsDistributed()) {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter); common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
@ -561,9 +618,13 @@ class LearnerImpl : public Learner {
monitor_.Start("PredictRaw"); monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_[train]); this->PredictRaw(train, &preds_[train]);
monitor_.Stop("PredictRaw"); monitor_.Stop("PredictRaw");
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
monitor_.Start("GetGradient"); monitor_.Start("GetGradient");
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_); obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient"); monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients");
gbm_->DoBoost(train, &gpair_, obj_.get()); gbm_->DoBoost(train, &gpair_, obj_.get());
monitor_.Stop("UpdateOneIter"); monitor_.Stop("UpdateOneIter");
} }
@ -792,6 +853,10 @@ class LearnerImpl : public Learner {
LearnerModelParamLegacy mparam_; LearnerModelParamLegacy mparam_;
LearnerModelParam learner_model_param_; LearnerModelParam learner_model_param_;
LearnerTrainParam tparam_; LearnerTrainParam tparam_;
// Used to identify the offset of JSON string when
// `enable_experimental_json_serialization' is set to false. Will be removed once JSON
// takes over.
std::string const serialisation_header_ { u8"CONFIG-offset:" };
// configurations // configurations
std::map<std::string, std::string> cfg_; std::map<std::string, std::string> cfg_;
std::map<std::string, std::string> attributes_; std::map<std::string, std::string> attributes_;
@ -811,9 +876,8 @@ class LearnerImpl : public Learner {
common::Monitor monitor_; common::Monitor monitor_;
/*! \brief saved config keys used to restore failed worker */ /*! \brief (Deprecated) saved config keys used to restore failed worker */
std::set<std::string> saved_configs_ = {"max_depth", "tree_method", "dsplit", std::set<std::string> saved_configs_ = {"num_round"};
"seed", "silent", "num_round", "gamma", "min_child_weight"};
}; };
std::string const LearnerImpl::kEvalMetric {"eval_metric"}; // NOLINT std::string const LearnerImpl::kEvalMetric {"eval_metric"}; // NOLINT

View File

@ -682,13 +682,13 @@ void RegTree::LoadModel(Json const& in) {
s.leaf_child_cnt = get<Integer const>(leaf_child_counts[i]); s.leaf_child_cnt = get<Integer const>(leaf_child_counts[i]);
auto& n = nodes_[i]; auto& n = nodes_[i];
auto left = get<Integer const>(lefts[i]); bst_node_t left = get<Integer const>(lefts[i]);
auto right = get<Integer const>(rights[i]); bst_node_t right = get<Integer const>(rights[i]);
auto parent = get<Integer const>(parents[i]); bst_node_t parent = get<Integer const>(parents[i]);
auto ind = get<Integer const>(indices[i]); bst_feature_t ind = get<Integer const>(indices[i]);
auto cond = get<Number const>(conds[i]); float cond { get<Number const>(conds[i]) };
auto dft_left = get<Boolean const>(default_left[i]); bool dft_left { get<Boolean const>(default_left[i]) };
n = Node(left, right, parent, ind, cond, dft_left); n = Node{left, right, parent, ind, cond, dft_left};
} }

View File

@ -1027,8 +1027,6 @@ class GPUHistMakerSpecialised {
param_.UpdateAllowUnknown(args); param_.UpdateAllowUnknown(args);
generic_param_ = generic_param; generic_param_ = generic_param;
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
device_ = generic_param_->gpu_id;
CHECK_GE(device_, 0) << "Must have at least one device";
dh::CheckComputeCapability(); dh::CheckComputeCapability();
monitor_.Init("updater_gpu_hist"); monitor_.Init("updater_gpu_hist");
@ -1041,6 +1039,7 @@ class GPUHistMakerSpecialised {
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) { const std::vector<RegTree*>& trees) {
monitor_.StartCuda("Update"); monitor_.StartCuda("Update");
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
@ -1064,6 +1063,8 @@ class GPUHistMakerSpecialised {
} }
void InitDataOnce(DMatrix* dmat) { void InitDataOnce(DMatrix* dmat) {
device_ = generic_param_->gpu_id;
CHECK_GE(device_, 0) << "Must have at least one device";
info_ = &dmat->Info(); info_ = &dmat->Info();
reducer_.Init({device_}); reducer_.Init({device_});
@ -1162,14 +1163,24 @@ class GPUHistMakerSpecialised {
class GPUHistMaker : public TreeUpdater { class GPUHistMaker : public TreeUpdater {
public: public:
void Configure(const Args& args) override { void Configure(const Args& args) override {
// Used in test to count how many configurations are performed
LOG(DEBUG) << "[GPU Hist]: Configure";
hist_maker_param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args);
float_maker_.reset(); // The passed in args can be empty, if we simply purge the old maker without
double_maker_.reset(); // preserving parameters then we can't do Update on it.
TrainParam param;
if (float_maker_) {
param = float_maker_->param_;
} else if (double_maker_) {
param = double_maker_->param_;
}
if (hist_maker_param_.single_precision_histogram) { if (hist_maker_param_.single_precision_histogram) {
float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>()); float_maker_.reset(new GPUHistMakerSpecialised<GradientPair>());
float_maker_->param_ = param;
float_maker_->Configure(args, tparam_); float_maker_->Configure(args, tparam_);
} else { } else {
double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>()); double_maker_.reset(new GPUHistMakerSpecialised<GradientPairPrecise>());
double_maker_->param_ = param;
double_maker_->Configure(args, tparam_); double_maker_->Configure(args, tparam_);
} }
} }

View File

@ -8,6 +8,7 @@
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/common/io.h" #include "../../../src/common/io.h"
TEST(c_api, XGDMatrixCreateFromMatDT) { TEST(c_api, XGDMatrixCreateFromMatDT) {
std::vector<int> col0 = {0, -1, 3}; std::vector<int> col0 = {0, -1, 3};
std::vector<float> col1 = {-4.0f, 2.0f, 0.0f}; std::vector<float> col1 = {-4.0f, 2.0f, 0.0f};
@ -77,6 +78,40 @@ TEST(c_api, Version) {
ASSERT_EQ(patch, XGBOOST_VER_PATCH); ASSERT_EQ(patch, XGBOOST_VER_PATCH);
} }
TEST(c_api, ConfigIO) {
size_t constexpr kRows = 10;
auto pp_dmat = CreateDMatrix(kRows, 10, 0);
auto p_dmat = *pp_dmat;
std::vector<std::shared_ptr<DMatrix>> mat {p_dmat};
std::vector<bst_float> labels(kRows);
for (size_t i = 0; i < labels.size(); ++i) {
labels[i] = i;
}
p_dmat->Info().labels_.HostVector() = labels;
std::shared_ptr<Learner> learner { Learner::Create(mat) };
BoosterHandle handle = learner.get();
learner->UpdateOneIter(0, p_dmat.get());
char const* out[1];
bst_ulong len {0};
XGBoosterSaveJsonConfig(handle, &len, out);
std::string config_str_0 { out[0] };
auto config_0 = Json::Load({config_str_0.c_str(), config_str_0.size()});
XGBoosterLoadJsonConfig(handle, out[0]);
bst_ulong len_1 {0};
std::string config_str_1 { out[0] };
XGBoosterSaveJsonConfig(handle, &len_1, out);
auto config_1 = Json::Load({config_str_1.c_str(), config_str_1.size()});
ASSERT_EQ(config_0, config_1);
delete pp_dmat;
}
TEST(c_api, Json_ModelIO) { TEST(c_api, Json_ModelIO) {
size_t constexpr kRows = 10; size_t constexpr kRows = 10;
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;

View File

@ -117,15 +117,28 @@ TEST(GBTree, Json_IO) {
CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) }; CreateTrainedGBM("gbtree", Args{}, kRows, kCols, &mparam, &gparam) };
Json model {Object()}; Json model {Object()};
model["model"] = Object();
auto& j_model = model["model"];
gbm->SaveModel(&model); model["config"] = Object();
auto& j_param = model["config"];
gbm->SaveModel(&j_model);
gbm->SaveConfig(&j_param);
std::string model_str; std::string model_str;
Json::Dump(model, &model_str); Json::Dump(model, &model_str);
auto loaded_model = Json::Load(StringView{model_str.c_str(), model_str.size()}); model = Json::Load({model_str.c_str(), model_str.size()});
ASSERT_EQ(get<String>(loaded_model["name"]), "gbtree"); ASSERT_EQ(get<String>(model["model"]["name"]), "gbtree");
ASSERT_TRUE(IsA<Object>(loaded_model["model"]["gbtree_model_param"]));
auto const& gbtree_model = model["model"]["model"];
ASSERT_EQ(get<Array>(gbtree_model["trees"]).size(), 1);
ASSERT_EQ(get<Integer>(get<Object>(get<Array>(gbtree_model["trees"]).front()).at("id")), 0);
ASSERT_EQ(get<Array>(gbtree_model["tree_info"]).size(), 1);
auto j_train_param = model["config"]["gbtree_train_param"];
ASSERT_EQ(get<String>(j_train_param["num_parallel_tree"]), "1");
} }
TEST(Dart, Json_IO) { TEST(Dart, Json_IO) {
@ -145,20 +158,21 @@ TEST(Dart, Json_IO) {
Json model {Object()}; Json model {Object()};
model["model"] = Object(); model["model"] = Object();
auto& j_model = model["model"]; auto& j_model = model["model"];
model["parameters"] = Object(); model["config"] = Object();
auto& j_param = model["config"];
gbm->SaveModel(&j_model); gbm->SaveModel(&j_model);
gbm->SaveConfig(&j_param);
std::string model_str; std::string model_str;
Json::Dump(model, &model_str); Json::Dump(model, &model_str);
model = Json::Load({model_str.c_str(), model_str.size()}); model = Json::Load({model_str.c_str(), model_str.size()});
{ ASSERT_EQ(get<String>(model["model"]["name"]), "dart") << model;
auto const& gbtree = model["model"]["gbtree"]; ASSERT_EQ(get<String>(model["config"]["name"]), "dart");
ASSERT_TRUE(IsA<Object>(gbtree)); ASSERT_TRUE(IsA<Object>(model["model"]["gbtree"]));
ASSERT_EQ(get<String>(model["model"]["name"]), "dart");
ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0); ASSERT_NE(get<Array>(model["model"]["weight_drop"]).size(), 0);
} }
}
} // namespace xgboost } // namespace xgboost

View File

@ -13,23 +13,6 @@
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/gbm/gbtree_model.h" #include "../../../src/gbm/gbtree_model.h"
namespace {
inline void CheckCAPICall(int ret) {
ASSERT_EQ(ret, 0) << XGBGetLastError();
}
} // namespace anonymous
const std::map<std::string, std::string>&
QueryBoosterConfigurationArguments(BoosterHandle handle) {
CHECK_NE(handle, static_cast<void*>(nullptr));
auto* bst = static_cast<xgboost::Learner*>(handle);
bst->Configure();
return bst->GetConfigurationArguments();
}
namespace xgboost { namespace xgboost {
namespace predictor { namespace predictor {
@ -110,77 +93,5 @@ TEST(gpu_predictor, ExternalMemoryTest) {
} }
} }
} }
// Test whether pickling preserves predictor parameters
TEST(gpu_predictor, PicklingTest) {
int const gpuid = 0;
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateBigTestData(tmp_file, 600);
DMatrixHandle dmat[1];
BoosterHandle bst, bst2;
std::vector<bst_float> label;
for (int i = 0; i < 200; ++i) {
label.push_back((i % 2 ? 1 : 0));
}
// Load data matrix
ASSERT_EQ(XGDMatrixCreateFromFile(
tmp_file.c_str(), 0, &dmat[0]), 0) << XGBGetLastError();
ASSERT_EQ(XGDMatrixSetFloatInfo(
dmat[0], "label", label.data(), 200), 0) << XGBGetLastError();
// Create booster
ASSERT_EQ(XGBoosterCreate(dmat, 1, &bst), 0) << XGBGetLastError();
// Set parameters
ASSERT_EQ(XGBoosterSetParam(bst, "seed", "0"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(bst, "base_score", "0.5"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(bst, "booster", "gbtree"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(bst, "learning_rate", "0.01"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(bst, "max_depth", "8"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(
bst, "objective", "binary:logistic"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(bst, "seed", "123"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(
bst, "tree_method", "gpu_hist"), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(
bst, "gpu_id", std::to_string(gpuid).c_str()), 0) << XGBGetLastError();
ASSERT_EQ(XGBoosterSetParam(bst, "predictor", "gpu_predictor"), 0) << XGBGetLastError();
// Run boosting iterations
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(XGBoosterUpdateOneIter(bst, i, dmat[0]), 0) << XGBGetLastError();
}
// Delete matrix
CheckCAPICall(XGDMatrixFree(dmat[0]));
// Pickle
const char* dptr;
bst_ulong len;
std::string buf;
CheckCAPICall(XGBoosterGetModelRaw(bst, &len, &dptr));
buf = std::string(dptr, len);
CheckCAPICall(XGBoosterFree(bst));
// Unpickle
CheckCAPICall(XGBoosterCreate(nullptr, 0, &bst2));
CheckCAPICall(XGBoosterLoadModelFromBuffer(bst2, buf.c_str(), len));
{ // Query predictor
const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
ASSERT_EQ(kwargs.at("predictor"), "gpu_predictor");
ASSERT_EQ(kwargs.at("gpu_id"), std::to_string(gpuid).c_str());
}
{ // Change predictor and query again
CheckCAPICall(XGBoosterSetParam(bst2, "predictor", "cpu_predictor"));
const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
ASSERT_EQ(kwargs.at("predictor"), "cpu_predictor");
}
CheckCAPICall(XGBoosterFree(bst2));
}
} // namespace predictor } // namespace predictor
} // namespace xgboost } // namespace xgboost

View File

@ -1,20 +1,39 @@
'''Loading a pickled model generated by test_pickling.py''' '''Loading a pickled model generated by test_pickling.py, only used by
import pickle `test_gpu_with_dask.py`'''
import unittest import unittest
import os import os
import xgboost as xgb import xgboost as xgb
import sys import json
sys.path.append("tests/python") from test_gpu_pickling import build_dataset, model_path, load_pickle
from test_pickling import build_dataset, model_path
class TestLoadPickle(unittest.TestCase): class TestLoadPickle(unittest.TestCase):
def test_load_pkl(self): def test_load_pkl(self):
assert os.environ['CUDA_VISIBLE_DEVICES'] == '' '''Test whether prediction is correct.'''
with open(model_path, 'rb') as fd: assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
bst = pickle.load(fd) bst = load_pickle(model_path)
x, y = build_dataset() x, y = build_dataset()
test_x = xgb.DMatrix(x) test_x = xgb.DMatrix(x)
res = bst.predict(test_x) res = bst.predict(test_x)
assert len(res) == 10 assert len(res) == 10
def test_predictor_type_is_auto(self):
'''Under invalid CUDA_VISIBLE_DEVICES, predictor should be set to
auto'''
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'auto'
def test_predictor_type_is_gpu(self):
'''When CUDA_VISIBLE_DEVICES is not specified, keep using
`gpu_predictor`'''
assert 'CUDA_VISIBLE_DEVICES' not in os.environ.keys()
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'gpu_predictor'

View File

@ -4,7 +4,7 @@ import unittest
import numpy as np import numpy as np
import subprocess import subprocess
import os import os
import sys import json
import xgboost as xgb import xgboost as xgb
from xgboost import XGBClassifier from xgboost import XGBClassifier
@ -39,18 +39,17 @@ class TestPickling(unittest.TestCase):
bst = xgb.train(param, train_x) bst = xgb.train(param, train_x)
save_pickle(bst, model_path) save_pickle(bst, model_path)
args = ["pytest", args = [
"--verbose", "pytest", "--verbose", "-s", "--fulltrace",
"-s", "./tests/python-gpu/load_pickle.py::TestLoadPickle::test_load_pkl"
"--fulltrace", ]
"./tests/python-gpu/load_pickle.py"]
command = '' command = ''
for arg in args: for arg in args:
command += arg command += arg
command += ' ' command += ' '
cuda_environment = {'CUDA_VISIBLE_DEVICES': ''} cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
env = os.environ env = os.environ.copy()
# Passing new_environment directly to `env' argument results # Passing new_environment directly to `env' argument results
# in failure on Windows: # in failure on Windows:
# Fatal Python error: _Py_HashRandomization_Init: failed to # Fatal Python error: _Py_HashRandomization_Init: failed to
@ -62,12 +61,55 @@ class TestPickling(unittest.TestCase):
assert status == 0 assert status == 0
os.remove(model_path) os.remove(model_path)
def test_pickled_predictor(self):
args_templae = [
"pytest",
"--verbose",
"-s",
"--fulltrace"]
x, y = build_dataset()
train_x = xgb.DMatrix(x, label=y)
param = {'tree_method': 'gpu_hist',
'verbosity': 1, 'predictor': 'gpu_predictor'}
bst = xgb.train(param, train_x)
config = json.loads(bst.save_config())
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'gpu_predictor'
save_pickle(bst, model_path)
args = args_templae.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_predictor_type_is_auto")
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
env = os.environ.copy()
env.update(cuda_environment)
# Load model in a CPU only environment.
status = subprocess.call(args, env=env)
assert status == 0
args = args_templae.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_predictor_type_is_gpu")
# Load in environment that has GPU.
env = os.environ.copy()
assert 'CUDA_VISIBLE_DEVICES' not in env.keys()
status = subprocess.call(args, env=env)
assert status == 0
def test_predict_sklearn_pickle(self): def test_predict_sklearn_pickle(self):
x, y = build_dataset() x, y = build_dataset()
kwargs = {'tree_method': 'gpu_hist', kwargs = {'tree_method': 'gpu_hist',
'predictor': 'gpu_predictor', 'predictor': 'gpu_predictor',
'verbosity': 2, 'verbosity': 1,
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'n_estimators': 10} 'n_estimators': 10}

View File

@ -7,23 +7,25 @@ rng = np.random.RandomState(1994)
class TestGPUTrainingContinuation(unittest.TestCase): class TestGPUTrainingContinuation(unittest.TestCase):
def test_training_continuation_binary(self): def run_training_continuation(self, use_json):
kRows = 32 kRows = 64
kCols = 16 kCols = 32
X = np.random.randn(kRows, kCols) X = np.random.randn(kRows, kCols)
y = np.random.randn(kRows) y = np.random.randn(kRows)
dtrain = xgb.DMatrix(X, y) dtrain = xgb.DMatrix(X, y)
params = {'tree_method': 'gpu_hist', 'max_depth': '2'} params = {'tree_method': 'gpu_hist', 'max_depth': '2',
bst_0 = xgb.train(params, dtrain, num_boost_round=4) 'gamma': '0.1', 'alpha': '0.01',
'enable_experimental_json_serialization': use_json}
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
dump_0 = bst_0.get_dump(dump_format='json') dump_0 = bst_0.get_dump(dump_format='json')
bst_1 = xgb.train(params, dtrain, num_boost_round=2) bst_1 = xgb.train(params, dtrain, num_boost_round=32)
bst_1 = xgb.train(params, dtrain, num_boost_round=2, xgb_model=bst_1) bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
dump_1 = bst_1.get_dump(dump_format='json') dump_1 = bst_1.get_dump(dump_format='json')
def recursive_compare(obj_0, obj_1): def recursive_compare(obj_0, obj_1):
if isinstance(obj_0, float): if isinstance(obj_0, float):
assert np.isclose(obj_0, obj_1) assert np.isclose(obj_0, obj_1, atol=1e-6)
elif isinstance(obj_0, str): elif isinstance(obj_0, str):
assert obj_0 == obj_1 assert obj_0 == obj_1
elif isinstance(obj_0, int): elif isinstance(obj_0, int):
@ -42,7 +44,14 @@ class TestGPUTrainingContinuation(unittest.TestCase):
for i in range(len(obj_0)): for i in range(len(obj_0)):
recursive_compare(obj_0[i], obj_1[i]) recursive_compare(obj_0[i], obj_1[i])
assert len(dump_0) == len(dump_1)
for i in range(len(dump_0)): for i in range(len(dump_0)):
obj_0 = json.loads(dump_0[i]) obj_0 = json.loads(dump_0[i])
obj_1 = json.loads(dump_1[i]) obj_1 = json.loads(dump_1[i])
recursive_compare(obj_0, obj_1) recursive_compare(obj_0, obj_1)
def test_gpu_training_continuation_binary(self):
self.run_training_continuation(False)
def test_gpu_training_continuation_json(self):
self.run_training_continuation(True)

View File

@ -203,7 +203,7 @@ class TestModels(unittest.TestCase):
self.assertRaises(ValueError, bst.predict, dm1) self.assertRaises(ValueError, bst.predict, dm1)
bst.predict(dm2) # success bst.predict(dm2) # success
def test_json_model_io(self): def test_model_json_io(self):
X = np.random.random((10, 3)) X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,)) y = np.random.randint(2, size=(10,))

View File

@ -2,6 +2,7 @@ import pickle
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import os import os
import unittest
kRows = 100 kRows = 100
@ -14,13 +15,8 @@ def generate_data():
return X, y return X, y
def test_model_pickling(): class TestPickling(unittest.TestCase):
xgb_params = { def run_model_pickling(self, xgb_params):
'verbosity': 0,
'nthread': 1,
'tree_method': 'hist'
}
X, y = generate_data() X, y = generate_data()
dtrain = xgb.DMatrix(X, y) dtrain = xgb.DMatrix(X, y)
bst = xgb.train(xgb_params, dtrain) bst = xgb.train(xgb_params, dtrain)
@ -46,3 +42,18 @@ def test_model_pickling():
if os.path.exists(filename): if os.path.exists(filename):
os.remove(filename) os.remove(filename)
def test_model_pickling_binary(self):
params = {
'nthread': 1,
'tree_method': 'hist'
}
self.run_model_pickling(params)
def test_model_pickling_json(self):
params = {
'nthread': 1,
'tree_method': 'hist',
'enable_experimental_json_serialization': True
}
self.run_model_pickling(params)

View File

@ -10,26 +10,35 @@ rng = np.random.RandomState(1337)
class TestTrainingContinuation(unittest.TestCase): class TestTrainingContinuation(unittest.TestCase):
num_parallel_tree = 3 num_parallel_tree = 3
xgb_params_01 = { def generate_parameters(self, use_json):
'verbosity': 0, xgb_params_01_binary = {
'nthread': 1, 'nthread': 1,
} }
xgb_params_02 = { xgb_params_02_binary = {
'verbosity': 0,
'nthread': 1, 'nthread': 1,
'num_parallel_tree': num_parallel_tree 'num_parallel_tree': self.num_parallel_tree
} }
xgb_params_03 = { xgb_params_03_binary = {
'verbosity': 0,
'nthread': 1, 'nthread': 1,
'num_class': 5, 'num_class': 5,
'num_parallel_tree': num_parallel_tree 'num_parallel_tree': self.num_parallel_tree
} }
if use_json:
xgb_params_01_binary[
'enable_experimental_json_serialization'] = True
xgb_params_02_binary[
'enable_experimental_json_serialization'] = True
xgb_params_03_binary[
'enable_experimental_json_serialization'] = True
@pytest.mark.skipif(**tm.no_sklearn()) return [
def test_training_continuation(self): xgb_params_01_binary, xgb_params_02_binary, xgb_params_03_binary
]
def run_training_continuation(self, xgb_params_01, xgb_params_02,
xgb_params_03):
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
@ -45,18 +54,18 @@ class TestTrainingContinuation(unittest.TestCase):
dtrain_2class = xgb.DMatrix(X_2class, label=y_2class) dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
dtrain_5class = xgb.DMatrix(X_5class, label=y_5class) dtrain_5class = xgb.DMatrix(X_5class, label=y_5class)
gbdt_01 = xgb.train(self.xgb_params_01, dtrain_2class, gbdt_01 = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=10) num_boost_round=10)
ntrees_01 = len(gbdt_01.get_dump()) ntrees_01 = len(gbdt_01.get_dump())
assert ntrees_01 == 10 assert ntrees_01 == 10
gbdt_02 = xgb.train(self.xgb_params_01, dtrain_2class, gbdt_02 = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=0) num_boost_round=0)
gbdt_02.save_model('xgb_tc.model') gbdt_02.save_model('xgb_tc.model')
gbdt_02a = xgb.train(self.xgb_params_01, dtrain_2class, gbdt_02a = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=10, xgb_model=gbdt_02) num_boost_round=10, xgb_model=gbdt_02)
gbdt_02b = xgb.train(self.xgb_params_01, dtrain_2class, gbdt_02b = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=10, xgb_model="xgb_tc.model") num_boost_round=10, xgb_model="xgb_tc.model")
ntrees_02a = len(gbdt_02a.get_dump()) ntrees_02a = len(gbdt_02a.get_dump())
ntrees_02b = len(gbdt_02b.get_dump()) ntrees_02b = len(gbdt_02b.get_dump())
@ -71,13 +80,13 @@ class TestTrainingContinuation(unittest.TestCase):
res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class)) res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
assert res1 == res2 assert res1 == res2
gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class, gbdt_03 = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=3) num_boost_round=3)
gbdt_03.save_model('xgb_tc.model') gbdt_03.save_model('xgb_tc.model')
gbdt_03a = xgb.train(self.xgb_params_01, dtrain_2class, gbdt_03a = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=7, xgb_model=gbdt_03) num_boost_round=7, xgb_model=gbdt_03)
gbdt_03b = xgb.train(self.xgb_params_01, dtrain_2class, gbdt_03b = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=7, xgb_model="xgb_tc.model") num_boost_round=7, xgb_model="xgb_tc.model")
ntrees_03a = len(gbdt_03a.get_dump()) ntrees_03a = len(gbdt_03a.get_dump())
ntrees_03b = len(gbdt_03b.get_dump()) ntrees_03b = len(gbdt_03b.get_dump())
@ -88,7 +97,7 @@ class TestTrainingContinuation(unittest.TestCase):
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class)) res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
assert res1 == res2 assert res1 == res2
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
num_boost_round=3) num_boost_round=3)
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration +
1) * self.num_parallel_tree 1) * self.num_parallel_tree
@ -100,7 +109,7 @@ class TestTrainingContinuation(unittest.TestCase):
ntree_limit=gbdt_04.best_ntree_limit)) ntree_limit=gbdt_04.best_ntree_limit))
assert res1 == res2 assert res1 == res2
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, gbdt_04 = xgb.train(xgb_params_02, dtrain_2class,
num_boost_round=7, xgb_model=gbdt_04) num_boost_round=7, xgb_model=gbdt_04)
assert gbdt_04.best_ntree_limit == ( assert gbdt_04.best_ntree_limit == (
gbdt_04.best_iteration + 1) * self.num_parallel_tree gbdt_04.best_iteration + 1) * self.num_parallel_tree
@ -112,11 +121,11 @@ class TestTrainingContinuation(unittest.TestCase):
ntree_limit=gbdt_04.best_ntree_limit)) ntree_limit=gbdt_04.best_ntree_limit))
assert res1 == res2 assert res1 == res2
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
num_boost_round=7) num_boost_round=7)
assert gbdt_05.best_ntree_limit == ( assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree gbdt_05.best_iteration + 1) * self.num_parallel_tree
gbdt_05 = xgb.train(self.xgb_params_03, gbdt_05 = xgb.train(xgb_params_03,
dtrain_5class, dtrain_5class,
num_boost_round=3, num_boost_round=3,
xgb_model=gbdt_05) xgb_model=gbdt_05)
@ -127,3 +136,32 @@ class TestTrainingContinuation(unittest.TestCase):
res2 = gbdt_05.predict(dtrain_5class, res2 = gbdt_05.predict(dtrain_5class,
ntree_limit=gbdt_05.best_ntree_limit) ntree_limit=gbdt_05.best_ntree_limit)
np.testing.assert_almost_equal(res1, res2) np.testing.assert_almost_equal(res1, res2)
@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_binary(self):
params = self.generate_parameters(False)
self.run_training_continuation(params[0], params[1], params[2])
@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_json(self):
params = self.generate_parameters(True)
for p in params:
p['enable_experimental_json_serialization'] = True
self.run_training_continuation(params[0], params[1], params[2])
@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_updaters_binary(self):
updaters = 'grow_colmaker,prune,refresh'
params = self.generate_parameters(False)
for p in params:
p['updater'] = updaters
self.run_training_continuation(params[0], params[1], params[2])
@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation_updaters_json(self):
# Picked up from R tests.
updaters = 'grow_colmaker,prune,refresh'
params = self.generate_parameters(True)
for p in params:
p['updater'] = updaters
self.run_training_continuation(params[0], params[1], params[2])