[sklearn] Fix loading model attributes. (#9808)
This commit is contained in:
@@ -944,6 +944,7 @@ def save_load_model(model_path):
|
||||
predt_0 = clf.predict(X)
|
||||
clf.save_model(model_path)
|
||||
clf.load_model(model_path)
|
||||
assert clf.booster == "gblinear"
|
||||
predt_1 = clf.predict(X)
|
||||
np.testing.assert_allclose(predt_0, predt_1)
|
||||
assert clf.best_iteration == best_iteration
|
||||
@@ -959,25 +960,26 @@ def save_load_model(model_path):
|
||||
|
||||
def test_save_load_model():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model')
|
||||
model_path = os.path.join(tempdir, "digits.model")
|
||||
save_load_model(model_path)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model.json')
|
||||
model_path = os.path.join(tempdir, "digits.model.json")
|
||||
save_load_model(model_path)
|
||||
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model.ubj')
|
||||
model_path = os.path.join(tempdir, "digits.model.ubj")
|
||||
digits = load_digits(n_class=2)
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
booster = xgb.train({'tree_method': 'hist',
|
||||
'objective': 'binary:logistic'},
|
||||
dtrain=xgb.DMatrix(X, y),
|
||||
num_boost_round=4)
|
||||
y = digits["target"]
|
||||
X = digits["data"]
|
||||
booster = xgb.train(
|
||||
{"tree_method": "hist", "objective": "binary:logistic"},
|
||||
dtrain=xgb.DMatrix(X, y),
|
||||
num_boost_round=4,
|
||||
)
|
||||
predt_0 = booster.predict(xgb.DMatrix(X))
|
||||
booster.save_model(model_path)
|
||||
cls = xgb.XGBClassifier()
|
||||
@@ -1011,6 +1013,8 @@ def test_save_load_model():
|
||||
clf = xgb.XGBClassifier()
|
||||
clf.load_model(model_path)
|
||||
assert clf.classes_.size == 10
|
||||
assert clf.objective == "multi:softprob"
|
||||
|
||||
np.testing.assert_equal(clf.classes_, np.arange(10))
|
||||
assert clf.n_classes_ == 10
|
||||
|
||||
|
||||
Reference in New Issue
Block a user