diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 11bf001fa..0ab8d1ba6 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -35,6 +35,7 @@
1.8
1.17.0
3.4.0
+ 3.3.2
2.12.17
2.12
3.3.5
diff --git a/jvm-packages/xgboost4j-spark-gpu/pom.xml b/jvm-packages/xgboost4j-spark-gpu/pom.xml
index bcb7edb2a..57770be5a 100644
--- a/jvm-packages/xgboost4j-spark-gpu/pom.xml
+++ b/jvm-packages/xgboost4j-spark-gpu/pom.xml
@@ -29,19 +29,19 @@
org.apache.spark
spark-core_${scala.binary.version}
- ${spark.version}
+ ${spark.version.gpu}
provided
org.apache.spark
spark-sql_${scala.binary.version}
- ${spark.version}
+ ${spark.version.gpu}
provided
org.apache.spark
spark-mllib_${scala.binary.version}
- ${spark.version}
+ ${spark.version.gpu}
provided
diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py
index ec47a8c23..19a7c6cff 100644
--- a/python-package/xgboost/spark/core.py
+++ b/python-package/xgboost/spark/core.py
@@ -1,13 +1,17 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code."""
+import base64
+
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json
+import os
from collections import namedtuple
from typing import Iterator, List, Optional, Tuple
import numpy as np
import pandas as pd
+from pyspark import cloudpickle
from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT
@@ -21,7 +25,14 @@ from pyspark.ml.param.shared import (
HasValidationIndicatorCol,
HasWeightCol,
)
-from pyspark.ml.util import MLReadable, MLWritable
+from pyspark.ml.util import (
+ DefaultParamsReader,
+ DefaultParamsWriter,
+ MLReadable,
+ MLReader,
+ MLWritable,
+ MLWriter,
+)
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
@@ -36,7 +47,7 @@ from pyspark.sql.types import (
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
import xgboost
-from xgboost import XGBClassifier, XGBRanker, XGBRegressor
+from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available
from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS
@@ -49,12 +60,6 @@ from .data import (
pred_contribs,
stack_series,
)
-from .model import (
- SparkXGBModelReader,
- SparkXGBModelWriter,
- SparkXGBReader,
- SparkXGBWriter,
-)
from .params import (
HasArbitraryParamsDict,
HasBaseMarginCol,
@@ -71,8 +76,11 @@ from .utils import (
_get_rabit_args,
_get_spark_session,
_is_local,
+ deserialize_booster,
+ deserialize_xgb_model,
get_class_name,
get_logger,
+ serialize_booster,
)
# Put pyspark specific params here, they won't be passed to XGBoost.
@@ -156,6 +164,8 @@ Pred = namedtuple(
)
pred = Pred("prediction", "rawPrediction", "probability", "predContrib")
+_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
+
class _SparkXGBParams(
HasFeaturesCol,
@@ -1122,31 +1132,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
return dataset
-class SparkXGBRegressorModel(_SparkXGBModel):
- """
- The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
-
- .. Note:: This API is experimental.
- """
-
- @classmethod
- def _xgb_cls(cls):
- return XGBRegressor
-
-
-class SparkXGBRankerModel(_SparkXGBModel):
- """
- The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
-
- .. Note:: This API is experimental.
- """
-
- @classmethod
- def _xgb_cls(cls):
- return XGBRanker
-
-
-class SparkXGBClassifierModel(
+class _ClassificationModel( # pylint: disable=abstract-method
_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol
):
"""
@@ -1155,10 +1141,6 @@ class SparkXGBClassifierModel(
.. Note:: This API is experimental.
"""
- @classmethod
- def _xgb_cls(cls):
- return XGBClassifier
-
def _transform(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
@@ -1286,53 +1268,178 @@ class SparkXGBClassifierModel(
return dataset.drop(pred_struct_col)
-def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class):
- params_dict = pyspark_estimator_class._get_xgb_params_default()
+class _SparkXGBSharedReadWrite:
+ @staticmethod
+ def saveMetadata(instance, path, sc, logger, extraMetadata=None):
+ """
+ Save the metadata of an xgboost.spark._SparkXGBEstimator or
+ xgboost.spark._SparkXGBModel.
+ """
+ instance._validate_params()
+ skipParams = ["callbacks", "xgb_model"]
+ jsonParams = {}
+ for p, v in instance._paramMap.items(): # pylint: disable=protected-access
+ if p.name not in skipParams:
+ jsonParams[p.name] = v
- def param_value_converter(v):
- if isinstance(v, np.generic):
- # convert numpy scalar values to corresponding python scalar values
- return np.array(v).item()
- if isinstance(v, dict):
- return {k: param_value_converter(nv) for k, nv in v.items()}
- if isinstance(v, list):
- return [param_value_converter(nv) for nv in v]
- return v
-
- def set_param_attrs(attr_name, param_obj_):
- param_obj_.typeConverter = param_value_converter
- setattr(pyspark_estimator_class, attr_name, param_obj_)
- setattr(pyspark_model_class, attr_name, param_obj_)
-
- for name in params_dict.keys():
- doc = (
- f"Refer to XGBoost doc of "
- f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}"
- )
-
- param_obj = Param(Params._dummy(), name=name, doc=doc)
- set_param_attrs(name, param_obj)
-
- fit_params_dict = pyspark_estimator_class._get_fit_params_default()
- for name in fit_params_dict.keys():
- doc = (
- f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
- f".fit() for this param {name}"
- )
- if name == "callbacks":
- doc += (
- "The callbacks can be arbitrary functions. It is saved using cloudpickle "
- "which is not a fully self-contained format. It may fail to load with "
- "different versions of dependencies."
+ extraMetadata = extraMetadata or {}
+ callbacks = instance.getOrDefault(instance.callbacks)
+ if callbacks is not None:
+ logger.warning(
+ "The callbacks parameter is saved using cloudpickle and it "
+ "is not a fully self-contained format. It may fail to load "
+ "with different versions of dependencies."
)
- param_obj = Param(Params._dummy(), name=name, doc=doc)
- set_param_attrs(name, param_obj)
-
- predict_params_dict = pyspark_estimator_class._get_predict_params_default()
- for name in predict_params_dict.keys():
- doc = (
- f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
- f".predict() for this param {name}"
+ serialized_callbacks = base64.encodebytes(
+ cloudpickle.dumps(callbacks)
+ ).decode("ascii")
+ extraMetadata["serialized_callbacks"] = serialized_callbacks
+ init_booster = instance.getOrDefault(instance.xgb_model)
+ if init_booster is not None:
+ extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
+ DefaultParamsWriter.saveMetadata(
+ instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
)
- param_obj = Param(Params._dummy(), name=name, doc=doc)
- set_param_attrs(name, param_obj)
+ if init_booster is not None:
+ ser_init_booster = serialize_booster(init_booster)
+ save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
+ _get_spark_session().createDataFrame(
+ [(ser_init_booster,)], ["init_booster"]
+ ).write.parquet(save_path)
+
+ @staticmethod
+ def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
+ """
+ Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
+ xgboost.spark._SparkXGBModel.
+
+ :return: a tuple of (metadata, instance)
+ """
+ metadata = DefaultParamsReader.loadMetadata(
+ path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
+ )
+ pyspark_xgb = pyspark_xgb_cls()
+ DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
+
+ if "serialized_callbacks" in metadata:
+ serialized_callbacks = metadata["serialized_callbacks"]
+ try:
+ callbacks = cloudpickle.loads(
+ base64.decodebytes(serialized_callbacks.encode("ascii"))
+ )
+ pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
+ except Exception as e: # pylint: disable=W0703
+ logger.warning(
+ f"Fails to load the callbacks param due to {e}. Please set the "
+ "callbacks param manually for the loaded estimator."
+ )
+
+ if "init_booster" in metadata:
+ load_path = os.path.join(path, metadata["init_booster"])
+ ser_init_booster = (
+ _get_spark_session().read.parquet(load_path).collect()[0].init_booster
+ )
+ init_booster = deserialize_booster(ser_init_booster)
+ pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
+
+ pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
+ return metadata, pyspark_xgb
+
+
+class SparkXGBWriter(MLWriter):
+ """
+ Spark Xgboost estimator writer.
+ """
+
+ def __init__(self, instance):
+ super().__init__()
+ self.instance = instance
+ self.logger = get_logger(self.__class__.__name__, level="WARN")
+
+ def saveImpl(self, path):
+ """
+ save model.
+ """
+ _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
+
+
+class SparkXGBReader(MLReader):
+ """
+ Spark Xgboost estimator reader.
+ """
+
+ def __init__(self, cls):
+ super().__init__()
+ self.cls = cls
+ self.logger = get_logger(self.__class__.__name__, level="WARN")
+
+ def load(self, path):
+ """
+ load model.
+ """
+ _, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
+ self.cls, path, self.sc, self.logger
+ )
+ return pyspark_xgb
+
+
+class SparkXGBModelWriter(MLWriter):
+ """
+ Spark Xgboost model writer.
+ """
+
+ def __init__(self, instance):
+ super().__init__()
+ self.instance = instance
+ self.logger = get_logger(self.__class__.__name__, level="WARN")
+
+ def saveImpl(self, path):
+ """
+ Save metadata and model for a :py:class:`_SparkXGBModel`
+ - save metadata to path/metadata
+ - save model to path/model.json
+ """
+ xgb_model = self.instance._xgb_sklearn_model
+ _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
+ model_save_path = os.path.join(path, "model")
+ booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
+ _get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
+ model_save_path
+ )
+
+
+class SparkXGBModelReader(MLReader):
+ """
+ Spark Xgboost model reader.
+ """
+
+ def __init__(self, cls):
+ super().__init__()
+ self.cls = cls
+ self.logger = get_logger(self.__class__.__name__, level="WARN")
+
+ def load(self, path):
+ """
+ Load metadata and model for a :py:class:`_SparkXGBModel`
+
+ :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
+ """
+ _, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
+ self.cls, path, self.sc, self.logger
+ )
+
+ xgb_sklearn_params = py_model._gen_xgb_params_dict(
+ gen_xgb_sklearn_estimator_param=True
+ )
+ model_load_path = os.path.join(path, "model")
+
+ ser_xgb_model = (
+ _get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
+ )
+
+ def create_xgb_model():
+ return self.cls._xgb_cls()(**xgb_sklearn_params)
+
+ xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
+ py_model._xgb_sklearn_model = xgb_model
+ return py_model
diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py
index 2fe113ad4..283e906a7 100644
--- a/python-package/xgboost/spark/estimator.py
+++ b/python-package/xgboost/spark/estimator.py
@@ -1,18 +1,77 @@
"""Xgboost pyspark integration submodule for estimator API."""
# pylint: disable=too-many-ancestors
+# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
+
from typing import Any, Type
+import numpy as np
+from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from .core import ( # type: ignore
- SparkXGBClassifierModel,
- SparkXGBRankerModel,
- SparkXGBRegressorModel,
- _set_pyspark_xgb_cls_param_attrs,
+ _ClassificationModel,
_SparkXGBEstimator,
+ _SparkXGBModel,
)
+from .utils import get_class_name
+
+
+def _set_pyspark_xgb_cls_param_attrs(
+ estimator: _SparkXGBEstimator, model: _SparkXGBModel
+) -> None:
+ """This function automatically infer to xgboost parameters and set them
+ into corresponding pyspark estimators and models"""
+ params_dict = estimator._get_xgb_params_default()
+
+ def param_value_converter(v: Any) -> Any:
+ if isinstance(v, np.generic):
+ # convert numpy scalar values to corresponding python scalar values
+ return np.array(v).item()
+ if isinstance(v, dict):
+ return {k: param_value_converter(nv) for k, nv in v.items()}
+ if isinstance(v, list):
+ return [param_value_converter(nv) for nv in v]
+ return v
+
+ def set_param_attrs(attr_name: str, param: Param) -> None:
+ param.typeConverter = param_value_converter
+ setattr(estimator, attr_name, param)
+ setattr(model, attr_name, param)
+
+ for name in params_dict.keys():
+ doc = (
+ f"Refer to XGBoost doc of "
+ f"{get_class_name(estimator._xgb_cls())} for this param {name}"
+ )
+
+ param_obj: Param = Param(Params._dummy(), name=name, doc=doc)
+ set_param_attrs(name, param_obj)
+
+ fit_params_dict = estimator._get_fit_params_default()
+ for name in fit_params_dict.keys():
+ doc = (
+ f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
+ f".fit() for this param {name}"
+ )
+ if name == "callbacks":
+ doc += (
+ "The callbacks can be arbitrary functions. It is saved using cloudpickle "
+ "which is not a fully self-contained format. It may fail to load with "
+ "different versions of dependencies."
+ )
+ param_obj = Param(Params._dummy(), name=name, doc=doc)
+ set_param_attrs(name, param_obj)
+
+ predict_params_dict = estimator._get_predict_params_default()
+ for name in predict_params_dict.keys():
+ doc = (
+ f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
+ f".predict() for this param {name}"
+ )
+ param_obj = Param(Params._dummy(), name=name, doc=doc)
+ set_param_attrs(name, param_obj)
class SparkXGBRegressor(_SparkXGBEstimator):
@@ -105,7 +164,7 @@ class SparkXGBRegressor(_SparkXGBEstimator):
return XGBRegressor
@classmethod
- def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]:
+ def _pyspark_model_cls(cls) -> Type["SparkXGBRegressorModel"]:
return SparkXGBRegressorModel
def _validate_params(self) -> None:
@@ -116,6 +175,18 @@ class SparkXGBRegressor(_SparkXGBEstimator):
)
+class SparkXGBRegressorModel(_SparkXGBModel):
+ """
+ The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
+
+ .. Note:: This API is experimental.
+ """
+
+ @classmethod
+ def _xgb_cls(cls) -> Type[XGBRegressor]:
+ return XGBRegressor
+
+
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
@@ -224,7 +295,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
return XGBClassifier
@classmethod
- def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]:
+ def _pyspark_model_cls(cls) -> Type["SparkXGBClassifierModel"]:
return SparkXGBClassifierModel
def _validate_params(self) -> None:
@@ -239,6 +310,18 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
)
+class SparkXGBClassifierModel(_ClassificationModel):
+ """
+ The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
+
+ .. Note:: This API is experimental.
+ """
+
+ @classmethod
+ def _xgb_cls(cls) -> Type[XGBClassifier]:
+ return XGBClassifier
+
+
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
@@ -352,7 +435,7 @@ class SparkXGBRanker(_SparkXGBEstimator):
return XGBRanker
@classmethod
- def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]:
+ def _pyspark_model_cls(cls) -> Type["SparkXGBRankerModel"]:
return SparkXGBRankerModel
def _validate_params(self) -> None:
@@ -363,4 +446,16 @@ class SparkXGBRanker(_SparkXGBEstimator):
)
+class SparkXGBRankerModel(_SparkXGBModel):
+ """
+ The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
+
+ .. Note:: This API is experimental.
+ """
+
+ @classmethod
+ def _xgb_cls(cls) -> Type[XGBRanker]:
+ return XGBRanker
+
+
_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel)
diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py
deleted file mode 100644
index 888bc9cc5..000000000
--- a/python-package/xgboost/spark/model.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# type: ignore
-"""Xgboost pyspark integration submodule for model API."""
-# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods
-import base64
-import os
-import uuid
-
-from pyspark import SparkFiles, cloudpickle
-from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
-from pyspark.sql import SparkSession
-
-from xgboost.core import Booster
-
-from .utils import get_class_name, get_logger
-
-
-def _get_or_create_tmp_dir():
- root_dir = SparkFiles.getRootDirectory()
- xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
- if not os.path.exists(xgb_tmp_dir):
- os.makedirs(xgb_tmp_dir)
- return xgb_tmp_dir
-
-
-def deserialize_xgb_model(model_string, xgb_model_creator):
- """
- Deserialize an xgboost.XGBModel instance from the input model_string.
- """
- xgb_model = xgb_model_creator()
- xgb_model.load_model(bytearray(model_string.encode("utf-8")))
- return xgb_model
-
-
-def serialize_booster(booster):
- """
- Serialize the input booster to a string.
-
- Parameters
- ----------
- booster:
- an xgboost.core.Booster instance
- """
- # TODO: change to use string io
- tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
- booster.save_model(tmp_file_name)
- with open(tmp_file_name, encoding="utf-8") as f:
- ser_model_string = f.read()
- return ser_model_string
-
-
-def deserialize_booster(ser_model_string):
- """
- Deserialize an xgboost.core.Booster from the input ser_model_string.
- """
- booster = Booster()
- # TODO: change to use string io
- tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
- with open(tmp_file_name, "w", encoding="utf-8") as f:
- f.write(ser_model_string)
- booster.load_model(tmp_file_name)
- return booster
-
-
-_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
-
-
-def _get_spark_session():
- return SparkSession.builder.getOrCreate()
-
-
-class _SparkXGBSharedReadWrite:
- @staticmethod
- def saveMetadata(instance, path, sc, logger, extraMetadata=None):
- """
- Save the metadata of an xgboost.spark._SparkXGBEstimator or
- xgboost.spark._SparkXGBModel.
- """
- instance._validate_params()
- skipParams = ["callbacks", "xgb_model"]
- jsonParams = {}
- for p, v in instance._paramMap.items(): # pylint: disable=protected-access
- if p.name not in skipParams:
- jsonParams[p.name] = v
-
- extraMetadata = extraMetadata or {}
- callbacks = instance.getOrDefault(instance.callbacks)
- if callbacks is not None:
- logger.warning(
- "The callbacks parameter is saved using cloudpickle and it "
- "is not a fully self-contained format. It may fail to load "
- "with different versions of dependencies."
- )
- serialized_callbacks = base64.encodebytes(
- cloudpickle.dumps(callbacks)
- ).decode("ascii")
- extraMetadata["serialized_callbacks"] = serialized_callbacks
- init_booster = instance.getOrDefault(instance.xgb_model)
- if init_booster is not None:
- extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
- DefaultParamsWriter.saveMetadata(
- instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
- )
- if init_booster is not None:
- ser_init_booster = serialize_booster(init_booster)
- save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
- _get_spark_session().createDataFrame(
- [(ser_init_booster,)], ["init_booster"]
- ).write.parquet(save_path)
-
- @staticmethod
- def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
- """
- Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
- xgboost.spark._SparkXGBModel.
-
- :return: a tuple of (metadata, instance)
- """
- metadata = DefaultParamsReader.loadMetadata(
- path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
- )
- pyspark_xgb = pyspark_xgb_cls()
- DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
-
- if "serialized_callbacks" in metadata:
- serialized_callbacks = metadata["serialized_callbacks"]
- try:
- callbacks = cloudpickle.loads(
- base64.decodebytes(serialized_callbacks.encode("ascii"))
- )
- pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
- except Exception as e: # pylint: disable=W0703
- logger.warning(
- f"Fails to load the callbacks param due to {e}. Please set the "
- "callbacks param manually for the loaded estimator."
- )
-
- if "init_booster" in metadata:
- load_path = os.path.join(path, metadata["init_booster"])
- ser_init_booster = (
- _get_spark_session().read.parquet(load_path).collect()[0].init_booster
- )
- init_booster = deserialize_booster(ser_init_booster)
- pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
-
- pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
- return metadata, pyspark_xgb
-
-
-class SparkXGBWriter(MLWriter):
- """
- Spark Xgboost estimator writer.
- """
-
- def __init__(self, instance):
- super().__init__()
- self.instance = instance
- self.logger = get_logger(self.__class__.__name__, level="WARN")
-
- def saveImpl(self, path):
- """
- save model.
- """
- _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
-
-
-class SparkXGBReader(MLReader):
- """
- Spark Xgboost estimator reader.
- """
-
- def __init__(self, cls):
- super().__init__()
- self.cls = cls
- self.logger = get_logger(self.__class__.__name__, level="WARN")
-
- def load(self, path):
- """
- load model.
- """
- _, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
- self.cls, path, self.sc, self.logger
- )
- return pyspark_xgb
-
-
-class SparkXGBModelWriter(MLWriter):
- """
- Spark Xgboost model writer.
- """
-
- def __init__(self, instance):
- super().__init__()
- self.instance = instance
- self.logger = get_logger(self.__class__.__name__, level="WARN")
-
- def saveImpl(self, path):
- """
- Save metadata and model for a :py:class:`_SparkXGBModel`
- - save metadata to path/metadata
- - save model to path/model.json
- """
- xgb_model = self.instance._xgb_sklearn_model
- _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
- model_save_path = os.path.join(path, "model")
- booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
- _get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
- model_save_path
- )
-
-
-class SparkXGBModelReader(MLReader):
- """
- Spark Xgboost model reader.
- """
-
- def __init__(self, cls):
- super().__init__()
- self.cls = cls
- self.logger = get_logger(self.__class__.__name__, level="WARN")
-
- def load(self, path):
- """
- Load metadata and model for a :py:class:`_SparkXGBModel`
-
- :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
- """
- _, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
- self.cls, path, self.sc, self.logger
- )
-
- xgb_sklearn_params = py_model._gen_xgb_params_dict(
- gen_xgb_sklearn_estimator_param=True
- )
- model_load_path = os.path.join(path, "model")
-
- ser_xgb_model = (
- _get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
- )
-
- def create_xgb_model():
- return self.cls._xgb_cls()(**xgb_sklearn_params)
-
- xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
- py_model._xgb_sklearn_model = xgb_model
- return py_model
diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py
index 979c40ea9..46e465dde 100644
--- a/python-package/xgboost/spark/utils.py
+++ b/python-package/xgboost/spark/utils.py
@@ -1,15 +1,19 @@
"""Xgboost pyspark integration submodule for helper functions."""
+# pylint: disable=fixme
+
import inspect
import logging
+import os
import sys
+import uuid
from threading import Thread
from typing import Any, Callable, Dict, Set, Type
import pyspark
-from pyspark import BarrierTaskContext, SparkContext
+from pyspark import BarrierTaskContext, SparkContext, SparkFiles
from pyspark.sql.session import SparkSession
-from xgboost import collective
+from xgboost import Booster, XGBModel, collective
from xgboost.tracker import RabitTracker
@@ -133,3 +137,52 @@ def _get_gpu_id(task_context: BarrierTaskContext) -> int:
)
# return the first gpu id.
return int(resources["gpu"].addresses[0].strip())
+
+
+def _get_or_create_tmp_dir() -> str:
+ root_dir = SparkFiles.getRootDirectory()
+ xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
+ if not os.path.exists(xgb_tmp_dir):
+ os.makedirs(xgb_tmp_dir)
+ return xgb_tmp_dir
+
+
+def deserialize_xgb_model(
+ model: str, xgb_model_creator: Callable[[], XGBModel]
+) -> XGBModel:
+ """
+ Deserialize an xgboost.XGBModel instance from the input model.
+ """
+ xgb_model = xgb_model_creator()
+ xgb_model.load_model(bytearray(model.encode("utf-8")))
+ return xgb_model
+
+
+def serialize_booster(booster: Booster) -> str:
+ """
+ Serialize the input booster to a string.
+
+ Parameters
+ ----------
+ booster:
+ an xgboost.core.Booster instance
+ """
+ # TODO: change to use string io
+ tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
+ booster.save_model(tmp_file_name)
+ with open(tmp_file_name, encoding="utf-8") as f:
+ ser_model_string = f.read()
+ return ser_model_string
+
+
+def deserialize_booster(model: str) -> Booster:
+ """
+ Deserialize an xgboost.core.Booster from the input ser_model_string.
+ """
+ booster = Booster()
+ # TODO: change to use string io
+ tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
+ with open(tmp_file_name, "w", encoding="utf-8") as f:
+ f.write(model)
+ booster.load_model(tmp_file_name)
+ return booster