[pyspark] make the model saved by pyspark compatible (#8219)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -21,34 +21,12 @@ def _get_or_create_tmp_dir():
|
||||
return xgb_tmp_dir
|
||||
|
||||
|
||||
def serialize_xgb_model(model):
|
||||
def deserialize_xgb_model(model_string, xgb_model_creator):
|
||||
"""
|
||||
Serialize the input model to a string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model:
|
||||
an xgboost.XGBModel instance, such as
|
||||
xgboost.XGBClassifier or xgboost.XGBRegressor instance
|
||||
"""
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
model.save_model(tmp_file_name)
|
||||
with open(tmp_file_name, "r", encoding="utf-8") as f:
|
||||
ser_model_string = f.read()
|
||||
return ser_model_string
|
||||
|
||||
|
||||
def deserialize_xgb_model(ser_model_string, xgb_model_creator):
|
||||
"""
|
||||
Deserialize an xgboost.XGBModel instance from the input ser_model_string.
|
||||
Deserialize an xgboost.XGBModel instance from the input model_string.
|
||||
"""
|
||||
xgb_model = xgb_model_creator()
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
with open(tmp_file_name, "w", encoding="utf-8") as f:
|
||||
f.write(ser_model_string)
|
||||
xgb_model.load_model(tmp_file_name)
|
||||
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
|
||||
return xgb_model
|
||||
|
||||
|
||||
@@ -222,11 +200,11 @@ class SparkXGBModelWriter(MLWriter):
|
||||
"""
|
||||
xgb_model = self.instance._xgb_sklearn_model
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
model_save_path = os.path.join(path, "model.json")
|
||||
ser_xgb_model = serialize_xgb_model(xgb_model)
|
||||
_get_spark_session().createDataFrame(
|
||||
[(ser_xgb_model,)], ["xgb_sklearn_model"]
|
||||
).write.parquet(model_save_path)
|
||||
model_save_path = os.path.join(path, "model")
|
||||
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
|
||||
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
|
||||
model_save_path
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBModelReader(MLReader):
|
||||
@@ -252,13 +230,10 @@ class SparkXGBModelReader(MLReader):
|
||||
xgb_sklearn_params = py_model._gen_xgb_params_dict(
|
||||
gen_xgb_sklearn_estimator_param=True
|
||||
)
|
||||
model_load_path = os.path.join(path, "model.json")
|
||||
model_load_path = os.path.join(path, "model")
|
||||
|
||||
ser_xgb_model = (
|
||||
_get_spark_session()
|
||||
.read.parquet(model_load_path)
|
||||
.collect()[0]
|
||||
.xgb_sklearn_model
|
||||
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
|
||||
)
|
||||
|
||||
def create_xgb_model():
|
||||
|
||||
Reference in New Issue
Block a user