[backport][sklearn] Fix loading model attributes. (#9808) (#9880)

This commit is contained in:
Jiaming Yuan
2023-12-13 14:20:04 +08:00
committed by GitHub
parent 41ce8f28b2
commit e4ee4e79dc
4 changed files with 48 additions and 42 deletions

View File

@@ -1932,6 +1932,7 @@ class TestWithDask:
cls.client = client
cls.fit(X, y)
predt_0 = cls.predict(X)
proba_0 = cls.predict_proba(X)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.pkl")
@@ -1941,7 +1942,9 @@ class TestWithDask:
with open(path, "rb") as fd:
cls = pickle.load(fd)
predt_1 = cls.predict(X)
proba_1 = cls.predict_proba(X)
np.testing.assert_allclose(predt_0.compute(), predt_1.compute())
np.testing.assert_allclose(proba_0.compute(), proba_1.compute())
path = os.path.join(tmpdir, "cls.json")
cls.save_model(path)
@@ -1950,16 +1953,20 @@ class TestWithDask:
cls.load_model(path)
assert cls.n_classes_ == 10
predt_2 = cls.predict(X)
proba_2 = cls.predict_proba(X)
np.testing.assert_allclose(predt_0.compute(), predt_2.compute())
np.testing.assert_allclose(proba_0.compute(), proba_2.compute())
# Use single node to load
cls = xgb.XGBClassifier()
cls.load_model(path)
assert cls.n_classes_ == 10
predt_3 = cls.predict(X_)
proba_3 = cls.predict_proba(X_)
np.testing.assert_allclose(predt_0.compute(), predt_3)
np.testing.assert_allclose(proba_0.compute(), proba_3)
def test_dask_unsupported_features(client: "Client") -> None: