Merge branch 'master' into sync-condition-2023May15
This commit is contained in:
commit
88fc8badfa
@ -35,6 +35,7 @@
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<flink.version>1.17.0</flink.version>
|
||||
<spark.version>3.4.0</spark.version>
|
||||
<spark.version.gpu>3.3.2</spark.version.gpu>
|
||||
<scala.version>2.12.17</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>3.3.5</hadoop.version>
|
||||
|
||||
@ -29,19 +29,19 @@
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<version>${spark.version.gpu}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<version>${spark.version.gpu}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-mllib_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<version>${spark.version.gpu}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
import base64
|
||||
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
|
||||
import json
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyspark import cloudpickle
|
||||
from pyspark.ml import Estimator, Model
|
||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||
from pyspark.ml.linalg import VectorUDT
|
||||
@ -21,7 +25,14 @@ from pyspark.ml.param.shared import (
|
||||
HasValidationIndicatorCol,
|
||||
HasWeightCol,
|
||||
)
|
||||
from pyspark.ml.util import MLReadable, MLWritable
|
||||
from pyspark.ml.util import (
|
||||
DefaultParamsReader,
|
||||
DefaultParamsWriter,
|
||||
MLReadable,
|
||||
MLReader,
|
||||
MLWritable,
|
||||
MLWriter,
|
||||
)
|
||||
from pyspark.sql import DataFrame
|
||||
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
|
||||
from pyspark.sql.types import (
|
||||
@ -36,7 +47,7 @@ from pyspark.sql.types import (
|
||||
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
||||
|
||||
import xgboost
|
||||
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
|
||||
from xgboost import XGBClassifier
|
||||
from xgboost.compat import is_cudf_available
|
||||
from xgboost.core import Booster
|
||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS
|
||||
@ -49,12 +60,6 @@ from .data import (
|
||||
pred_contribs,
|
||||
stack_series,
|
||||
)
|
||||
from .model import (
|
||||
SparkXGBModelReader,
|
||||
SparkXGBModelWriter,
|
||||
SparkXGBReader,
|
||||
SparkXGBWriter,
|
||||
)
|
||||
from .params import (
|
||||
HasArbitraryParamsDict,
|
||||
HasBaseMarginCol,
|
||||
@ -71,8 +76,11 @@ from .utils import (
|
||||
_get_rabit_args,
|
||||
_get_spark_session,
|
||||
_is_local,
|
||||
deserialize_booster,
|
||||
deserialize_xgb_model,
|
||||
get_class_name,
|
||||
get_logger,
|
||||
serialize_booster,
|
||||
)
|
||||
|
||||
# Put pyspark specific params here, they won't be passed to XGBoost.
|
||||
@ -156,6 +164,8 @@ Pred = namedtuple(
|
||||
)
|
||||
pred = Pred("prediction", "rawPrediction", "probability", "predContrib")
|
||||
|
||||
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
|
||||
|
||||
|
||||
class _SparkXGBParams(
|
||||
HasFeaturesCol,
|
||||
@ -1122,31 +1132,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
return dataset
|
||||
|
||||
|
||||
class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
return XGBRegressor
|
||||
|
||||
|
||||
class SparkXGBRankerModel(_SparkXGBModel):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
return XGBRanker
|
||||
|
||||
|
||||
class SparkXGBClassifierModel(
|
||||
class _ClassificationModel( # pylint: disable=abstract-method
|
||||
_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol
|
||||
):
|
||||
"""
|
||||
@ -1155,10 +1141,6 @@ class SparkXGBClassifierModel(
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
return XGBClassifier
|
||||
|
||||
def _transform(self, dataset):
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
@ -1286,53 +1268,178 @@ class SparkXGBClassifierModel(
|
||||
return dataset.drop(pred_struct_col)
|
||||
|
||||
|
||||
def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class):
|
||||
params_dict = pyspark_estimator_class._get_xgb_params_default()
|
||||
class _SparkXGBSharedReadWrite:
|
||||
@staticmethod
|
||||
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
|
||||
"""
|
||||
Save the metadata of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
"""
|
||||
instance._validate_params()
|
||||
skipParams = ["callbacks", "xgb_model"]
|
||||
jsonParams = {}
|
||||
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
|
||||
if p.name not in skipParams:
|
||||
jsonParams[p.name] = v
|
||||
|
||||
def param_value_converter(v):
|
||||
if isinstance(v, np.generic):
|
||||
# convert numpy scalar values to corresponding python scalar values
|
||||
return np.array(v).item()
|
||||
if isinstance(v, dict):
|
||||
return {k: param_value_converter(nv) for k, nv in v.items()}
|
||||
if isinstance(v, list):
|
||||
return [param_value_converter(nv) for nv in v]
|
||||
return v
|
||||
|
||||
def set_param_attrs(attr_name, param_obj_):
|
||||
param_obj_.typeConverter = param_value_converter
|
||||
setattr(pyspark_estimator_class, attr_name, param_obj_)
|
||||
setattr(pyspark_model_class, attr_name, param_obj_)
|
||||
|
||||
for name in params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of "
|
||||
f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}"
|
||||
)
|
||||
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
fit_params_dict = pyspark_estimator_class._get_fit_params_default()
|
||||
for name in fit_params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
|
||||
f".fit() for this param {name}"
|
||||
)
|
||||
if name == "callbacks":
|
||||
doc += (
|
||||
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
|
||||
"which is not a fully self-contained format. It may fail to load with "
|
||||
"different versions of dependencies."
|
||||
extraMetadata = extraMetadata or {}
|
||||
callbacks = instance.getOrDefault(instance.callbacks)
|
||||
if callbacks is not None:
|
||||
logger.warning(
|
||||
"The callbacks parameter is saved using cloudpickle and it "
|
||||
"is not a fully self-contained format. It may fail to load "
|
||||
"with different versions of dependencies."
|
||||
)
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
predict_params_dict = pyspark_estimator_class._get_predict_params_default()
|
||||
for name in predict_params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
|
||||
f".predict() for this param {name}"
|
||||
serialized_callbacks = base64.encodebytes(
|
||||
cloudpickle.dumps(callbacks)
|
||||
).decode("ascii")
|
||||
extraMetadata["serialized_callbacks"] = serialized_callbacks
|
||||
init_booster = instance.getOrDefault(instance.xgb_model)
|
||||
if init_booster is not None:
|
||||
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
|
||||
DefaultParamsWriter.saveMetadata(
|
||||
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
|
||||
)
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
if init_booster is not None:
|
||||
ser_init_booster = serialize_booster(init_booster)
|
||||
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
|
||||
_get_spark_session().createDataFrame(
|
||||
[(ser_init_booster,)], ["init_booster"]
|
||||
).write.parquet(save_path)
|
||||
|
||||
@staticmethod
|
||||
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
|
||||
"""
|
||||
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
|
||||
:return: a tuple of (metadata, instance)
|
||||
"""
|
||||
metadata = DefaultParamsReader.loadMetadata(
|
||||
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
|
||||
)
|
||||
pyspark_xgb = pyspark_xgb_cls()
|
||||
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
|
||||
|
||||
if "serialized_callbacks" in metadata:
|
||||
serialized_callbacks = metadata["serialized_callbacks"]
|
||||
try:
|
||||
callbacks = cloudpickle.loads(
|
||||
base64.decodebytes(serialized_callbacks.encode("ascii"))
|
||||
)
|
||||
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.warning(
|
||||
f"Fails to load the callbacks param due to {e}. Please set the "
|
||||
"callbacks param manually for the loaded estimator."
|
||||
)
|
||||
|
||||
if "init_booster" in metadata:
|
||||
load_path = os.path.join(path, metadata["init_booster"])
|
||||
ser_init_booster = (
|
||||
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
|
||||
)
|
||||
init_booster = deserialize_booster(ser_init_booster)
|
||||
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
|
||||
|
||||
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
|
||||
return metadata, pyspark_xgb
|
||||
|
||||
|
||||
class SparkXGBWriter(MLWriter):
|
||||
"""
|
||||
Spark Xgboost estimator writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
"""
|
||||
save model.
|
||||
"""
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
|
||||
|
||||
class SparkXGBReader(MLReader):
|
||||
"""
|
||||
Spark Xgboost estimator reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
"""
|
||||
load model.
|
||||
"""
|
||||
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
return pyspark_xgb
|
||||
|
||||
|
||||
class SparkXGBModelWriter(MLWriter):
|
||||
"""
|
||||
Spark Xgboost model writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
"""
|
||||
Save metadata and model for a :py:class:`_SparkXGBModel`
|
||||
- save metadata to path/metadata
|
||||
- save model to path/model.json
|
||||
"""
|
||||
xgb_model = self.instance._xgb_sklearn_model
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
model_save_path = os.path.join(path, "model")
|
||||
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
|
||||
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
|
||||
model_save_path
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBModelReader(MLReader):
|
||||
"""
|
||||
Spark Xgboost model reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
"""
|
||||
Load metadata and model for a :py:class:`_SparkXGBModel`
|
||||
|
||||
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
|
||||
"""
|
||||
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
|
||||
xgb_sklearn_params = py_model._gen_xgb_params_dict(
|
||||
gen_xgb_sklearn_estimator_param=True
|
||||
)
|
||||
model_load_path = os.path.join(path, "model")
|
||||
|
||||
ser_xgb_model = (
|
||||
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
|
||||
)
|
||||
|
||||
def create_xgb_model():
|
||||
return self.cls._xgb_cls()(**xgb_sklearn_params)
|
||||
|
||||
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
|
||||
py_model._xgb_sklearn_model = xgb_model
|
||||
return py_model
|
||||
|
||||
@ -1,18 +1,77 @@
|
||||
"""Xgboost pyspark integration submodule for estimator API."""
|
||||
# pylint: disable=too-many-ancestors
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
|
||||
from typing import Any, Type
|
||||
|
||||
import numpy as np
|
||||
from pyspark.ml.param import Param, Params
|
||||
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
|
||||
|
||||
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
|
||||
|
||||
from .core import ( # type: ignore
|
||||
SparkXGBClassifierModel,
|
||||
SparkXGBRankerModel,
|
||||
SparkXGBRegressorModel,
|
||||
_set_pyspark_xgb_cls_param_attrs,
|
||||
_ClassificationModel,
|
||||
_SparkXGBEstimator,
|
||||
_SparkXGBModel,
|
||||
)
|
||||
from .utils import get_class_name
|
||||
|
||||
|
||||
def _set_pyspark_xgb_cls_param_attrs(
|
||||
estimator: _SparkXGBEstimator, model: _SparkXGBModel
|
||||
) -> None:
|
||||
"""This function automatically infer to xgboost parameters and set them
|
||||
into corresponding pyspark estimators and models"""
|
||||
params_dict = estimator._get_xgb_params_default()
|
||||
|
||||
def param_value_converter(v: Any) -> Any:
|
||||
if isinstance(v, np.generic):
|
||||
# convert numpy scalar values to corresponding python scalar values
|
||||
return np.array(v).item()
|
||||
if isinstance(v, dict):
|
||||
return {k: param_value_converter(nv) for k, nv in v.items()}
|
||||
if isinstance(v, list):
|
||||
return [param_value_converter(nv) for nv in v]
|
||||
return v
|
||||
|
||||
def set_param_attrs(attr_name: str, param: Param) -> None:
|
||||
param.typeConverter = param_value_converter
|
||||
setattr(estimator, attr_name, param)
|
||||
setattr(model, attr_name, param)
|
||||
|
||||
for name in params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of "
|
||||
f"{get_class_name(estimator._xgb_cls())} for this param {name}"
|
||||
)
|
||||
|
||||
param_obj: Param = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
fit_params_dict = estimator._get_fit_params_default()
|
||||
for name in fit_params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
|
||||
f".fit() for this param {name}"
|
||||
)
|
||||
if name == "callbacks":
|
||||
doc += (
|
||||
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
|
||||
"which is not a fully self-contained format. It may fail to load with "
|
||||
"different versions of dependencies."
|
||||
)
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
predict_params_dict = estimator._get_predict_params_default()
|
||||
for name in predict_params_dict.keys():
|
||||
doc = (
|
||||
f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
|
||||
f".predict() for this param {name}"
|
||||
)
|
||||
param_obj = Param(Params._dummy(), name=name, doc=doc)
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
|
||||
class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
@ -105,7 +164,7 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
return XGBRegressor
|
||||
|
||||
@classmethod
|
||||
def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]:
|
||||
def _pyspark_model_cls(cls) -> Type["SparkXGBRegressorModel"]:
|
||||
return SparkXGBRegressorModel
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
@ -116,6 +175,18 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBRegressorModel(_SparkXGBModel):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls) -> Type[XGBRegressor]:
|
||||
return XGBRegressor
|
||||
|
||||
|
||||
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
|
||||
|
||||
|
||||
@ -224,7 +295,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
return XGBClassifier
|
||||
|
||||
@classmethod
|
||||
def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]:
|
||||
def _pyspark_model_cls(cls) -> Type["SparkXGBClassifierModel"]:
|
||||
return SparkXGBClassifierModel
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
@ -239,6 +310,18 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBClassifierModel(_ClassificationModel):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls) -> Type[XGBClassifier]:
|
||||
return XGBClassifier
|
||||
|
||||
|
||||
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
|
||||
|
||||
|
||||
@ -352,7 +435,7 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
return XGBRanker
|
||||
|
||||
@classmethod
|
||||
def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]:
|
||||
def _pyspark_model_cls(cls) -> Type["SparkXGBRankerModel"]:
|
||||
return SparkXGBRankerModel
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
@ -363,4 +446,16 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBRankerModel(_SparkXGBModel):
|
||||
"""
|
||||
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls) -> Type[XGBRanker]:
|
||||
return XGBRanker
|
||||
|
||||
|
||||
_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel)
|
||||
|
||||
@ -1,245 +0,0 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for model API."""
|
||||
# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from pyspark import SparkFiles, cloudpickle
|
||||
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
from xgboost.core import Booster
|
||||
|
||||
from .utils import get_class_name, get_logger
|
||||
|
||||
|
||||
def _get_or_create_tmp_dir():
|
||||
root_dir = SparkFiles.getRootDirectory()
|
||||
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
|
||||
if not os.path.exists(xgb_tmp_dir):
|
||||
os.makedirs(xgb_tmp_dir)
|
||||
return xgb_tmp_dir
|
||||
|
||||
|
||||
def deserialize_xgb_model(model_string, xgb_model_creator):
|
||||
"""
|
||||
Deserialize an xgboost.XGBModel instance from the input model_string.
|
||||
"""
|
||||
xgb_model = xgb_model_creator()
|
||||
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
|
||||
return xgb_model
|
||||
|
||||
|
||||
def serialize_booster(booster):
|
||||
"""
|
||||
Serialize the input booster to a string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster:
|
||||
an xgboost.core.Booster instance
|
||||
"""
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
booster.save_model(tmp_file_name)
|
||||
with open(tmp_file_name, encoding="utf-8") as f:
|
||||
ser_model_string = f.read()
|
||||
return ser_model_string
|
||||
|
||||
|
||||
def deserialize_booster(ser_model_string):
|
||||
"""
|
||||
Deserialize an xgboost.core.Booster from the input ser_model_string.
|
||||
"""
|
||||
booster = Booster()
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
with open(tmp_file_name, "w", encoding="utf-8") as f:
|
||||
f.write(ser_model_string)
|
||||
booster.load_model(tmp_file_name)
|
||||
return booster
|
||||
|
||||
|
||||
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
|
||||
|
||||
|
||||
def _get_spark_session():
|
||||
return SparkSession.builder.getOrCreate()
|
||||
|
||||
|
||||
class _SparkXGBSharedReadWrite:
|
||||
@staticmethod
|
||||
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
|
||||
"""
|
||||
Save the metadata of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
"""
|
||||
instance._validate_params()
|
||||
skipParams = ["callbacks", "xgb_model"]
|
||||
jsonParams = {}
|
||||
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
|
||||
if p.name not in skipParams:
|
||||
jsonParams[p.name] = v
|
||||
|
||||
extraMetadata = extraMetadata or {}
|
||||
callbacks = instance.getOrDefault(instance.callbacks)
|
||||
if callbacks is not None:
|
||||
logger.warning(
|
||||
"The callbacks parameter is saved using cloudpickle and it "
|
||||
"is not a fully self-contained format. It may fail to load "
|
||||
"with different versions of dependencies."
|
||||
)
|
||||
serialized_callbacks = base64.encodebytes(
|
||||
cloudpickle.dumps(callbacks)
|
||||
).decode("ascii")
|
||||
extraMetadata["serialized_callbacks"] = serialized_callbacks
|
||||
init_booster = instance.getOrDefault(instance.xgb_model)
|
||||
if init_booster is not None:
|
||||
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
|
||||
DefaultParamsWriter.saveMetadata(
|
||||
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
|
||||
)
|
||||
if init_booster is not None:
|
||||
ser_init_booster = serialize_booster(init_booster)
|
||||
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
|
||||
_get_spark_session().createDataFrame(
|
||||
[(ser_init_booster,)], ["init_booster"]
|
||||
).write.parquet(save_path)
|
||||
|
||||
@staticmethod
|
||||
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
|
||||
"""
|
||||
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
|
||||
:return: a tuple of (metadata, instance)
|
||||
"""
|
||||
metadata = DefaultParamsReader.loadMetadata(
|
||||
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
|
||||
)
|
||||
pyspark_xgb = pyspark_xgb_cls()
|
||||
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
|
||||
|
||||
if "serialized_callbacks" in metadata:
|
||||
serialized_callbacks = metadata["serialized_callbacks"]
|
||||
try:
|
||||
callbacks = cloudpickle.loads(
|
||||
base64.decodebytes(serialized_callbacks.encode("ascii"))
|
||||
)
|
||||
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.warning(
|
||||
f"Fails to load the callbacks param due to {e}. Please set the "
|
||||
"callbacks param manually for the loaded estimator."
|
||||
)
|
||||
|
||||
if "init_booster" in metadata:
|
||||
load_path = os.path.join(path, metadata["init_booster"])
|
||||
ser_init_booster = (
|
||||
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
|
||||
)
|
||||
init_booster = deserialize_booster(ser_init_booster)
|
||||
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
|
||||
|
||||
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
|
||||
return metadata, pyspark_xgb
|
||||
|
||||
|
||||
class SparkXGBWriter(MLWriter):
|
||||
"""
|
||||
Spark Xgboost estimator writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
"""
|
||||
save model.
|
||||
"""
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
|
||||
|
||||
class SparkXGBReader(MLReader):
|
||||
"""
|
||||
Spark Xgboost estimator reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
"""
|
||||
load model.
|
||||
"""
|
||||
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
return pyspark_xgb
|
||||
|
||||
|
||||
class SparkXGBModelWriter(MLWriter):
|
||||
"""
|
||||
Spark Xgboost model writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
"""
|
||||
Save metadata and model for a :py:class:`_SparkXGBModel`
|
||||
- save metadata to path/metadata
|
||||
- save model to path/model.json
|
||||
"""
|
||||
xgb_model = self.instance._xgb_sklearn_model
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
model_save_path = os.path.join(path, "model")
|
||||
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
|
||||
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
|
||||
model_save_path
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBModelReader(MLReader):
|
||||
"""
|
||||
Spark Xgboost model reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
"""
|
||||
Load metadata and model for a :py:class:`_SparkXGBModel`
|
||||
|
||||
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
|
||||
"""
|
||||
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
|
||||
xgb_sklearn_params = py_model._gen_xgb_params_dict(
|
||||
gen_xgb_sklearn_estimator_param=True
|
||||
)
|
||||
model_load_path = os.path.join(path, "model")
|
||||
|
||||
ser_xgb_model = (
|
||||
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
|
||||
)
|
||||
|
||||
def create_xgb_model():
|
||||
return self.cls._xgb_cls()(**xgb_sklearn_params)
|
||||
|
||||
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
|
||||
py_model._xgb_sklearn_model = xgb_model
|
||||
return py_model
|
||||
@ -1,15 +1,19 @@
|
||||
"""Xgboost pyspark integration submodule for helper functions."""
|
||||
# pylint: disable=fixme
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, Set, Type
|
||||
|
||||
import pyspark
|
||||
from pyspark import BarrierTaskContext, SparkContext
|
||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles
|
||||
from pyspark.sql.session import SparkSession
|
||||
|
||||
from xgboost import collective
|
||||
from xgboost import Booster, XGBModel, collective
|
||||
from xgboost.tracker import RabitTracker
|
||||
|
||||
|
||||
@ -133,3 +137,52 @@ def _get_gpu_id(task_context: BarrierTaskContext) -> int:
|
||||
)
|
||||
# return the first gpu id.
|
||||
return int(resources["gpu"].addresses[0].strip())
|
||||
|
||||
|
||||
def _get_or_create_tmp_dir() -> str:
|
||||
root_dir = SparkFiles.getRootDirectory()
|
||||
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
|
||||
if not os.path.exists(xgb_tmp_dir):
|
||||
os.makedirs(xgb_tmp_dir)
|
||||
return xgb_tmp_dir
|
||||
|
||||
|
||||
def deserialize_xgb_model(
|
||||
model: str, xgb_model_creator: Callable[[], XGBModel]
|
||||
) -> XGBModel:
|
||||
"""
|
||||
Deserialize an xgboost.XGBModel instance from the input model.
|
||||
"""
|
||||
xgb_model = xgb_model_creator()
|
||||
xgb_model.load_model(bytearray(model.encode("utf-8")))
|
||||
return xgb_model
|
||||
|
||||
|
||||
def serialize_booster(booster: Booster) -> str:
|
||||
"""
|
||||
Serialize the input booster to a string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster:
|
||||
an xgboost.core.Booster instance
|
||||
"""
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
booster.save_model(tmp_file_name)
|
||||
with open(tmp_file_name, encoding="utf-8") as f:
|
||||
ser_model_string = f.read()
|
||||
return ser_model_string
|
||||
|
||||
|
||||
def deserialize_booster(model: str) -> Booster:
|
||||
"""
|
||||
Deserialize an xgboost.core.Booster from the input ser_model_string.
|
||||
"""
|
||||
booster = Booster()
|
||||
# TODO: change to use string io
|
||||
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
|
||||
with open(tmp_file_name, "w", encoding="utf-8") as f:
|
||||
f.write(model)
|
||||
booster.load_model(tmp_file_name)
|
||||
return booster
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user