Implement new save_raw in Python. (#7572)
* Expose the new C API function to Python. * Remove old document and helper script. * Small optimization to the `save_raw` and Json ctors.
This commit is contained in:
parent
9f20a3315e
commit
dac9eb13bd
@ -1,79 +0,0 @@
|
|||||||
'''This is a simple script that converts a pickled XGBoost
|
|
||||||
Scikit-Learn interface object from 0.90 to a native model. Pickle
|
|
||||||
format is not stable as it's a direct serialization of Python object.
|
|
||||||
We advice not to use it when stability is needed.
|
|
||||||
|
|
||||||
'''
|
|
||||||
import pickle
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
import xgboost
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
def save_label_encoder(le):
|
|
||||||
'''Save the label encoder in XGBClassifier'''
|
|
||||||
meta = dict()
|
|
||||||
for k, v in le.__dict__.items():
|
|
||||||
if isinstance(v, np.ndarray):
|
|
||||||
meta[k] = v.tolist()
|
|
||||||
else:
|
|
||||||
meta[k] = v
|
|
||||||
return meta
|
|
||||||
|
|
||||||
|
|
||||||
def xgboost_skl_90to100(skl_model):
|
|
||||||
'''Extract the model and related metadata in SKL model.'''
|
|
||||||
model = {}
|
|
||||||
with open(skl_model, 'rb') as fd:
|
|
||||||
old = pickle.load(fd)
|
|
||||||
if not isinstance(old, xgboost.XGBModel):
|
|
||||||
raise TypeError(
|
|
||||||
'The script only handes Scikit-Learn interface object')
|
|
||||||
|
|
||||||
# Save Scikit-Learn specific Python attributes into a JSON document.
|
|
||||||
for k, v in old.__dict__.items():
|
|
||||||
if k == '_le':
|
|
||||||
model[k] = save_label_encoder(v)
|
|
||||||
elif k == 'classes_':
|
|
||||||
model[k] = v.tolist()
|
|
||||||
elif k == '_Booster':
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
json.dumps({k: v})
|
|
||||||
model[k] = v
|
|
||||||
except TypeError:
|
|
||||||
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
|
||||||
booster = old.get_booster()
|
|
||||||
# Store the JSON serialization as an attribute
|
|
||||||
booster.set_attr(scikit_learn=json.dumps(model))
|
|
||||||
|
|
||||||
# Save it into a native model.
|
|
||||||
i = 0
|
|
||||||
while True:
|
|
||||||
path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin'
|
|
||||||
if os.path.exists(path):
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
booster.save_model(path)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version'
|
|
||||||
' that generates this pickle.')
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description=('A simple script to convert pickle generated by'
|
|
||||||
' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).'))
|
|
||||||
parser.add_argument(
|
|
||||||
'--old-pickle',
|
|
||||||
type=str,
|
|
||||||
help='Path to old pickle file of Scikit-Learn interface object. '
|
|
||||||
'Will output a native model converted from this pickle file',
|
|
||||||
required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
xgboost_skl_90to100(args.old_pickle)
|
|
||||||
@ -2,16 +2,18 @@
|
|||||||
Introduction to Model IO
|
Introduction to Model IO
|
||||||
########################
|
########################
|
||||||
|
|
||||||
In XGBoost 1.0.0, we introduced experimental support of using `JSON
|
In XGBoost 1.0.0, we introduced support of using `JSON
|
||||||
<https://www.json.org/json-en.html>`_ for saving/loading XGBoost models and related
|
<https://www.json.org/json-en.html>`_ for saving/loading XGBoost models and related
|
||||||
hyper-parameters for training, aiming to replace the old binary internal format with an
|
hyper-parameters for training, aiming to replace the old binary internal format with an
|
||||||
open format that can be easily reused. The support for binary format will be continued in
|
open format that can be easily reused. Later in XGBoost 1.6.0, additional support for
|
||||||
the future until JSON format is no-longer experimental and has satisfying performance.
|
`Universal Binary JSON <https://ubjson.org/>`__ is added as an optimization for more
|
||||||
This tutorial aims to share some basic insights into the JSON serialisation method used in
|
efficient model IO. They have the same document structure with different representations,
|
||||||
XGBoost. Without explicitly mentioned, the following sections assume you are using the
|
and we will refer them collectively as the JSON format. This tutorial aims to share some
|
||||||
JSON format, which can be enabled by providing the file name with ``.json`` as file
|
basic insights into the JSON serialisation method used in XGBoost. Without explicitly
|
||||||
extension when saving/loading model: ``booster.save_model('model.json')``. More details
|
mentioned, the following sections assume you are using the one of the 2 outputs formats,
|
||||||
below.
|
which can be enabled by providing the file name with ``.json`` (or ``.ubj`` for binary
|
||||||
|
JSON) as file extension when saving/loading model: ``booster.save_model('model.json')``.
|
||||||
|
More details below.
|
||||||
|
|
||||||
Before we get started, XGBoost is a gradient boosting library with focus on tree model,
|
Before we get started, XGBoost is a gradient boosting library with focus on tree model,
|
||||||
which means inside XGBoost, there are 2 distinct parts:
|
which means inside XGBoost, there are 2 distinct parts:
|
||||||
@ -53,7 +55,8 @@ Other language bindings are still working in progress.
|
|||||||
based serialisation methods.
|
based serialisation methods.
|
||||||
|
|
||||||
To enable JSON format support for model IO (saving only the trees and objective), provide
|
To enable JSON format support for model IO (saving only the trees and objective), provide
|
||||||
a filename with ``.json`` as file extension:
|
a filename with ``.json`` or ``.ubj`` as file extension, the latter is the extension for
|
||||||
|
`Universal Binary JSON <https://ubjson.org/>`__
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
:caption: Python
|
:caption: Python
|
||||||
@ -65,7 +68,7 @@ a filename with ``.json`` as file extension:
|
|||||||
|
|
||||||
xgb.save(bst, 'model_file_name.json')
|
xgb.save(bst, 'model_file_name.json')
|
||||||
|
|
||||||
While for memory snapshot, JSON is the default starting with xgboost 1.3.
|
While for memory snapshot, UBJSON is the default starting with xgboost 1.6.
|
||||||
|
|
||||||
***************************************************************
|
***************************************************************
|
||||||
A note on backward compatibility of models and memory snapshots
|
A note on backward compatibility of models and memory snapshots
|
||||||
@ -105,15 +108,10 @@ Loading pickled file from different version of XGBoost
|
|||||||
|
|
||||||
As noted, pickled model is neither portable nor stable, but in some cases the pickled
|
As noted, pickled model is neither portable nor stable, but in some cases the pickled
|
||||||
models are valuable. One way to restore it in the future is to load it back with that
|
models are valuable. One way to restore it in the future is to load it back with that
|
||||||
specific version of Python and XGBoost, export the model by calling `save_model`. To help
|
specific version of Python and XGBoost, export the model by calling `save_model`.
|
||||||
easing the mitigation, we created a simple script for converting pickled XGBoost 0.90
|
|
||||||
Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script
|
|
||||||
suits simple use cases, and it's advised not to use pickle when stability is needed. It's
|
|
||||||
located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See comments in
|
|
||||||
the script for more details.
|
|
||||||
|
|
||||||
A similar procedure may be used to recover the model persisted in an old RDS file. In R, you are
|
A similar procedure may be used to recover the model persisted in an old RDS file. In R,
|
||||||
able to install an older version of XGBoost using the ``remotes`` package:
|
you are able to install an older version of XGBoost using the ``remotes`` package:
|
||||||
|
|
||||||
.. code-block:: r
|
.. code-block:: r
|
||||||
|
|
||||||
@ -244,10 +242,3 @@ leaf directly, instead it saves the weights as a separated array.
|
|||||||
|
|
||||||
.. include:: ../model.schema
|
.. include:: ../model.schema
|
||||||
:code: json
|
:code: json
|
||||||
|
|
||||||
************
|
|
||||||
Future Plans
|
|
||||||
************
|
|
||||||
|
|
||||||
Right now using the JSON format incurs longer serialisation time, we have been working on
|
|
||||||
optimizing the JSON implementation to close the gap between binary format and JSON format.
|
|
||||||
|
|||||||
@ -89,9 +89,10 @@ class JsonString : public Value {
|
|||||||
JsonString(std::string const& str) : // NOLINT
|
JsonString(std::string const& str) : // NOLINT
|
||||||
Value(ValueKind::kString), str_{str} {}
|
Value(ValueKind::kString), str_{str} {}
|
||||||
JsonString(std::string&& str) noexcept : // NOLINT
|
JsonString(std::string&& str) noexcept : // NOLINT
|
||||||
Value(ValueKind::kString), str_{std::move(str)} {}
|
Value(ValueKind::kString), str_{std::forward<std::string>(str)} {}
|
||||||
JsonString(JsonString&& str) noexcept : // NOLINT
|
JsonString(JsonString&& str) noexcept : Value(ValueKind::kString) { // NOLINT
|
||||||
Value(ValueKind::kString), str_{std::move(str.str_)} {}
|
std::swap(str.str_, this->str_);
|
||||||
|
}
|
||||||
|
|
||||||
void Save(JsonWriter* writer) const override;
|
void Save(JsonWriter* writer) const override;
|
||||||
|
|
||||||
@ -111,8 +112,8 @@ class JsonArray : public Value {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
JsonArray() : Value(ValueKind::kArray) {}
|
JsonArray() : Value(ValueKind::kArray) {}
|
||||||
JsonArray(std::vector<Json>&& arr) noexcept : // NOLINT
|
JsonArray(std::vector<Json>&& arr) noexcept // NOLINT
|
||||||
Value(ValueKind::kArray), vec_{std::move(arr)} {}
|
: Value(ValueKind::kArray), vec_{std::forward<std::vector<Json>>(arr)} {}
|
||||||
JsonArray(std::vector<Json> const& arr) : // NOLINT
|
JsonArray(std::vector<Json> const& arr) : // NOLINT
|
||||||
Value(ValueKind::kArray), vec_{arr} {}
|
Value(ValueKind::kArray), vec_{arr} {}
|
||||||
JsonArray(JsonArray const& that) = delete;
|
JsonArray(JsonArray const& that) = delete;
|
||||||
@ -381,10 +382,9 @@ class Json {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
// array
|
// array
|
||||||
explicit Json(JsonArray list) :
|
explicit Json(JsonArray&& list) : ptr_{new JsonArray(std::forward<JsonArray>(list))} {}
|
||||||
ptr_ {new JsonArray(std::move(list))} {}
|
Json& operator=(JsonArray&& array) {
|
||||||
Json& operator=(JsonArray array) {
|
ptr_.reset(new JsonArray(std::forward<JsonArray>(array)));
|
||||||
ptr_.reset(new JsonArray(std::move(array)));
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
// typed array
|
// typed array
|
||||||
@ -397,17 +397,15 @@ class Json {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
// object
|
// object
|
||||||
explicit Json(JsonObject object) :
|
explicit Json(JsonObject&& object) : ptr_{new JsonObject(std::forward<JsonObject>(object))} {}
|
||||||
ptr_{new JsonObject(std::move(object))} {}
|
Json& operator=(JsonObject&& object) {
|
||||||
Json& operator=(JsonObject object) {
|
ptr_.reset(new JsonObject(std::forward<JsonObject>(object)));
|
||||||
ptr_.reset(new JsonObject(std::move(object)));
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
// string
|
// string
|
||||||
explicit Json(JsonString str) :
|
explicit Json(JsonString&& str) : ptr_{new JsonString(std::forward<JsonString>(str))} {}
|
||||||
ptr_{new JsonString(std::move(str))} {}
|
Json& operator=(JsonString&& str) {
|
||||||
Json& operator=(JsonString str) {
|
ptr_.reset(new JsonString(std::forward<JsonString>(str)));
|
||||||
ptr_.reset(new JsonString(std::move(str)));
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
// bool
|
// bool
|
||||||
|
|||||||
@ -45,6 +45,8 @@ enum class PredictionType : std::uint8_t { // NOLINT
|
|||||||
struct XGBAPIThreadLocalEntry {
|
struct XGBAPIThreadLocalEntry {
|
||||||
/*! \brief result holder for returning string */
|
/*! \brief result holder for returning string */
|
||||||
std::string ret_str;
|
std::string ret_str;
|
||||||
|
/*! \brief result holder for returning raw buffer */
|
||||||
|
std::vector<char> ret_char_vec;
|
||||||
/*! \brief result holder for returning strings */
|
/*! \brief result holder for returning strings */
|
||||||
std::vector<std::string> ret_vec_str;
|
std::vector<std::string> ret_vec_str;
|
||||||
/*! \brief result holder for returning string pointers */
|
/*! \brief result holder for returning string pointers */
|
||||||
|
|||||||
@ -2135,9 +2135,15 @@ class Booster:
|
|||||||
|
|
||||||
The model is saved in an XGBoost internal format which is universal among the
|
The model is saved in an XGBoost internal format which is universal among the
|
||||||
various XGBoost interfaces. Auxiliary attributes of the Python Booster object
|
various XGBoost interfaces. Auxiliary attributes of the Python Booster object
|
||||||
(such as feature_names) will not be saved when using binary format. To save those
|
(such as feature_names) will not be saved when using binary format. To save
|
||||||
attributes, use JSON instead. See :doc:`Model IO </tutorials/saving_model>` for
|
those attributes, use JSON/UBJ instead. See :doc:`Model IO
|
||||||
more info.
|
</tutorials/saving_model>` for more info.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model.save_model("model.json")
|
||||||
|
# or
|
||||||
|
model.save_model("model.ubj")
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -2152,18 +2158,28 @@ class Booster:
|
|||||||
else:
|
else:
|
||||||
raise TypeError("fname must be a string or os PathLike")
|
raise TypeError("fname must be a string or os PathLike")
|
||||||
|
|
||||||
def save_raw(self) -> bytearray:
|
def save_raw(self, raw_format: str = "deprecated") -> bytearray:
|
||||||
"""Save the model to a in memory buffer representation instead of file.
|
"""Save the model to a in memory buffer representation instead of file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_format :
|
||||||
|
Format of output buffer. Can be `json`, `ubj` or `deprecated`. Right now
|
||||||
|
the default is `deprecated` but it will be changed to `ubj` (univeral binary
|
||||||
|
json) in the future.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
a in memory buffer representation of the model
|
An in memory buffer representation of the model
|
||||||
"""
|
"""
|
||||||
length = c_bst_ulong()
|
length = c_bst_ulong()
|
||||||
cptr = ctypes.POINTER(ctypes.c_char)()
|
cptr = ctypes.POINTER(ctypes.c_char)()
|
||||||
_check_call(_LIB.XGBoosterGetModelRaw(self.handle,
|
config = from_pystr_to_cstr(json.dumps({"format": raw_format}))
|
||||||
ctypes.byref(length),
|
_check_call(
|
||||||
ctypes.byref(cptr)))
|
_LIB.XGBoosterSaveModelToBuffer(
|
||||||
|
self.handle, config, ctypes.byref(length), ctypes.byref(cptr)
|
||||||
|
)
|
||||||
|
)
|
||||||
return ctypes2buffer(cptr, length.value)
|
return ctypes2buffer(cptr, length.value)
|
||||||
|
|
||||||
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
|
||||||
@ -2173,8 +2189,14 @@ class Booster:
|
|||||||
The model is loaded from XGBoost format which is universal among the various
|
The model is loaded from XGBoost format which is universal among the various
|
||||||
XGBoost interfaces. Auxiliary attributes of the Python Booster object (such as
|
XGBoost interfaces. Auxiliary attributes of the Python Booster object (such as
|
||||||
feature_names) will not be loaded when using binary format. To save those
|
feature_names) will not be loaded when using binary format. To save those
|
||||||
attributes, use JSON instead. See :doc:`Model IO </tutorials/saving_model>` for
|
attributes, use JSON/UBJ instead. See :doc:`Model IO </tutorials/saving_model>`
|
||||||
more info.
|
for more info.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model.load_model("model.json")
|
||||||
|
# or
|
||||||
|
model.load_model("model.ubj")
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|||||||
@ -971,28 +971,34 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
|
|||||||
auto format = RequiredArg<String>(config, "format", __func__);
|
auto format = RequiredArg<String>(config, "format", __func__);
|
||||||
|
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
std::string &raw_str = learner->GetThreadLocal().ret_str;
|
|
||||||
raw_str.clear();
|
|
||||||
|
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
|
|
||||||
|
auto save_json = [&](std::ios::openmode mode) {
|
||||||
|
std::vector<char> &raw_char_vec = learner->GetThreadLocal().ret_char_vec;
|
||||||
|
Json out{Object{}};
|
||||||
|
learner->SaveModel(&out);
|
||||||
|
Json::Dump(out, &raw_char_vec, mode);
|
||||||
|
*out_dptr = dmlc::BeginPtr(raw_char_vec);
|
||||||
|
*out_len = static_cast<xgboost::bst_ulong>(raw_char_vec.size());
|
||||||
|
};
|
||||||
|
|
||||||
Json out{Object{}};
|
Json out{Object{}};
|
||||||
if (format == "json") {
|
if (format == "json") {
|
||||||
learner->SaveModel(&out);
|
save_json(std::ios::out);
|
||||||
Json::Dump(out, &raw_str);
|
|
||||||
} else if (format == "ubj") {
|
} else if (format == "ubj") {
|
||||||
learner->SaveModel(&out);
|
save_json(std::ios::binary);
|
||||||
Json::Dump(out, &raw_str, std::ios::binary);
|
|
||||||
} else if (format == "deprecated") {
|
} else if (format == "deprecated") {
|
||||||
WarnOldModel();
|
WarnOldModel();
|
||||||
|
auto &raw_str = learner->GetThreadLocal().ret_str;
|
||||||
|
raw_str.clear();
|
||||||
common::MemoryBufferStream fo(&raw_str);
|
common::MemoryBufferStream fo(&raw_str);
|
||||||
learner->SaveModel(&fo);
|
learner->SaveModel(&fo);
|
||||||
|
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||||
|
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown format: `" << format << "`";
|
LOG(FATAL) << "Unknown format: `" << format << "`";
|
||||||
}
|
}
|
||||||
|
|
||||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
|
||||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
|
||||||
|
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -195,11 +195,12 @@ Json& Value::operator[](int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Json Object
|
// Json Object
|
||||||
JsonObject::JsonObject(JsonObject && that) noexcept :
|
JsonObject::JsonObject(JsonObject&& that) noexcept : Value(ValueKind::kObject) {
|
||||||
Value(ValueKind::kObject), object_{std::move(that.object_)} {}
|
std::swap(that.object_, this->object_);
|
||||||
|
}
|
||||||
|
|
||||||
JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
|
JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
|
||||||
: Value(ValueKind::kObject), object_{std::move(object)} {}
|
: Value(ValueKind::kObject), object_{std::forward<std::map<std::string, Json>>(object)} {}
|
||||||
|
|
||||||
bool JsonObject::operator==(Value const& rhs) const {
|
bool JsonObject::operator==(Value const& rhs) const {
|
||||||
if (!IsA<JsonObject>(&rhs)) {
|
if (!IsA<JsonObject>(&rhs)) {
|
||||||
@ -220,8 +221,9 @@ bool JsonString::operator==(Value const& rhs) const {
|
|||||||
void JsonString::Save(JsonWriter* writer) const { writer->Visit(this); }
|
void JsonString::Save(JsonWriter* writer) const { writer->Visit(this); }
|
||||||
|
|
||||||
// Json Array
|
// Json Array
|
||||||
JsonArray::JsonArray(JsonArray && that) noexcept :
|
JsonArray::JsonArray(JsonArray&& that) noexcept : Value(ValueKind::kArray) {
|
||||||
Value(ValueKind::kArray), vec_{std::move(that.vec_)} {}
|
std::swap(that.vec_, this->vec_);
|
||||||
|
}
|
||||||
|
|
||||||
bool JsonArray::operator==(Value const& rhs) const {
|
bool JsonArray::operator==(Value const& rhs) const {
|
||||||
if (!IsA<JsonArray>(&rhs)) {
|
if (!IsA<JsonArray>(&rhs)) {
|
||||||
@ -696,6 +698,7 @@ void Json::Dump(Json json, std::string* str, std::ios::openmode mode) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Json::Dump(Json json, std::vector<char>* str, std::ios::openmode mode) {
|
void Json::Dump(Json json, std::vector<char>* str, std::ios::openmode mode) {
|
||||||
|
str->clear();
|
||||||
if (mode & std::ios::binary) {
|
if (mode & std::ios::binary) {
|
||||||
UBJWriter writer{str};
|
UBJWriter writer{str};
|
||||||
writer.Save(json);
|
writer.Save(json);
|
||||||
@ -768,9 +771,7 @@ std::string UBJReader::DecodeStr() {
|
|||||||
str.resize(bsize);
|
str.resize(bsize);
|
||||||
auto ptr = raw_str_.c_str() + cursor_.Pos();
|
auto ptr = raw_str_.c_str() + cursor_.Pos();
|
||||||
std::memcpy(&str[0], ptr, bsize);
|
std::memcpy(&str[0], ptr, bsize);
|
||||||
for (int64_t i = 0; i < bsize; ++i) {
|
this->cursor_.Forward(bsize);
|
||||||
this->cursor_.Forward();
|
|
||||||
}
|
|
||||||
return str;
|
return str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -289,6 +289,19 @@ class TestModels:
|
|||||||
os.remove(model_path)
|
os.remove(model_path)
|
||||||
assert locale.getpreferredencoding(False) == loc
|
assert locale.getpreferredencoding(False) == loc
|
||||||
|
|
||||||
|
json_raw = bst.save_raw(raw_format="json")
|
||||||
|
from_jraw = xgb.Booster()
|
||||||
|
from_jraw.load_model(json_raw)
|
||||||
|
|
||||||
|
ubj_raw = bst.save_raw(raw_format="ubj")
|
||||||
|
from_ubjraw = xgb.Booster()
|
||||||
|
from_ubjraw.load_model(ubj_raw)
|
||||||
|
|
||||||
|
old_from_json = from_jraw.save_raw(raw_format="deprecated")
|
||||||
|
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
|
||||||
|
|
||||||
|
assert old_from_json == old_from_ubj
|
||||||
|
|
||||||
@pytest.mark.parametrize("ext", ["json", "ubj"])
|
@pytest.mark.parametrize("ext", ["json", "ubj"])
|
||||||
def test_model_json_io(self, ext: str) -> None:
|
def test_model_json_io(self, ext: str) -> None:
|
||||||
parameters = {"booster": "gbtree", "tree_method": "hist"}
|
parameters = {"booster": "gbtree", "tree_method": "hist"}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user