Merge branch 'master' into sync-condition-2023May15

This commit is contained in:
amdsc21 2023-05-17 19:55:50 +02:00
commit 88fc8badfa
6 changed files with 352 additions and 341 deletions

View File

@ -35,6 +35,7 @@
<maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.17.0</flink.version> <flink.version>1.17.0</flink.version>
<spark.version>3.4.0</spark.version> <spark.version>3.4.0</spark.version>
<spark.version.gpu>3.3.2</spark.version.gpu>
<scala.version>2.12.17</scala.version> <scala.version>2.12.17</scala.version>
<scala.binary.version>2.12</scala.binary.version> <scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.3.5</hadoop.version> <hadoop.version>3.3.5</hadoop.version>

View File

@ -29,19 +29,19 @@
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId> <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version> <version>${spark.version.gpu}</version>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId> <artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version> <version>${spark.version.gpu}</version>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.binary.version}</artifactId> <artifactId>spark-mllib_${scala.binary.version}</artifactId>
<version>${spark.version}</version> <version>${spark.version.gpu}</version>
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency> <dependency>

View File

@ -1,13 +1,17 @@
# type: ignore # type: ignore
"""Xgboost pyspark integration submodule for core code.""" """Xgboost pyspark integration submodule for core code."""
import base64
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches # pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json import json
import os
from collections import namedtuple from collections import namedtuple
from typing import Iterator, List, Optional, Tuple from typing import Iterator, List, Optional, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pyspark import cloudpickle
from pyspark.ml import Estimator, Model from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT from pyspark.ml.linalg import VectorUDT
@ -21,7 +25,14 @@ from pyspark.ml.param.shared import (
HasValidationIndicatorCol, HasValidationIndicatorCol,
HasWeightCol, HasWeightCol,
) )
from pyspark.ml.util import MLReadable, MLWritable from pyspark.ml.util import (
DefaultParamsReader,
DefaultParamsWriter,
MLReadable,
MLReader,
MLWritable,
MLWriter,
)
from pyspark.sql import DataFrame from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import ( from pyspark.sql.types import (
@ -36,7 +47,7 @@ from pyspark.sql.types import (
from scipy.special import expit, softmax # pylint: disable=no-name-in-module from scipy.special import expit, softmax # pylint: disable=no-name-in-module
import xgboost import xgboost
from xgboost import XGBClassifier, XGBRanker, XGBRegressor from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available from xgboost.compat import is_cudf_available
from xgboost.core import Booster from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS from xgboost.sklearn import DEFAULT_N_ESTIMATORS
@ -49,12 +60,6 @@ from .data import (
pred_contribs, pred_contribs,
stack_series, stack_series,
) )
from .model import (
SparkXGBModelReader,
SparkXGBModelWriter,
SparkXGBReader,
SparkXGBWriter,
)
from .params import ( from .params import (
HasArbitraryParamsDict, HasArbitraryParamsDict,
HasBaseMarginCol, HasBaseMarginCol,
@ -71,8 +76,11 @@ from .utils import (
_get_rabit_args, _get_rabit_args,
_get_spark_session, _get_spark_session,
_is_local, _is_local,
deserialize_booster,
deserialize_xgb_model,
get_class_name, get_class_name,
get_logger, get_logger,
serialize_booster,
) )
# Put pyspark specific params here, they won't be passed to XGBoost. # Put pyspark specific params here, they won't be passed to XGBoost.
@ -156,6 +164,8 @@ Pred = namedtuple(
) )
pred = Pred("prediction", "rawPrediction", "probability", "predContrib") pred = Pred("prediction", "rawPrediction", "probability", "predContrib")
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
class _SparkXGBParams( class _SparkXGBParams(
HasFeaturesCol, HasFeaturesCol,
@ -1122,31 +1132,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
return dataset return dataset
class SparkXGBRegressorModel(_SparkXGBModel): class _ClassificationModel( # pylint: disable=abstract-method
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRegressor
class SparkXGBRankerModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRanker
class SparkXGBClassifierModel(
_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol _SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol
): ):
""" """
@ -1155,10 +1141,6 @@ class SparkXGBClassifierModel(
.. Note:: This API is experimental. .. Note:: This API is experimental.
""" """
@classmethod
def _xgb_cls(cls):
return XGBClassifier
def _transform(self, dataset): def _transform(self, dataset):
# pylint: disable=too-many-statements, too-many-locals # pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable # Save xgb_sklearn_model and predict_params to be local variable
@ -1286,53 +1268,178 @@ class SparkXGBClassifierModel(
return dataset.drop(pred_struct_col) return dataset.drop(pred_struct_col)
def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class): class _SparkXGBSharedReadWrite:
params_dict = pyspark_estimator_class._get_xgb_params_default() @staticmethod
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
"""
Save the metadata of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
"""
instance._validate_params()
skipParams = ["callbacks", "xgb_model"]
jsonParams = {}
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
if p.name not in skipParams:
jsonParams[p.name] = v
def param_value_converter(v): extraMetadata = extraMetadata or {}
if isinstance(v, np.generic): callbacks = instance.getOrDefault(instance.callbacks)
# convert numpy scalar values to corresponding python scalar values if callbacks is not None:
return np.array(v).item() logger.warning(
if isinstance(v, dict): "The callbacks parameter is saved using cloudpickle and it "
return {k: param_value_converter(nv) for k, nv in v.items()} "is not a fully self-contained format. It may fail to load "
if isinstance(v, list): "with different versions of dependencies."
return [param_value_converter(nv) for nv in v]
return v
def set_param_attrs(attr_name, param_obj_):
param_obj_.typeConverter = param_value_converter
setattr(pyspark_estimator_class, attr_name, param_obj_)
setattr(pyspark_model_class, attr_name, param_obj_)
for name in params_dict.keys():
doc = (
f"Refer to XGBoost doc of "
f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}"
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
fit_params_dict = pyspark_estimator_class._get_fit_params_default()
for name in fit_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
f".fit() for this param {name}"
)
if name == "callbacks":
doc += (
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
"which is not a fully self-contained format. It may fail to load with "
"different versions of dependencies."
) )
param_obj = Param(Params._dummy(), name=name, doc=doc) serialized_callbacks = base64.encodebytes(
set_param_attrs(name, param_obj) cloudpickle.dumps(callbacks)
).decode("ascii")
predict_params_dict = pyspark_estimator_class._get_predict_params_default() extraMetadata["serialized_callbacks"] = serialized_callbacks
for name in predict_params_dict.keys(): init_booster = instance.getOrDefault(instance.xgb_model)
doc = ( if init_booster is not None:
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}" extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
f".predict() for this param {name}" DefaultParamsWriter.saveMetadata(
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
) )
param_obj = Param(Params._dummy(), name=name, doc=doc) if init_booster is not None:
set_param_attrs(name, param_obj) ser_init_booster = serialize_booster(init_booster)
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
_get_spark_session().createDataFrame(
[(ser_init_booster,)], ["init_booster"]
).write.parquet(save_path)
@staticmethod
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
"""
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
:return: a tuple of (metadata, instance)
"""
metadata = DefaultParamsReader.loadMetadata(
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
)
pyspark_xgb = pyspark_xgb_cls()
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
if "serialized_callbacks" in metadata:
serialized_callbacks = metadata["serialized_callbacks"]
try:
callbacks = cloudpickle.loads(
base64.decodebytes(serialized_callbacks.encode("ascii"))
)
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
except Exception as e: # pylint: disable=W0703
logger.warning(
f"Fails to load the callbacks param due to {e}. Please set the "
"callbacks param manually for the loaded estimator."
)
if "init_booster" in metadata:
load_path = os.path.join(path, metadata["init_booster"])
ser_init_booster = (
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
)
init_booster = deserialize_booster(ser_init_booster)
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
return metadata, pyspark_xgb
class SparkXGBWriter(MLWriter):
"""
Spark Xgboost estimator writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
save model.
"""
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
class SparkXGBReader(MLReader):
"""
Spark Xgboost estimator reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
load model.
"""
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
return pyspark_xgb
class SparkXGBModelWriter(MLWriter):
"""
Spark Xgboost model writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
Save metadata and model for a :py:class:`_SparkXGBModel`
- save metadata to path/metadata
- save model to path/model.json
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
)
class SparkXGBModelReader(MLReader):
"""
Spark Xgboost model reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
Load metadata and model for a :py:class:`_SparkXGBModel`
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
"""
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model")
ser_xgb_model = (
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
)
def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
py_model._xgb_sklearn_model = xgb_model
return py_model

View File

@ -1,18 +1,77 @@
"""Xgboost pyspark integration submodule for estimator API.""" """Xgboost pyspark integration submodule for estimator API."""
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
from typing import Any, Type from typing import Any, Type
import numpy as np
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from xgboost import XGBClassifier, XGBRanker, XGBRegressor from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from .core import ( # type: ignore from .core import ( # type: ignore
SparkXGBClassifierModel, _ClassificationModel,
SparkXGBRankerModel,
SparkXGBRegressorModel,
_set_pyspark_xgb_cls_param_attrs,
_SparkXGBEstimator, _SparkXGBEstimator,
_SparkXGBModel,
) )
from .utils import get_class_name
def _set_pyspark_xgb_cls_param_attrs(
estimator: _SparkXGBEstimator, model: _SparkXGBModel
) -> None:
"""This function automatically infer to xgboost parameters and set them
into corresponding pyspark estimators and models"""
params_dict = estimator._get_xgb_params_default()
def param_value_converter(v: Any) -> Any:
if isinstance(v, np.generic):
# convert numpy scalar values to corresponding python scalar values
return np.array(v).item()
if isinstance(v, dict):
return {k: param_value_converter(nv) for k, nv in v.items()}
if isinstance(v, list):
return [param_value_converter(nv) for nv in v]
return v
def set_param_attrs(attr_name: str, param: Param) -> None:
param.typeConverter = param_value_converter
setattr(estimator, attr_name, param)
setattr(model, attr_name, param)
for name in params_dict.keys():
doc = (
f"Refer to XGBoost doc of "
f"{get_class_name(estimator._xgb_cls())} for this param {name}"
)
param_obj: Param = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
fit_params_dict = estimator._get_fit_params_default()
for name in fit_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
f".fit() for this param {name}"
)
if name == "callbacks":
doc += (
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
"which is not a fully self-contained format. It may fail to load with "
"different versions of dependencies."
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
predict_params_dict = estimator._get_predict_params_default()
for name in predict_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
f".predict() for this param {name}"
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
class SparkXGBRegressor(_SparkXGBEstimator): class SparkXGBRegressor(_SparkXGBEstimator):
@ -105,7 +164,7 @@ class SparkXGBRegressor(_SparkXGBEstimator):
return XGBRegressor return XGBRegressor
@classmethod @classmethod
def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]: def _pyspark_model_cls(cls) -> Type["SparkXGBRegressorModel"]:
return SparkXGBRegressorModel return SparkXGBRegressorModel
def _validate_params(self) -> None: def _validate_params(self) -> None:
@ -116,6 +175,18 @@ class SparkXGBRegressor(_SparkXGBEstimator):
) )
class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls) -> Type[XGBRegressor]:
return XGBRegressor
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel) _set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
@ -224,7 +295,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
return XGBClassifier return XGBClassifier
@classmethod @classmethod
def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]: def _pyspark_model_cls(cls) -> Type["SparkXGBClassifierModel"]:
return SparkXGBClassifierModel return SparkXGBClassifierModel
def _validate_params(self) -> None: def _validate_params(self) -> None:
@ -239,6 +310,18 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
) )
class SparkXGBClassifierModel(_ClassificationModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls) -> Type[XGBClassifier]:
return XGBClassifier
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel) _set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
@ -352,7 +435,7 @@ class SparkXGBRanker(_SparkXGBEstimator):
return XGBRanker return XGBRanker
@classmethod @classmethod
def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]: def _pyspark_model_cls(cls) -> Type["SparkXGBRankerModel"]:
return SparkXGBRankerModel return SparkXGBRankerModel
def _validate_params(self) -> None: def _validate_params(self) -> None:
@ -363,4 +446,16 @@ class SparkXGBRanker(_SparkXGBEstimator):
) )
class SparkXGBRankerModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls) -> Type[XGBRanker]:
return XGBRanker
_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel) _set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel)

View File

@ -1,245 +0,0 @@
# type: ignore
"""Xgboost pyspark integration submodule for model API."""
# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods
import base64
import os
import uuid
from pyspark import SparkFiles, cloudpickle
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
from pyspark.sql import SparkSession
from xgboost.core import Booster
from .utils import get_class_name, get_logger
def _get_or_create_tmp_dir():
root_dir = SparkFiles.getRootDirectory()
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
if not os.path.exists(xgb_tmp_dir):
os.makedirs(xgb_tmp_dir)
return xgb_tmp_dir
def deserialize_xgb_model(model_string, xgb_model_creator):
"""
Deserialize an xgboost.XGBModel instance from the input model_string.
"""
xgb_model = xgb_model_creator()
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
return xgb_model
def serialize_booster(booster):
"""
Serialize the input booster to a string.
Parameters
----------
booster:
an xgboost.core.Booster instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
booster.save_model(tmp_file_name)
with open(tmp_file_name, encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
def deserialize_booster(ser_model_string):
"""
Deserialize an xgboost.core.Booster from the input ser_model_string.
"""
booster = Booster()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(ser_model_string)
booster.load_model(tmp_file_name)
return booster
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
def _get_spark_session():
return SparkSession.builder.getOrCreate()
class _SparkXGBSharedReadWrite:
@staticmethod
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
"""
Save the metadata of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
"""
instance._validate_params()
skipParams = ["callbacks", "xgb_model"]
jsonParams = {}
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
if p.name not in skipParams:
jsonParams[p.name] = v
extraMetadata = extraMetadata or {}
callbacks = instance.getOrDefault(instance.callbacks)
if callbacks is not None:
logger.warning(
"The callbacks parameter is saved using cloudpickle and it "
"is not a fully self-contained format. It may fail to load "
"with different versions of dependencies."
)
serialized_callbacks = base64.encodebytes(
cloudpickle.dumps(callbacks)
).decode("ascii")
extraMetadata["serialized_callbacks"] = serialized_callbacks
init_booster = instance.getOrDefault(instance.xgb_model)
if init_booster is not None:
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
DefaultParamsWriter.saveMetadata(
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
)
if init_booster is not None:
ser_init_booster = serialize_booster(init_booster)
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
_get_spark_session().createDataFrame(
[(ser_init_booster,)], ["init_booster"]
).write.parquet(save_path)
@staticmethod
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
"""
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
:return: a tuple of (metadata, instance)
"""
metadata = DefaultParamsReader.loadMetadata(
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
)
pyspark_xgb = pyspark_xgb_cls()
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
if "serialized_callbacks" in metadata:
serialized_callbacks = metadata["serialized_callbacks"]
try:
callbacks = cloudpickle.loads(
base64.decodebytes(serialized_callbacks.encode("ascii"))
)
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
except Exception as e: # pylint: disable=W0703
logger.warning(
f"Fails to load the callbacks param due to {e}. Please set the "
"callbacks param manually for the loaded estimator."
)
if "init_booster" in metadata:
load_path = os.path.join(path, metadata["init_booster"])
ser_init_booster = (
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
)
init_booster = deserialize_booster(ser_init_booster)
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
return metadata, pyspark_xgb
class SparkXGBWriter(MLWriter):
"""
Spark Xgboost estimator writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
save model.
"""
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
class SparkXGBReader(MLReader):
"""
Spark Xgboost estimator reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
load model.
"""
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
return pyspark_xgb
class SparkXGBModelWriter(MLWriter):
"""
Spark Xgboost model writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
Save metadata and model for a :py:class:`_SparkXGBModel`
- save metadata to path/metadata
- save model to path/model.json
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
)
class SparkXGBModelReader(MLReader):
"""
Spark Xgboost model reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
Load metadata and model for a :py:class:`_SparkXGBModel`
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
"""
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model")
ser_xgb_model = (
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
)
def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
py_model._xgb_sklearn_model = xgb_model
return py_model

View File

@ -1,15 +1,19 @@
"""Xgboost pyspark integration submodule for helper functions.""" """Xgboost pyspark integration submodule for helper functions."""
# pylint: disable=fixme
import inspect import inspect
import logging import logging
import os
import sys import sys
import uuid
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, Set, Type from typing import Any, Callable, Dict, Set, Type
import pyspark import pyspark
from pyspark import BarrierTaskContext, SparkContext from pyspark import BarrierTaskContext, SparkContext, SparkFiles
from pyspark.sql.session import SparkSession from pyspark.sql.session import SparkSession
from xgboost import collective from xgboost import Booster, XGBModel, collective
from xgboost.tracker import RabitTracker from xgboost.tracker import RabitTracker
@ -133,3 +137,52 @@ def _get_gpu_id(task_context: BarrierTaskContext) -> int:
) )
# return the first gpu id. # return the first gpu id.
return int(resources["gpu"].addresses[0].strip()) return int(resources["gpu"].addresses[0].strip())
def _get_or_create_tmp_dir() -> str:
root_dir = SparkFiles.getRootDirectory()
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
if not os.path.exists(xgb_tmp_dir):
os.makedirs(xgb_tmp_dir)
return xgb_tmp_dir
def deserialize_xgb_model(
model: str, xgb_model_creator: Callable[[], XGBModel]
) -> XGBModel:
"""
Deserialize an xgboost.XGBModel instance from the input model.
"""
xgb_model = xgb_model_creator()
xgb_model.load_model(bytearray(model.encode("utf-8")))
return xgb_model
def serialize_booster(booster: Booster) -> str:
"""
Serialize the input booster to a string.
Parameters
----------
booster:
an xgboost.core.Booster instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
booster.save_model(tmp_file_name)
with open(tmp_file_name, encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
def deserialize_booster(model: str) -> Booster:
"""
Deserialize an xgboost.core.Booster from the input ser_model_string.
"""
booster = Booster()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(model)
booster.load_model(tmp_file_name)
return booster