Merge branch 'master' into sync-condition-2023May15

This commit is contained in:
amdsc21 2023-05-17 19:55:50 +02:00
commit 88fc8badfa
6 changed files with 352 additions and 341 deletions

View File

@ -35,6 +35,7 @@
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.17.0</flink.version>
<spark.version>3.4.0</spark.version>
<spark.version.gpu>3.3.2</spark.version.gpu>
<scala.version>2.12.17</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.3.5</hadoop.version>

View File

@ -29,19 +29,19 @@
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<version>${spark.version.gpu}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<version>${spark.version.gpu}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<version>${spark.version.gpu}</version>
<scope>provided</scope>
</dependency>
<dependency>

View File

@ -1,13 +1,17 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code."""
import base64
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json
import os
from collections import namedtuple
from typing import Iterator, List, Optional, Tuple
import numpy as np
import pandas as pd
from pyspark import cloudpickle
from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT
@ -21,7 +25,14 @@ from pyspark.ml.param.shared import (
HasValidationIndicatorCol,
HasWeightCol,
)
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.ml.util import (
DefaultParamsReader,
DefaultParamsWriter,
MLReadable,
MLReader,
MLWritable,
MLWriter,
)
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
@ -36,7 +47,7 @@ from pyspark.sql.types import (
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
import xgboost
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available
from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS
@ -49,12 +60,6 @@ from .data import (
pred_contribs,
stack_series,
)
from .model import (
SparkXGBModelReader,
SparkXGBModelWriter,
SparkXGBReader,
SparkXGBWriter,
)
from .params import (
HasArbitraryParamsDict,
HasBaseMarginCol,
@ -71,8 +76,11 @@ from .utils import (
_get_rabit_args,
_get_spark_session,
_is_local,
deserialize_booster,
deserialize_xgb_model,
get_class_name,
get_logger,
serialize_booster,
)
# Put pyspark specific params here, they won't be passed to XGBoost.
@ -156,6 +164,8 @@ Pred = namedtuple(
)
pred = Pred("prediction", "rawPrediction", "probability", "predContrib")
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
class _SparkXGBParams(
HasFeaturesCol,
@ -1122,31 +1132,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
return dataset
class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRegressor
class SparkXGBRankerModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBRanker
class SparkXGBClassifierModel(
class _ClassificationModel( # pylint: disable=abstract-method
_SparkXGBModel, HasProbabilityCol, HasRawPredictionCol, HasContribPredictionCol
):
"""
@ -1155,10 +1141,6 @@ class SparkXGBClassifierModel(
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls):
return XGBClassifier
def _transform(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
@ -1286,53 +1268,178 @@ class SparkXGBClassifierModel(
return dataset.drop(pred_struct_col)
def _set_pyspark_xgb_cls_param_attrs(pyspark_estimator_class, pyspark_model_class):
params_dict = pyspark_estimator_class._get_xgb_params_default()
class _SparkXGBSharedReadWrite:
@staticmethod
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
"""
Save the metadata of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
"""
instance._validate_params()
skipParams = ["callbacks", "xgb_model"]
jsonParams = {}
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
if p.name not in skipParams:
jsonParams[p.name] = v
def param_value_converter(v):
if isinstance(v, np.generic):
# convert numpy scalar values to corresponding python scalar values
return np.array(v).item()
if isinstance(v, dict):
return {k: param_value_converter(nv) for k, nv in v.items()}
if isinstance(v, list):
return [param_value_converter(nv) for nv in v]
return v
extraMetadata = extraMetadata or {}
callbacks = instance.getOrDefault(instance.callbacks)
if callbacks is not None:
logger.warning(
"The callbacks parameter is saved using cloudpickle and it "
"is not a fully self-contained format. It may fail to load "
"with different versions of dependencies."
)
serialized_callbacks = base64.encodebytes(
cloudpickle.dumps(callbacks)
).decode("ascii")
extraMetadata["serialized_callbacks"] = serialized_callbacks
init_booster = instance.getOrDefault(instance.xgb_model)
if init_booster is not None:
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
DefaultParamsWriter.saveMetadata(
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
)
if init_booster is not None:
ser_init_booster = serialize_booster(init_booster)
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
_get_spark_session().createDataFrame(
[(ser_init_booster,)], ["init_booster"]
).write.parquet(save_path)
def set_param_attrs(attr_name, param_obj_):
param_obj_.typeConverter = param_value_converter
setattr(pyspark_estimator_class, attr_name, param_obj_)
setattr(pyspark_model_class, attr_name, param_obj_)
@staticmethod
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
"""
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
for name in params_dict.keys():
doc = (
f"Refer to XGBoost doc of "
f"{get_class_name(pyspark_estimator_class._xgb_cls())} for this param {name}"
:return: a tuple of (metadata, instance)
"""
metadata = DefaultParamsReader.loadMetadata(
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
)
pyspark_xgb = pyspark_xgb_cls()
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
if "serialized_callbacks" in metadata:
serialized_callbacks = metadata["serialized_callbacks"]
try:
callbacks = cloudpickle.loads(
base64.decodebytes(serialized_callbacks.encode("ascii"))
)
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
except Exception as e: # pylint: disable=W0703
logger.warning(
f"Fails to load the callbacks param due to {e}. Please set the "
"callbacks param manually for the loaded estimator."
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
if "init_booster" in metadata:
load_path = os.path.join(path, metadata["init_booster"])
ser_init_booster = (
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
)
init_booster = deserialize_booster(ser_init_booster)
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
fit_params_dict = pyspark_estimator_class._get_fit_params_default()
for name in fit_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
f".fit() for this param {name}"
)
if name == "callbacks":
doc += (
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
"which is not a fully self-contained format. It may fail to load with "
"different versions of dependencies."
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
return metadata, pyspark_xgb
predict_params_dict = pyspark_estimator_class._get_predict_params_default()
for name in predict_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(pyspark_estimator_class._xgb_cls())}"
f".predict() for this param {name}"
class SparkXGBWriter(MLWriter):
"""
Spark Xgboost estimator writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
save model.
"""
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
class SparkXGBReader(MLReader):
"""
Spark Xgboost estimator reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
load model.
"""
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
return pyspark_xgb
class SparkXGBModelWriter(MLWriter):
"""
Spark Xgboost model writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
Save metadata and model for a :py:class:`_SparkXGBModel`
- save metadata to path/metadata
- save model to path/model.json
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
)
class SparkXGBModelReader(MLReader):
"""
Spark Xgboost model reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
Load metadata and model for a :py:class:`_SparkXGBModel`
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
"""
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model")
ser_xgb_model = (
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
)
def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
py_model._xgb_sklearn_model = xgb_model
return py_model

View File

@ -1,18 +1,77 @@
"""Xgboost pyspark integration submodule for estimator API."""
# pylint: disable=too-many-ancestors
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
from typing import Any, Type
import numpy as np
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
from .core import ( # type: ignore
SparkXGBClassifierModel,
SparkXGBRankerModel,
SparkXGBRegressorModel,
_set_pyspark_xgb_cls_param_attrs,
_ClassificationModel,
_SparkXGBEstimator,
_SparkXGBModel,
)
from .utils import get_class_name
def _set_pyspark_xgb_cls_param_attrs(
estimator: _SparkXGBEstimator, model: _SparkXGBModel
) -> None:
"""This function automatically infer to xgboost parameters and set them
into corresponding pyspark estimators and models"""
params_dict = estimator._get_xgb_params_default()
def param_value_converter(v: Any) -> Any:
if isinstance(v, np.generic):
# convert numpy scalar values to corresponding python scalar values
return np.array(v).item()
if isinstance(v, dict):
return {k: param_value_converter(nv) for k, nv in v.items()}
if isinstance(v, list):
return [param_value_converter(nv) for nv in v]
return v
def set_param_attrs(attr_name: str, param: Param) -> None:
param.typeConverter = param_value_converter
setattr(estimator, attr_name, param)
setattr(model, attr_name, param)
for name in params_dict.keys():
doc = (
f"Refer to XGBoost doc of "
f"{get_class_name(estimator._xgb_cls())} for this param {name}"
)
param_obj: Param = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
fit_params_dict = estimator._get_fit_params_default()
for name in fit_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
f".fit() for this param {name}"
)
if name == "callbacks":
doc += (
"The callbacks can be arbitrary functions. It is saved using cloudpickle "
"which is not a fully self-contained format. It may fail to load with "
"different versions of dependencies."
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
predict_params_dict = estimator._get_predict_params_default()
for name in predict_params_dict.keys():
doc = (
f"Refer to XGBoost doc of {get_class_name(estimator._xgb_cls())}"
f".predict() for this param {name}"
)
param_obj = Param(Params._dummy(), name=name, doc=doc)
set_param_attrs(name, param_obj)
class SparkXGBRegressor(_SparkXGBEstimator):
@ -105,7 +164,7 @@ class SparkXGBRegressor(_SparkXGBEstimator):
return XGBRegressor
@classmethod
def _pyspark_model_cls(cls) -> Type[SparkXGBRegressorModel]:
def _pyspark_model_cls(cls) -> Type["SparkXGBRegressorModel"]:
return SparkXGBRegressorModel
def _validate_params(self) -> None:
@ -116,6 +175,18 @@ class SparkXGBRegressor(_SparkXGBEstimator):
)
class SparkXGBRegressorModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRegressor.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls) -> Type[XGBRegressor]:
return XGBRegressor
_set_pyspark_xgb_cls_param_attrs(SparkXGBRegressor, SparkXGBRegressorModel)
@ -224,7 +295,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
return XGBClassifier
@classmethod
def _pyspark_model_cls(cls) -> Type[SparkXGBClassifierModel]:
def _pyspark_model_cls(cls) -> Type["SparkXGBClassifierModel"]:
return SparkXGBClassifierModel
def _validate_params(self) -> None:
@ -239,6 +310,18 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
)
class SparkXGBClassifierModel(_ClassificationModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBClassifier.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls) -> Type[XGBClassifier]:
return XGBClassifier
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
@ -352,7 +435,7 @@ class SparkXGBRanker(_SparkXGBEstimator):
return XGBRanker
@classmethod
def _pyspark_model_cls(cls) -> Type[SparkXGBRankerModel]:
def _pyspark_model_cls(cls) -> Type["SparkXGBRankerModel"]:
return SparkXGBRankerModel
def _validate_params(self) -> None:
@ -363,4 +446,16 @@ class SparkXGBRanker(_SparkXGBEstimator):
)
class SparkXGBRankerModel(_SparkXGBModel):
"""
The model returned by :func:`xgboost.spark.SparkXGBRanker.fit`
.. Note:: This API is experimental.
"""
@classmethod
def _xgb_cls(cls) -> Type[XGBRanker]:
return XGBRanker
_set_pyspark_xgb_cls_param_attrs(SparkXGBRanker, SparkXGBRankerModel)

View File

@ -1,245 +0,0 @@
# type: ignore
"""Xgboost pyspark integration submodule for model API."""
# pylint: disable=fixme, invalid-name, protected-access, too-few-public-methods
import base64
import os
import uuid
from pyspark import SparkFiles, cloudpickle
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
from pyspark.sql import SparkSession
from xgboost.core import Booster
from .utils import get_class_name, get_logger
def _get_or_create_tmp_dir():
root_dir = SparkFiles.getRootDirectory()
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
if not os.path.exists(xgb_tmp_dir):
os.makedirs(xgb_tmp_dir)
return xgb_tmp_dir
def deserialize_xgb_model(model_string, xgb_model_creator):
"""
Deserialize an xgboost.XGBModel instance from the input model_string.
"""
xgb_model = xgb_model_creator()
xgb_model.load_model(bytearray(model_string.encode("utf-8")))
return xgb_model
def serialize_booster(booster):
"""
Serialize the input booster to a string.
Parameters
----------
booster:
an xgboost.core.Booster instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
booster.save_model(tmp_file_name)
with open(tmp_file_name, encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
def deserialize_booster(ser_model_string):
"""
Deserialize an xgboost.core.Booster from the input ser_model_string.
"""
booster = Booster()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(ser_model_string)
booster.load_model(tmp_file_name)
return booster
_INIT_BOOSTER_SAVE_PATH = "init_booster.json"
def _get_spark_session():
return SparkSession.builder.getOrCreate()
class _SparkXGBSharedReadWrite:
@staticmethod
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
"""
Save the metadata of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
"""
instance._validate_params()
skipParams = ["callbacks", "xgb_model"]
jsonParams = {}
for p, v in instance._paramMap.items(): # pylint: disable=protected-access
if p.name not in skipParams:
jsonParams[p.name] = v
extraMetadata = extraMetadata or {}
callbacks = instance.getOrDefault(instance.callbacks)
if callbacks is not None:
logger.warning(
"The callbacks parameter is saved using cloudpickle and it "
"is not a fully self-contained format. It may fail to load "
"with different versions of dependencies."
)
serialized_callbacks = base64.encodebytes(
cloudpickle.dumps(callbacks)
).decode("ascii")
extraMetadata["serialized_callbacks"] = serialized_callbacks
init_booster = instance.getOrDefault(instance.xgb_model)
if init_booster is not None:
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
DefaultParamsWriter.saveMetadata(
instance, path, sc, extraMetadata=extraMetadata, paramMap=jsonParams
)
if init_booster is not None:
ser_init_booster = serialize_booster(init_booster)
save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
_get_spark_session().createDataFrame(
[(ser_init_booster,)], ["init_booster"]
).write.parquet(save_path)
@staticmethod
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
"""
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel.
:return: a tuple of (metadata, instance)
"""
metadata = DefaultParamsReader.loadMetadata(
path, sc, expectedClassName=get_class_name(pyspark_xgb_cls)
)
pyspark_xgb = pyspark_xgb_cls()
DefaultParamsReader.getAndSetParams(pyspark_xgb, metadata)
if "serialized_callbacks" in metadata:
serialized_callbacks = metadata["serialized_callbacks"]
try:
callbacks = cloudpickle.loads(
base64.decodebytes(serialized_callbacks.encode("ascii"))
)
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
except Exception as e: # pylint: disable=W0703
logger.warning(
f"Fails to load the callbacks param due to {e}. Please set the "
"callbacks param manually for the loaded estimator."
)
if "init_booster" in metadata:
load_path = os.path.join(path, metadata["init_booster"])
ser_init_booster = (
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
)
init_booster = deserialize_booster(ser_init_booster)
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
return metadata, pyspark_xgb
class SparkXGBWriter(MLWriter):
"""
Spark Xgboost estimator writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
save model.
"""
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
class SparkXGBReader(MLReader):
"""
Spark Xgboost estimator reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
load model.
"""
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
return pyspark_xgb
class SparkXGBModelWriter(MLWriter):
"""
Spark Xgboost model writer.
"""
def __init__(self, instance):
super().__init__()
self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path):
"""
Save metadata and model for a :py:class:`_SparkXGBModel`
- save metadata to path/metadata
- save model to path/model.json
"""
xgb_model = self.instance._xgb_sklearn_model
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
model_save_path
)
class SparkXGBModelReader(MLReader):
"""
Spark Xgboost model reader.
"""
def __init__(self, cls):
super().__init__()
self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path):
"""
Load metadata and model for a :py:class:`_SparkXGBModel`
:return: SparkXGBRegressorModel or SparkXGBClassifierModel instance
"""
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger
)
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model")
ser_xgb_model = (
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
)
def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
py_model._xgb_sklearn_model = xgb_model
return py_model

View File

@ -1,15 +1,19 @@
"""Xgboost pyspark integration submodule for helper functions."""
# pylint: disable=fixme
import inspect
import logging
import os
import sys
import uuid
from threading import Thread
from typing import Any, Callable, Dict, Set, Type
import pyspark
from pyspark import BarrierTaskContext, SparkContext
from pyspark import BarrierTaskContext, SparkContext, SparkFiles
from pyspark.sql.session import SparkSession
from xgboost import collective
from xgboost import Booster, XGBModel, collective
from xgboost.tracker import RabitTracker
@ -133,3 +137,52 @@ def _get_gpu_id(task_context: BarrierTaskContext) -> int:
)
# return the first gpu id.
return int(resources["gpu"].addresses[0].strip())
def _get_or_create_tmp_dir() -> str:
root_dir = SparkFiles.getRootDirectory()
xgb_tmp_dir = os.path.join(root_dir, "xgboost-tmp")
if not os.path.exists(xgb_tmp_dir):
os.makedirs(xgb_tmp_dir)
return xgb_tmp_dir
def deserialize_xgb_model(
model: str, xgb_model_creator: Callable[[], XGBModel]
) -> XGBModel:
"""
Deserialize an xgboost.XGBModel instance from the input model.
"""
xgb_model = xgb_model_creator()
xgb_model.load_model(bytearray(model.encode("utf-8")))
return xgb_model
def serialize_booster(booster: Booster) -> str:
"""
Serialize the input booster to a string.
Parameters
----------
booster:
an xgboost.core.Booster instance
"""
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
booster.save_model(tmp_file_name)
with open(tmp_file_name, encoding="utf-8") as f:
ser_model_string = f.read()
return ser_model_string
def deserialize_booster(model: str) -> Booster:
"""
Deserialize an xgboost.core.Booster from the input ser_model_string.
"""
booster = Booster()
# TODO: change to use string io
tmp_file_name = os.path.join(_get_or_create_tmp_dir(), f"{uuid.uuid4()}.json")
with open(tmp_file_name, "w", encoding="utf-8") as f:
f.write(model)
booster.load_model(tmp_file_name)
return booster