[pyspark] make the model saved by pyspark compatible (#8219)

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Bobby Wang
2022-09-20 16:43:49 +08:00
committed by GitHub
parent 520586ffa7
commit 4f42aa5f12
2 changed files with 40 additions and 41 deletions

View File

@@ -21,34 +21,12 @@ def _get_or_create_tmp_dir():
return xgb_tmp_dir
def serialize_xgb_model(model):
def deserialize_xgb_model(model_string, xgb_model_creator):
"""
Serialize the input model to a string.
Parameters
----------
model:
an xgboost.XGBModel instance, such as
xgboost.XGBClassifier or xgboost.XGBRegressor instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
model.save_model(tmp_file_name)
with open(tmp_file_name, "r", encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
def deserialize_xgb_model(ser_model_string, xgb_model_creator):
"""
Deserialize an xgboost.XGBModel instance from the input ser_model_string.
Deserialize an xgboost.XGBModel instance from the input model_string.
"""
xgb_model = xgb_model_creator()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(ser_model_string)
xgb_model.load_model(tmp_file_name)
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
return xgb_model
@@ -222,11 +200,11 @@ class SparkXGBModelWriter(MLWriter):
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model.json")
ser_xgb_model = serialize_xgb_model(xgb_model)
_get_spark_session().createDataFrame(
[(ser_xgb_model,)], ["xgb_sklearn_model"]
).write.parquet(model_save_path)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
)
class SparkXGBModelReader(MLReader):
@@ -252,13 +230,10 @@ class SparkXGBModelReader(MLReader):
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model.json")
model_load_path = os.path.join(path, "model")
ser_xgb_model = (
_get_spark_session()
.read.parquet(model_load_path)
.collect()[0]
.xgb_sklearn_model
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
)
def create_xgb_model():