diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 11bf001fa..0ab8d1ba6 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -35,6 +35,7 @@ 1.8 1.17.0 3.4.0 + 3.3.2 2.12.17 2.12 3.3.5 diff --git a/jvm-packages/xgboost4j-spark-gpu/pom.xml b/jvm-packages/xgboost4j-spark-gpu/pom.xml index bcb7edb2a..57770be5a 100644 --- a/jvm-packages/xgboost4j-spark-gpu/pom.xml +++ b/jvm-packages/xgboost4j-spark-gpu/pom.xml @@ -29,19 +29,19 @@ org.apache.spark spark-core_${scala.binary.version} - ${spark.version} + ${spark.version.gpu} provided org.apache.spark spark-sql_${scala.binary.version} - ${spark.version} + ${spark.version.gpu} provided org.apache.spark spark-mllib_${scala.binary.version} - ${spark.version} + ${spark.version.gpu} provided diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index ec47a8c23..19a7c6cff 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -1,13 +1,17 @@ # type: ignore """Xgboost pyspark integration submodule for core code.""" +import base64 + # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=too-few-public-methods, too-many-lines, too-many-branches import json +import os from collections import namedtuple from typing import Iterator, List, Optional, Tuple import numpy as np import pandas as pd +from pyspark import cloudpickle from pyspark.ml import Estimator, Model from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.ml.linalg import VectorUDT @@ -21,7 +25,14 @@ from pyspark.ml.param.shared import ( HasValidationIndicatorCol, HasWeightCol, ) -from pyspark.ml.util import MLReadable, MLWritable +from pyspark.ml.util import ( + DefaultParamsReader, + DefaultParamsWriter, + MLReadable, + MLReader, + MLWritable, + MLWriter, +) from pyspark.sql import DataFrame from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.types import ( @@ -36,7 +47,7 @@ from pyspark.sql.types import ( from scipy.special import expit, softmax # pylint: disable=no-name-in-module import xgboost -from xgboost import XGBClassifier, XGBRanker, XGBRegressor +from xgboost import XGBClassifier from xgboost.compat import is_cudf_available from xgboost.core import Booster from xgboost.sklearn import DEFAULT_N_ESTIMATORS @@ -49,12 +60,6 @@ from .data import ( pred_contribs, stack_series, ) -from .model import ( - SparkXGBModelReader, - SparkXGBModelWriter, - SparkXGBReader, - SparkXGBWriter, -) from .params import ( HasArbitraryParamsDict, HasBaseMarginCol, @@ -71,8 +76,11 @@ from .utils import ( _get_rabit_args, _get_spark_session, _is_local, + deserialize_booster, + deserialize_xgb_model, get_class_name, get_logger, + serialize_booster, ) # Put pyspark specific params here, they won't be passed to XGBoost. @@ -156,6 +164,8 @@ Pred = namedtuple( ) pred = Pred("prediction", "rawPrediction", "probability", "predContrib") +_INIT_BOOSTER_SAVE_PATH = "init_booster.json" + class _SparkXGBParams( HasFeaturesCol, @@ -1122,31 +1132,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): return dataset -class SparkXGBRegressorModel(_SparkXGBModel): - """ - The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit` - - .. Note:: This API is experimental. - """ - - @classmethod - def _xgb_cls(cls): - return XGBRegressor - - -class SparkXGBRankerModel(_SparkXGBModel): - """ - The model returned by :func:`xgboost.spark.SparkXGBRanker.fit` - - .. Note:: This API is experimental. - """ - - @classmethod - def _xgb_cls(cls): - return XGBRanker - - -class SparkXGBClassifierModel( +class _ClassificationModel( # pylint: disable=abstract-method _SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol ): """ @@ -1155,10 +1141,6 @@ class SparkXGBClassifierModel( .. Note:: This API is experimental. """ - @classmethod - def _xgb_cls(cls): - return XGBClassifier - def _transform(self, dataset): # pylint: disable=too-many-statements, too-many-locals # Save xgb_sklearn_model and predict_params to be local variable @@ -1286,53 +1268,178 @@ class SparkXGBClassifierModel( return dataset.drop(pred_struct_col) -def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class): - params_dict = pyspark_estimator_class._get_xgb_params_default() +class _SparkXGBSharedReadWrite: + @staticmethod + def saveMetadata(instance, path, sc, logger, extraMetadata=None): + """ + Save the metadata of an xgboost.spark._SparkXGBEstimator or + xgboost.spark._SparkXGBModel. + """ + instance._validate_params() + skipParams = ["callbacks", "xgb_model"] + jsonParams = {} + for p, v in instance._paramMap.items(): # pylint: disable=protected-access + if p.name not in skipParams: + jsonParams[p.name] = v - def param_value_converter(v): - if isinstance(v, np.generic): - # convert numpy scalar values to corresponding python scalar values - return np.array(v).item() - if isinstance(v, dict): - return {k: param_value_converter(nv) for k, nv in v.items()} - if isinstance(v, list): - return [param_value_converter(nv) for nv in v] - return v - - def set_param_attrs(attr_name, param_obj_): - param_obj_.typeConverter = param_value_converter - setattr(pyspark_estimator_class, attr_name, param_obj_) - setattr(pyspark_model_class, attr_name, param_obj_) - - for name in params_dict.keys(): - doc = ( - f"Refer to XGBoost doc of " - f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}" - ) - - param_obj = Param(Params._dummy(), name=name, doc=doc) - set_param_attrs(name, param_obj) - - fit_params_dict = pyspark_estimator_class._get_fit_params_default() - for name in fit_params_dict.keys(): - doc = ( - f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}" - f".fit() for this param {name}" - ) - if name == "callbacks": - doc += ( - "The callbacks can be arbitrary functions. It is saved using cloudpickle " - "which is not a fully self-contained format. It may fail to load with " - "different versions of dependencies." + extraMetadata = extraMetadata or {} + callbacks = instance.getOrDefault(instance.callbacks) + if callbacks is not None: + logger.warning( + "The callbacks parameter is saved using cloudpickle and it " + "is not a fully self-contained format. It may fail to load " + "with different versions of dependencies." ) - param_obj = Param(Params._dummy(), name=name, doc=doc) - set_param_attrs(name, param_obj) - - predict_params_dict = pyspark_estimator_class._get_predict_params_default() - for name in predict_params_dict.keys(): - doc = ( - f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}" - f".predict() for this param {name}" + serialized_callbacks = base64.encodebytes( + cloudpickle.dumps(callbacks) + ).decode("ascii") + extraMetadata["serialized_callbacks"] = serialized_callbacks + init_booster = instance.getOrDefault(instance.xgb_model) + if init_booster is not None: + extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH + DefaultParamsWriter.saveMetadata( + instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams ) - param_obj = Param(Params._dummy(), name=name, doc=doc) - set_param_attrs(name, param_obj) + if init_booster is not None: + ser_init_booster = serialize_booster(init_booster) + save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH) + _get_spark_session().createDataFrame( + [(ser_init_booster,)], ["init_booster"] + ).write.parquet(save_path) + + @staticmethod + def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): + """ + Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or + xgboost.spark._SparkXGBModel. + + :return: a tuple of (metadata, instance) + """ + metadata = DefaultParamsReader.loadMetadata( + path, sc, expectedClassName=get_class_name(pyspark_xgb_cls) + ) + pyspark_xgb = pyspark_xgb_cls() + DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata) + + if "serialized_callbacks" in metadata: + serialized_callbacks = metadata["serialized_callbacks"] + try: + callbacks = cloudpickle.loads( + base64.decodebytes(serialized_callbacks.encode("ascii")) + ) + pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) + except Exception as e: # pylint: disable=W0703 + logger.warning( + f"Fails to load the callbacks param due to {e}. Please set the " + "callbacks param manually for the loaded estimator." + ) + + if "init_booster" in metadata: + load_path = os.path.join(path, metadata["init_booster"]) + ser_init_booster = ( + _get_spark_session().read.parquet(load_path).collect()[0].init_booster + ) + init_booster = deserialize_booster(ser_init_booster) + pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) + + pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access + return metadata, pyspark_xgb + + +class SparkXGBWriter(MLWriter): + """ + Spark Xgboost estimator writer. + """ + + def __init__(self, instance): + super().__init__() + self.instance = instance + self.logger = get_logger(self.__class__.__name__, level="WARN") + + def saveImpl(self, path): + """ + save model. + """ + _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + + +class SparkXGBReader(MLReader): + """ + Spark Xgboost estimator reader. + """ + + def __init__(self, cls): + super().__init__() + self.cls = cls + self.logger = get_logger(self.__class__.__name__, level="WARN") + + def load(self, path): + """ + load model. + """ + _, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance( + self.cls, path, self.sc, self.logger + ) + return pyspark_xgb + + +class SparkXGBModelWriter(MLWriter): + """ + Spark Xgboost model writer. + """ + + def __init__(self, instance): + super().__init__() + self.instance = instance + self.logger = get_logger(self.__class__.__name__, level="WARN") + + def saveImpl(self, path): + """ + Save metadata and model for a :py:class:`_SparkXGBModel` + - save metadata to path/metadata + - save model to path/model.json + """ + xgb_model = self.instance._xgb_sklearn_model + _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) + model_save_path = os.path.join(path, "model") + booster = xgb_model.get_booster().save_raw("json").decode("utf-8") + _get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile( + model_save_path + ) + + +class SparkXGBModelReader(MLReader): + """ + Spark Xgboost model reader. + """ + + def __init__(self, cls): + super().__init__() + self.cls = cls + self.logger = get_logger(self.__class__.__name__, level="WARN") + + def load(self, path): + """ + Load metadata and model for a :py:class:`_SparkXGBModel` + + :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance + """ + _, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance( + self.cls, path, self.sc, self.logger + ) + + xgb_sklearn_params = py_model._gen_xgb_params_dict( + gen_xgb_sklearn_estimator_param=True + ) + model_load_path = os.path.join(path, "model") + + ser_xgb_model = ( + _get_spark_session().sparkContext.textFile(model_load_path).collect()[0] + ) + + def create_xgb_model(): + return self.cls._xgb_cls()(**xgb_sklearn_params) + + xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model) + py_model._xgb_sklearn_model = xgb_model + return py_model diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 2fe113ad4..283e906a7 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,18 +1,77 @@ """Xgboost pyspark integration submodule for estimator API.""" # pylint: disable=too-many-ancestors +# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name + from typing import Any, Type +import numpy as np +from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from xgboost import XGBClassifier, XGBRanker, XGBRegressor from .core import ( # type: ignore - SparkXGBClassifierModel, - SparkXGBRankerModel, - SparkXGBRegressorModel, - _set_pyspark_xgb_cls_param_attrs, + _ClassificationModel, _SparkXGBEstimator, + _SparkXGBModel, ) +from .utils import get_class_name + + +def _set_pyspark_xgb_cls_param_attrs( + estimator: _SparkXGBEstimator, model: _SparkXGBModel +) -> None: + """This function automatically infer to xgboost parameters and set them + into corresponding pyspark estimators and models""" + params_dict = estimator._get_xgb_params_default() + + def param_value_converter(v: Any) -> Any: + if isinstance(v, np.generic): + # convert numpy scalar values to corresponding python scalar values + return np.array(v).item() + if isinstance(v, dict): + return {k: param_value_converter(nv) for k, nv in v.items()} + if isinstance(v, list): + return [param_value_converter(nv) for nv in v] + return v + + def set_param_attrs(attr_name: str, param: Param) -> None: + param.typeConverter = param_value_converter + setattr(estimator, attr_name, param) + setattr(model, attr_name, param) + + for name in params_dict.keys(): + doc = ( + f"Refer to XGBoost doc of " + f"{get_class_name(estimator._xgb_cls())} for this param {name}" + ) + + param_obj: Param = Param(Params._dummy(), name=name, doc=doc) + set_param_attrs(name, param_obj) + + fit_params_dict = estimator._get_fit_params_default() + for name in fit_params_dict.keys(): + doc = ( + f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}" + f".fit() for this param {name}" + ) + if name == "callbacks": + doc += ( + "The callbacks can be arbitrary functions. It is saved using cloudpickle " + "which is not a fully self-contained format. It may fail to load with " + "different versions of dependencies." + ) + param_obj = Param(Params._dummy(), name=name, doc=doc) + set_param_attrs(name, param_obj) + + predict_params_dict = estimator._get_predict_params_default() + for name in predict_params_dict.keys(): + doc = ( + f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}" + f".predict() for this param {name}" + ) + param_obj = Param(Params._dummy(), name=name, doc=doc) + set_param_attrs(name, param_obj) class SparkXGBRegressor(_SparkXGBEstimator): @@ -105,7 +164,7 @@ class SparkXGBRegressor(_SparkXGBEstimator): return XGBRegressor @classmethod - def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]: + def _pyspark_model_cls(cls) -> Type["SparkXGBRegressorModel"]: return SparkXGBRegressorModel def _validate_params(self) -> None: @@ -116,6 +175,18 @@ class SparkXGBRegressor(_SparkXGBEstimator): ) +class SparkXGBRegressorModel(_SparkXGBModel): + """ + The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit` + + .. Note:: This API is experimental. + """ + + @classmethod + def _xgb_cls(cls) -> Type[XGBRegressor]: + return XGBRegressor + + _set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel) @@ -224,7 +295,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction return XGBClassifier @classmethod - def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]: + def _pyspark_model_cls(cls) -> Type["SparkXGBClassifierModel"]: return SparkXGBClassifierModel def _validate_params(self) -> None: @@ -239,6 +310,18 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction ) +class SparkXGBClassifierModel(_ClassificationModel): + """ + The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit` + + .. Note:: This API is experimental. + """ + + @classmethod + def _xgb_cls(cls) -> Type[XGBClassifier]: + return XGBClassifier + + _set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel) @@ -352,7 +435,7 @@ class SparkXGBRanker(_SparkXGBEstimator): return XGBRanker @classmethod - def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]: + def _pyspark_model_cls(cls) -> Type["SparkXGBRankerModel"]: return SparkXGBRankerModel def _validate_params(self) -> None: @@ -363,4 +446,16 @@ class SparkXGBRanker(_SparkXGBEstimator): ) +class SparkXGBRankerModel(_SparkXGBModel): + """ + The model returned by :func:`xgboost.spark.SparkXGBRanker.fit` + + .. Note:: This API is experimental. + """ + + @classmethod + def _xgb_cls(cls) -> Type[XGBRanker]: + return XGBRanker + + _set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel) diff --git a/python-package/xgboost/spark/model.py b/python-package/xgboost/spark/model.py deleted file mode 100644 index 888bc9cc5..000000000 --- a/python-package/xgboost/spark/model.py +++ /dev/null @@ -1,245 +0,0 @@ -# type: ignore -"""Xgboost pyspark integration submodule for model API.""" -# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods -import base64 -import os -import uuid - -from pyspark import SparkFiles, cloudpickle -from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter -from pyspark.sql import SparkSession - -from xgboost.core import Booster - -from .utils import get_class_name, get_logger - - -def _get_or_create_tmp_dir(): - root_dir = SparkFiles.getRootDirectory() - xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp") - if not os.path.exists(xgb_tmp_dir): - os.makedirs(xgb_tmp_dir) - return xgb_tmp_dir - - -def deserialize_xgb_model(model_string, xgb_model_creator): - """ - Deserialize an xgboost.XGBModel instance from the input model_string. - """ - xgb_model = xgb_model_creator() - xgb_model.load_model(bytearray(model_string.encode("utf-8"))) - return xgb_model - - -def serialize_booster(booster): - """ - Serialize the input booster to a string. - - Parameters - ---------- - booster: - an xgboost.core.Booster instance - """ - # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") - booster.save_model(tmp_file_name) - with open(tmp_file_name, encoding="utf-8") as f: - ser_model_string = f.read() - return ser_model_string - - -def deserialize_booster(ser_model_string): - """ - Deserialize an xgboost.core.Booster from the input ser_model_string. - """ - booster = Booster() - # TODO: change to use string io - tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") - with open(tmp_file_name, "w", encoding="utf-8") as f: - f.write(ser_model_string) - booster.load_model(tmp_file_name) - return booster - - -_INIT_BOOSTER_SAVE_PATH = "init_booster.json" - - -def _get_spark_session(): - return SparkSession.builder.getOrCreate() - - -class _SparkXGBSharedReadWrite: - @staticmethod - def saveMetadata(instance, path, sc, logger, extraMetadata=None): - """ - Save the metadata of an xgboost.spark._SparkXGBEstimator or - xgboost.spark._SparkXGBModel. - """ - instance._validate_params() - skipParams = ["callbacks", "xgb_model"] - jsonParams = {} - for p, v in instance._paramMap.items(): # pylint: disable=protected-access - if p.name not in skipParams: - jsonParams[p.name] = v - - extraMetadata = extraMetadata or {} - callbacks = instance.getOrDefault(instance.callbacks) - if callbacks is not None: - logger.warning( - "The callbacks parameter is saved using cloudpickle and it " - "is not a fully self-contained format. It may fail to load " - "with different versions of dependencies." - ) - serialized_callbacks = base64.encodebytes( - cloudpickle.dumps(callbacks) - ).decode("ascii") - extraMetadata["serialized_callbacks"] = serialized_callbacks - init_booster = instance.getOrDefault(instance.xgb_model) - if init_booster is not None: - extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH - DefaultParamsWriter.saveMetadata( - instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams - ) - if init_booster is not None: - ser_init_booster = serialize_booster(init_booster) - save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH) - _get_spark_session().createDataFrame( - [(ser_init_booster,)], ["init_booster"] - ).write.parquet(save_path) - - @staticmethod - def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): - """ - Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or - xgboost.spark._SparkXGBModel. - - :return: a tuple of (metadata, instance) - """ - metadata = DefaultParamsReader.loadMetadata( - path, sc, expectedClassName=get_class_name(pyspark_xgb_cls) - ) - pyspark_xgb = pyspark_xgb_cls() - DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata) - - if "serialized_callbacks" in metadata: - serialized_callbacks = metadata["serialized_callbacks"] - try: - callbacks = cloudpickle.loads( - base64.decodebytes(serialized_callbacks.encode("ascii")) - ) - pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) - except Exception as e: # pylint: disable=W0703 - logger.warning( - f"Fails to load the callbacks param due to {e}. Please set the " - "callbacks param manually for the loaded estimator." - ) - - if "init_booster" in metadata: - load_path = os.path.join(path, metadata["init_booster"]) - ser_init_booster = ( - _get_spark_session().read.parquet(load_path).collect()[0].init_booster - ) - init_booster = deserialize_booster(ser_init_booster) - pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) - - pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access - return metadata, pyspark_xgb - - -class SparkXGBWriter(MLWriter): - """ - Spark Xgboost estimator writer. - """ - - def __init__(self, instance): - super().__init__() - self.instance = instance - self.logger = get_logger(self.__class__.__name__, level="WARN") - - def saveImpl(self, path): - """ - save model. - """ - _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) - - -class SparkXGBReader(MLReader): - """ - Spark Xgboost estimator reader. - """ - - def __init__(self, cls): - super().__init__() - self.cls = cls - self.logger = get_logger(self.__class__.__name__, level="WARN") - - def load(self, path): - """ - load model. - """ - _, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance( - self.cls, path, self.sc, self.logger - ) - return pyspark_xgb - - -class SparkXGBModelWriter(MLWriter): - """ - Spark Xgboost model writer. - """ - - def __init__(self, instance): - super().__init__() - self.instance = instance - self.logger = get_logger(self.__class__.__name__, level="WARN") - - def saveImpl(self, path): - """ - Save metadata and model for a :py:class:`_SparkXGBModel` - - save metadata to path/metadata - - save model to path/model.json - """ - xgb_model = self.instance._xgb_sklearn_model - _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) - model_save_path = os.path.join(path, "model") - booster = xgb_model.get_booster().save_raw("json").decode("utf-8") - _get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile( - model_save_path - ) - - -class SparkXGBModelReader(MLReader): - """ - Spark Xgboost model reader. - """ - - def __init__(self, cls): - super().__init__() - self.cls = cls - self.logger = get_logger(self.__class__.__name__, level="WARN") - - def load(self, path): - """ - Load metadata and model for a :py:class:`_SparkXGBModel` - - :return: SparkXGBRegressorModel or SparkXGBClassifierModel instance - """ - _, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance( - self.cls, path, self.sc, self.logger - ) - - xgb_sklearn_params = py_model._gen_xgb_params_dict( - gen_xgb_sklearn_estimator_param=True - ) - model_load_path = os.path.join(path, "model") - - ser_xgb_model = ( - _get_spark_session().sparkContext.textFile(model_load_path).collect()[0] - ) - - def create_xgb_model(): - return self.cls._xgb_cls()(**xgb_sklearn_params) - - xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model) - py_model._xgb_sklearn_model = xgb_model - return py_model diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 979c40ea9..46e465dde 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -1,15 +1,19 @@ """Xgboost pyspark integration submodule for helper functions.""" +# pylint: disable=fixme + import inspect import logging +import os import sys +import uuid from threading import Thread from typing import Any, Callable, Dict, Set, Type import pyspark -from pyspark import BarrierTaskContext, SparkContext +from pyspark import BarrierTaskContext, SparkContext, SparkFiles from pyspark.sql.session import SparkSession -from xgboost import collective +from xgboost import Booster, XGBModel, collective from xgboost.tracker import RabitTracker @@ -133,3 +137,52 @@ def _get_gpu_id(task_context: BarrierTaskContext) -> int: ) # return the first gpu id. return int(resources["gpu"].addresses[0].strip()) + + +def _get_or_create_tmp_dir() -> str: + root_dir = SparkFiles.getRootDirectory() + xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp") + if not os.path.exists(xgb_tmp_dir): + os.makedirs(xgb_tmp_dir) + return xgb_tmp_dir + + +def deserialize_xgb_model( + model: str, xgb_model_creator: Callable[[], XGBModel] +) -> XGBModel: + """ + Deserialize an xgboost.XGBModel instance from the input model. + """ + xgb_model = xgb_model_creator() + xgb_model.load_model(bytearray(model.encode("utf-8"))) + return xgb_model + + +def serialize_booster(booster: Booster) -> str: + """ + Serialize the input booster to a string. + + Parameters + ---------- + booster: + an xgboost.core.Booster instance + """ + # TODO: change to use string io + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") + booster.save_model(tmp_file_name) + with open(tmp_file_name, encoding="utf-8") as f: + ser_model_string = f.read() + return ser_model_string + + +def deserialize_booster(model: str) -> Booster: + """ + Deserialize an xgboost.core.Booster from the input ser_model_string. + """ + booster = Booster() + # TODO: change to use string io + tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json") + with open(tmp_file_name, "w", encoding="utf-8") as f: + f.write(model) + booster.load_model(tmp_file_name) + return booster