[dask] Ensure model can be pickled. (#6651)

This commit is contained in:
Jiaming Yuan 2021-01-28 21:47:57 +08:00 committed by GitHub
parent 0ad6e18a2a
commit d167892c7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 5 deletions

View File

@ -1359,6 +1359,12 @@ class DaskScikitLearnBase(XGBModel):
return self return self
return self.client.sync(_).__await__() return self.client.sync(_).__await__()
def __getstate__(self):
this = self.__dict__.copy()
if "_client" in this.keys():
del this["_client"]
return this
@property @property
def client(self) -> "distributed.Client": def client(self) -> "distributed.Client":
'''The dask client used in this model.''' '''The dask client used in this model.'''

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
import pickle
import testing as tm import testing as tm
import pytest import pytest
import xgboost as xgb import xgboost as xgb
@ -1104,23 +1104,32 @@ class TestWithDask:
predt_0 = cls.predict(X) predt_0 = cls.predict(X)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.pkl")
with open(path, "wb") as fd:
pickle.dump(cls, fd)
with open(path, "rb") as fd:
cls = pickle.load(fd)
predt_1 = cls.predict(X)
np.testing.assert_allclose(predt_0.compute(), predt_1.compute())
path = os.path.join(tmpdir, 'cls.json') path = os.path.join(tmpdir, 'cls.json')
cls.save_model(path) cls.save_model(path)
cls = xgb.dask.DaskXGBClassifier() cls = xgb.dask.DaskXGBClassifier()
cls.load_model(path) cls.load_model(path)
assert cls.n_classes_ == 10 assert cls.n_classes_ == 10
predt_1 = cls.predict(X) predt_2 = cls.predict(X)
np.testing.assert_allclose(predt_0.compute(), predt_1.compute()) np.testing.assert_allclose(predt_0.compute(), predt_2.compute())
# Use single node to load # Use single node to load
cls = xgb.XGBClassifier() cls = xgb.XGBClassifier()
cls.load_model(path) cls.load_model(path)
assert cls.n_classes_ == 10 assert cls.n_classes_ == 10
predt_2 = cls.predict(X_) predt_3 = cls.predict(X_)
np.testing.assert_allclose(predt_0.compute(), predt_2) np.testing.assert_allclose(predt_0.compute(), predt_3)
class TestDaskCallbacks: class TestDaskCallbacks: