Save Scikit-Learn attributes into learner attributes. (#5245)
* Remove the recommendation for pickle. * Save skl attributes in booster.attr * Test loading scikit-learn model with native booster.
This commit is contained in:
@@ -30,7 +30,8 @@ def json_model(model_path, parameters):
|
||||
class TestModels(unittest.TestCase):
|
||||
def test_glm(self):
|
||||
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
||||
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, 'nthread': 1}
|
||||
'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1,
|
||||
'nthread': 1}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 4
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from xgboost.sklearn import XGBoostLabelEncoder
|
||||
import testing as tm
|
||||
import tempfile
|
||||
import os
|
||||
@@ -614,7 +615,7 @@ def test_validation_weights_xgbclassifier():
|
||||
for i in [0, 1]))
|
||||
|
||||
|
||||
def test_save_load_model():
|
||||
def save_load_model(model_path):
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
@@ -622,18 +623,64 @@ def test_save_load_model():
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||
with TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model')
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
||||
xgb_model.save_model(model_path)
|
||||
for train_index, test_index in kf.split(X, y):
|
||||
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
|
||||
xgb_model.save_model(model_path)
|
||||
xgb_model = xgb.XGBClassifier()
|
||||
xgb_model.load_model(model_path)
|
||||
assert isinstance(xgb_model.classes_, np.ndarray)
|
||||
assert isinstance(xgb_model._Booster, xgb.Booster)
|
||||
assert isinstance(xgb_model._le, XGBoostLabelEncoder)
|
||||
assert isinstance(xgb_model._le.classes_, np.ndarray)
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
labels = y[test_index]
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
assert xgb_model.get_booster().attr('scikit_learn') is None
|
||||
|
||||
# test native booster
|
||||
preds = xgb_model.predict(X[test_index], output_margin=True)
|
||||
booster = xgb.Booster(model_file=model_path)
|
||||
predt_1 = booster.predict(xgb.DMatrix(X[test_index]),
|
||||
output_margin=True)
|
||||
assert np.allclose(preds, predt_1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
xgb_model = xgb.XGBModel()
|
||||
xgb_model.load_model(model_path)
|
||||
preds = xgb_model.predict(X[test_index])
|
||||
labels = y[test_index]
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
|
||||
|
||||
def test_save_load_model():
|
||||
with TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model')
|
||||
save_load_model(model_path)
|
||||
|
||||
with TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model.json')
|
||||
save_load_model(model_path)
|
||||
|
||||
from sklearn.datasets import load_digits
|
||||
with TemporaryDirectory() as tempdir:
|
||||
model_path = os.path.join(tempdir, 'digits.model.json')
|
||||
digits = load_digits(2)
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
booster = xgb.train({'tree_method': 'hist',
|
||||
'objective': 'binary:logistic'},
|
||||
dtrain=xgb.DMatrix(X, y),
|
||||
num_boost_round=4)
|
||||
predt_0 = booster.predict(xgb.DMatrix(X))
|
||||
booster.save_model(model_path)
|
||||
cls = xgb.XGBClassifier()
|
||||
cls.load_model(model_path)
|
||||
predt_1 = cls.predict(X)
|
||||
assert np.allclose(predt_0, predt_1)
|
||||
|
||||
cls = xgb.XGBModel()
|
||||
cls.load_model(model_path)
|
||||
predt_1 = cls.predict(X)
|
||||
assert np.allclose(predt_0, predt_1)
|
||||
|
||||
|
||||
def test_RFECV():
|
||||
|
||||
Reference in New Issue
Block a user