[pyspark] make the model saved by pyspark compatible (#8219)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
520586ffa7
commit
4f42aa5f12
@ -21,34 +21,12 @@ def _get_or_create_tmp_dir():
|
|||||||
return xgb_tmp_dir
|
return xgb_tmp_dir
|
||||||
|
|
||||||
|
|
||||||
def serialize_xgb_model(model):
|
def deserialize_xgb_model(model_string, xgb_model_creator):
|
||||||
"""
|
"""
|
||||||
Serialize the input model to a string.
|
Deserialize an xgboost.XGBModel instance from the input model_string.
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model:
|
|
||||||
an xgboost.XGBModel instance, such as
|
|
||||||
xgboost.XGBClassifier or xgboost.XGBRegressor instance
|
|
||||||
"""
|
|
||||||
# TODO: change to use string io
|
|
||||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
|
||||||
model.save_model(tmp_file_name)
|
|
||||||
with open(tmp_file_name, "r", encoding="utf-8") as f:
|
|
||||||
ser_model_string = f.read()
|
|
||||||
return ser_model_string
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_xgb_model(ser_model_string, xgb_model_creator):
|
|
||||||
"""
|
|
||||||
Deserialize an xgboost.XGBModel instance from the input ser_model_string.
|
|
||||||
"""
|
"""
|
||||||
xgb_model = xgb_model_creator()
|
xgb_model = xgb_model_creator()
|
||||||
# TODO: change to use string io
|
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
|
||||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
|
||||||
with open(tmp_file_name, "w", encoding="utf-8") as f:
|
|
||||||
f.write(ser_model_string)
|
|
||||||
xgb_model.load_model(tmp_file_name)
|
|
||||||
return xgb_model
|
return xgb_model
|
||||||
|
|
||||||
|
|
||||||
@ -222,11 +200,11 @@ class SparkXGBModelWriter(MLWriter):
|
|||||||
"""
|
"""
|
||||||
xgb_model = self.instance._xgb_sklearn_model
|
xgb_model = self.instance._xgb_sklearn_model
|
||||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||||
model_save_path = os.path.join(path, "model.json")
|
model_save_path = os.path.join(path, "model")
|
||||||
ser_xgb_model = serialize_xgb_model(xgb_model)
|
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
|
||||||
_get_spark_session().createDataFrame(
|
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
|
||||||
[(ser_xgb_model,)], ["xgb_sklearn_model"]
|
model_save_path
|
||||||
).write.parquet(model_save_path)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SparkXGBModelReader(MLReader):
|
class SparkXGBModelReader(MLReader):
|
||||||
@ -252,13 +230,10 @@ class SparkXGBModelReader(MLReader):
|
|||||||
xgb_sklearn_params = py_model._gen_xgb_params_dict(
|
xgb_sklearn_params = py_model._gen_xgb_params_dict(
|
||||||
gen_xgb_sklearn_estimator_param=True
|
gen_xgb_sklearn_estimator_param=True
|
||||||
)
|
)
|
||||||
model_load_path = os.path.join(path, "model.json")
|
model_load_path = os.path.join(path, "model")
|
||||||
|
|
||||||
ser_xgb_model = (
|
ser_xgb_model = (
|
||||||
_get_spark_session()
|
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
|
||||||
.read.parquet(model_load_path)
|
|
||||||
.collect()[0]
|
|
||||||
.xgb_sklearn_model
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_xgb_model():
|
def create_xgb_model():
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
@ -7,6 +8,8 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import testing as tm
|
import testing as tm
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
if tm.no_spark()["condition"]:
|
if tm.no_spark()["condition"]:
|
||||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||||
@ -31,7 +34,7 @@ from xgboost.spark import (
|
|||||||
)
|
)
|
||||||
from xgboost.spark.core import _non_booster_params
|
from xgboost.spark.core import _non_booster_params
|
||||||
|
|
||||||
from xgboost import XGBClassifier, XGBRegressor
|
from xgboost import XGBClassifier, XGBModel, XGBRegressor
|
||||||
|
|
||||||
from .utils import SparkTestCase
|
from .utils import SparkTestCase
|
||||||
|
|
||||||
@ -63,7 +66,12 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
# >>> reg2.fit(X, y)
|
# >>> reg2.fit(X, y)
|
||||||
# >>> reg2.predict(X, ntree_limit=5)
|
# >>> reg2.predict(X, ntree_limit=5)
|
||||||
# array([0.22185266, 0.77814734], dtype=float32)
|
# array([0.22185266, 0.77814734], dtype=float32)
|
||||||
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
|
self.reg_params = {
|
||||||
|
"max_depth": 5,
|
||||||
|
"n_estimators": 10,
|
||||||
|
"ntree_limit": 5,
|
||||||
|
"max_bin": 9,
|
||||||
|
}
|
||||||
self.reg_df_train = self.session.createDataFrame(
|
self.reg_df_train = self.session.createDataFrame(
|
||||||
[
|
[
|
||||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||||
@ -428,6 +436,12 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
def get_local_tmp_dir(self):
|
def get_local_tmp_dir(self):
|
||||||
return self.tempdir + str(uuid.uuid4())
|
return self.tempdir + str(uuid.uuid4())
|
||||||
|
|
||||||
|
def assert_model_compatible(self, model: XGBModel, model_path: str):
|
||||||
|
bst = xgb.Booster()
|
||||||
|
path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0]
|
||||||
|
bst.load_model(path)
|
||||||
|
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json"))
|
||||||
|
|
||||||
def test_regressor_params_basic(self):
|
def test_regressor_params_basic(self):
|
||||||
py_reg = SparkXGBRegressor()
|
py_reg = SparkXGBRegressor()
|
||||||
self.assertTrue(hasattr(py_reg, "n_estimators"))
|
self.assertTrue(hasattr(py_reg, "n_estimators"))
|
||||||
@ -592,7 +606,8 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_regressor_model_save_load(self):
|
def test_regressor_model_save_load(self):
|
||||||
path = "file:" + self.get_local_tmp_dir()
|
tmp_dir = self.get_local_tmp_dir()
|
||||||
|
path = "file:" + tmp_dir
|
||||||
regressor = SparkXGBRegressor(**self.reg_params)
|
regressor = SparkXGBRegressor(**self.reg_params)
|
||||||
model = regressor.fit(self.reg_df_train)
|
model = regressor.fit(self.reg_df_train)
|
||||||
model.save(path)
|
model.save(path)
|
||||||
@ -612,8 +627,11 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
||||||
SparkXGBClassifierModel.load(path)
|
SparkXGBClassifierModel.load(path)
|
||||||
|
|
||||||
|
self.assert_model_compatible(model, tmp_dir)
|
||||||
|
|
||||||
def test_classifier_model_save_load(self):
|
def test_classifier_model_save_load(self):
|
||||||
path = "file:" + self.get_local_tmp_dir()
|
tmp_dir = self.get_local_tmp_dir()
|
||||||
|
path = "file:" + tmp_dir
|
||||||
regressor = SparkXGBClassifier(**self.cls_params)
|
regressor = SparkXGBClassifier(**self.cls_params)
|
||||||
model = regressor.fit(self.cls_df_train)
|
model = regressor.fit(self.cls_df_train)
|
||||||
model.save(path)
|
model.save(path)
|
||||||
@ -633,12 +651,15 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
||||||
SparkXGBRegressorModel.load(path)
|
SparkXGBRegressorModel.load(path)
|
||||||
|
|
||||||
|
self.assert_model_compatible(model, tmp_dir)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_params_map(params_kv, estimator):
|
def _get_params_map(params_kv, estimator):
|
||||||
return {getattr(estimator, k): v for k, v in params_kv.items()}
|
return {getattr(estimator, k): v for k, v in params_kv.items()}
|
||||||
|
|
||||||
def test_regressor_model_pipeline_save_load(self):
|
def test_regressor_model_pipeline_save_load(self):
|
||||||
path = "file:" + self.get_local_tmp_dir()
|
tmp_dir = self.get_local_tmp_dir()
|
||||||
|
path = "file:" + tmp_dir
|
||||||
regressor = SparkXGBRegressor()
|
regressor = SparkXGBRegressor()
|
||||||
pipeline = Pipeline(stages=[regressor])
|
pipeline = Pipeline(stages=[regressor])
|
||||||
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor))
|
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor))
|
||||||
@ -656,9 +677,11 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
row.prediction, row.expected_prediction_with_params, atol=1e-3
|
row.prediction, row.expected_prediction_with_params, atol=1e-3
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.assert_model_compatible(model.stages[0], tmp_dir)
|
||||||
|
|
||||||
def test_classifier_model_pipeline_save_load(self):
|
def test_classifier_model_pipeline_save_load(self):
|
||||||
path = "file:" + self.get_local_tmp_dir()
|
tmp_dir = self.get_local_tmp_dir()
|
||||||
|
path = "file:" + tmp_dir
|
||||||
classifier = SparkXGBClassifier()
|
classifier = SparkXGBClassifier()
|
||||||
pipeline = Pipeline(stages=[classifier])
|
pipeline = Pipeline(stages=[classifier])
|
||||||
pipeline = pipeline.copy(
|
pipeline = pipeline.copy(
|
||||||
@ -678,6 +701,7 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
row.probability, row.expected_probability_with_params, atol=1e-3
|
row.probability, row.expected_probability_with_params, atol=1e-3
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.assert_model_compatible(model.stages[0], tmp_dir)
|
||||||
|
|
||||||
def test_classifier_with_cross_validator(self):
|
def test_classifier_with_cross_validator(self):
|
||||||
xgb_classifer = SparkXGBClassifier()
|
xgb_classifer = SparkXGBClassifier()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user