Merge branch 'master' into sync-condition-2023May15

This commit is contained in:
amdsc21 2023-05-19 20:30:35 +02:00
commit 7663d47383
2 changed files with 178 additions and 115 deletions

View File

@ -1,17 +1,28 @@
# type: ignore
"""Xgboost pyspark integration submodule for core code.""" """Xgboost pyspark integration submodule for core code."""
import base64 import base64
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches # pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json import json
import logging
import os import os
from collections import namedtuple from collections import namedtuple
from typing import Iterator, List, Optional, Tuple from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pyspark import cloudpickle from pyspark import SparkContext, cloudpickle
from pyspark.ml import Estimator, Model from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT from pyspark.ml.linalg import VectorUDT
@ -33,7 +44,7 @@ from pyspark.ml.util import (
MLWritable, MLWritable,
MLWriter, MLWriter,
) )
from pyspark.sql import DataFrame from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import ( from pyspark.sql.types import (
ArrayType, ArrayType,
@ -50,7 +61,7 @@ import xgboost
from xgboost import XGBClassifier from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available from xgboost.compat import is_cudf_available
from xgboost.core import Booster from xgboost.core import Booster
from xgboost.sklearn import DEFAULT_N_ESTIMATORS from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel
from xgboost.training import train as worker_train from xgboost.training import train as worker_train
from .data import ( from .data import (
@ -191,6 +202,7 @@ class _SparkXGBParams(
"use_gpu", "use_gpu",
"A boolean variable. Set use_gpu=true if the executors " "A boolean variable. Set use_gpu=true if the executors "
+ "are running on GPU instances. Currently, only one GPU per task is supported.", + "are running on GPU instances. Currently, only one GPU per task is supported.",
TypeConverters.toBoolean,
) )
force_repartition = Param( force_repartition = Param(
Params._dummy(), Params._dummy(),
@ -199,19 +211,24 @@ class _SparkXGBParams(
+ "want to force the input dataset to be repartitioned before XGBoost training." + "want to force the input dataset to be repartitioned before XGBoost training."
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended" + "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
+ "to have force_repartition be True.", + "to have force_repartition be True.",
TypeConverters.toBoolean,
) )
repartition_random_shuffle = Param( repartition_random_shuffle = Param(
Params._dummy(), Params._dummy(),
"repartition_random_shuffle", "repartition_random_shuffle",
"A boolean variable. Set repartition_random_shuffle=true if you want to random shuffle " "A boolean variable. Set repartition_random_shuffle=true if you want to random shuffle "
"dataset when repartitioning is required. By default is True.", "dataset when repartitioning is required. By default is True.",
TypeConverters.toBoolean,
) )
feature_names = Param( feature_names = Param(
Params._dummy(), "feature_names", "A list of str to specify feature names." Params._dummy(),
"feature_names",
"A list of str to specify feature names.",
TypeConverters.toList,
) )
@classmethod @classmethod
def _xgb_cls(cls): def _xgb_cls(cls) -> Type[XGBModel]:
""" """
Subclasses should override this method and Subclasses should override this method and
returns an xgboost.XGBModel subclass returns an xgboost.XGBModel subclass
@ -220,7 +237,8 @@ class _SparkXGBParams(
# Parameters for xgboost.XGBModel() # Parameters for xgboost.XGBModel()
@classmethod @classmethod
def _get_xgb_params_default(cls): def _get_xgb_params_default(cls) -> Dict[str, Any]:
"""Get the xgboost.sklearn.XGBModel default parameters and filter out some"""
xgb_model_default = cls._xgb_cls()() xgb_model_default = cls._xgb_cls()()
params_dict = xgb_model_default.get_params() params_dict = xgb_model_default.get_params()
filtered_params_dict = { filtered_params_dict = {
@ -229,11 +247,15 @@ class _SparkXGBParams(
filtered_params_dict["n_estimators"] = DEFAULT_N_ESTIMATORS filtered_params_dict["n_estimators"] = DEFAULT_N_ESTIMATORS
return filtered_params_dict return filtered_params_dict
def _set_xgb_params_default(self): def _set_xgb_params_default(self) -> None:
"""Set xgboost parameters into spark parameters"""
filtered_params_dict = self._get_xgb_params_default() filtered_params_dict = self._get_xgb_params_default()
self._setDefault(**filtered_params_dict) self._setDefault(**filtered_params_dict)
def _gen_xgb_params_dict(self, gen_xgb_sklearn_estimator_param=False): def _gen_xgb_params_dict(
self, gen_xgb_sklearn_estimator_param: bool = False
) -> Dict[str, Any]:
"""Generate the xgboost parameters which will be passed into xgboost library"""
xgb_params = {} xgb_params = {}
non_xgb_params = ( non_xgb_params = (
set(_pyspark_specific_params) set(_pyspark_specific_params)
@ -254,20 +276,20 @@ class _SparkXGBParams(
# Parameters for xgboost.XGBModel().fit() # Parameters for xgboost.XGBModel().fit()
@classmethod @classmethod
def _get_fit_params_default(cls): def _get_fit_params_default(cls) -> Dict[str, Any]:
"""Get the xgboost.XGBModel().fit() parameters"""
fit_params = _get_default_params_from_func( fit_params = _get_default_params_from_func(
cls._xgb_cls().fit, _unsupported_fit_params cls._xgb_cls().fit, _unsupported_fit_params
) )
return fit_params return fit_params
def _set_fit_params_default(self): def _set_fit_params_default(self) -> None:
"""Get the xgboost.XGBModel().fit() parameters and set them to spark parameters"""
filtered_params_dict = self._get_fit_params_default() filtered_params_dict = self._get_fit_params_default()
self._setDefault(**filtered_params_dict) self._setDefault(**filtered_params_dict)
def _gen_fit_params_dict(self): def _gen_fit_params_dict(self) -> Dict[str, Any]:
""" """Generate the fit parameters which will be passed into fit function"""
Returns a dict of params for .fit()
"""
fit_params_keys = self._get_fit_params_default().keys() fit_params_keys = self._get_fit_params_default().keys()
fit_params = {} fit_params = {}
for param in self.extractParamMap(): for param in self.extractParamMap():
@ -275,22 +297,22 @@ class _SparkXGBParams(
fit_params[param.name] = self.getOrDefault(param) fit_params[param.name] = self.getOrDefault(param)
return fit_params return fit_params
# Parameters for xgboost.XGBModel().predict()
@classmethod @classmethod
def _get_predict_params_default(cls): def _get_predict_params_default(cls) -> Dict[str, Any]:
"""Get the parameters from xgboost.XGBModel().predict()"""
predict_params = _get_default_params_from_func( predict_params = _get_default_params_from_func(
cls._xgb_cls().predict, _unsupported_predict_params cls._xgb_cls().predict, _unsupported_predict_params
) )
return predict_params return predict_params
def _set_predict_params_default(self): def _set_predict_params_default(self) -> None:
"""Get the parameters from xgboost.XGBModel().predict() and
set them into spark parameters"""
filtered_params_dict = self._get_predict_params_default() filtered_params_dict = self._get_predict_params_default()
self._setDefault(**filtered_params_dict) self._setDefault(**filtered_params_dict)
def _gen_predict_params_dict(self): def _gen_predict_params_dict(self) -> Dict[str, Any]:
""" """Generate predict parameters which will be passed into xgboost.XGBModel().predict()"""
Returns a dict of params for .predict()
"""
predict_params_keys = self._get_predict_params_default().keys() predict_params_keys = self._get_predict_params_default().keys()
predict_params = {} predict_params = {}
for param in self.extractParamMap(): for param in self.extractParamMap():
@ -298,9 +320,9 @@ class _SparkXGBParams(
predict_params[param.name] = self.getOrDefault(param) predict_params[param.name] = self.getOrDefault(param)
return predict_params return predict_params
def _validate_params(self): def _validate_params(self) -> None:
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
init_model = self.getOrDefault(self.xgb_model) init_model = self.getOrDefault("xgb_model")
if init_model is not None and not isinstance(init_model, Booster): if init_model is not None and not isinstance(init_model, Booster):
raise ValueError( raise ValueError(
"The xgb_model param must be set with a `xgboost.core.Booster` " "The xgb_model param must be set with a `xgboost.core.Booster` "
@ -321,18 +343,19 @@ class _SparkXGBParams(
"If features_cols param set, then features_col param is ignored." "If features_cols param set, then features_col param is ignored."
) )
if self.getOrDefault(self.objective) is not None: if self.getOrDefault("objective") is not None:
if not isinstance(self.getOrDefault(self.objective), str): if not isinstance(self.getOrDefault("objective"), str):
raise ValueError("Only string type 'objective' param is allowed.") raise ValueError("Only string type 'objective' param is allowed.")
if self.getOrDefault(self.eval_metric) is not None: eval_metric = "eval_metric"
if self.getOrDefault(eval_metric) is not None:
if not ( if not (
isinstance(self.getOrDefault(self.eval_metric), str) isinstance(self.getOrDefault(eval_metric), str)
or ( or (
isinstance(self.getOrDefault(self.eval_metric), List) isinstance(self.getOrDefault(eval_metric), List)
and all( and all(
isinstance(metric, str) isinstance(metric, str)
for metric in self.getOrDefault(self.eval_metric) for metric in self.getOrDefault(eval_metric)
) )
) )
): ):
@ -340,10 +363,10 @@ class _SparkXGBParams(
"Only string type or list of string type 'eval_metric' param is allowed." "Only string type or list of string type 'eval_metric' param is allowed."
) )
if self.getOrDefault(self.early_stopping_rounds) is not None: if self.getOrDefault("early_stopping_rounds") is not None:
if not ( if not (
self.isDefined(self.validationIndicatorCol) self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol) and self.getOrDefault(self.validationIndicatorCol) != ""
): ):
raise ValueError( raise ValueError(
"If 'early_stopping_rounds' param is set, you need to set " "If 'early_stopping_rounds' param is set, you need to set "
@ -351,7 +374,7 @@ class _SparkXGBParams(
) )
if self.getOrDefault(self.enable_sparse_data_optim): if self.getOrDefault(self.enable_sparse_data_optim):
if self.getOrDefault(self.missing) != 0.0: if self.getOrDefault("missing") != 0.0:
# If DMatrix is constructed from csr / csc matrix, then inactive elements # If DMatrix is constructed from csr / csc matrix, then inactive elements
# in csr / csc matrix are regarded as missing value, but, in pyspark, we # in csr / csc matrix are regarded as missing value, but, in pyspark, we
# are hard to control elements to be active or inactive in sparse vector column, # are hard to control elements to be active or inactive in sparse vector column,
@ -424,8 +447,8 @@ class _SparkXGBParams(
def _validate_and_convert_feature_col_as_float_col_list( def _validate_and_convert_feature_col_as_float_col_list(
dataset, features_col_names: list dataset: DataFrame, features_col_names: List[str]
) -> list: ) -> List[Column]:
"""Values in feature columns must be integral types or float/double types""" """Values in feature columns must be integral types or float/double types"""
feature_cols = [] feature_cols = []
for c in features_col_names: for c in features_col_names:
@ -440,7 +463,12 @@ def _validate_and_convert_feature_col_as_float_col_list(
return feature_cols return feature_cols
def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name): def _validate_and_convert_feature_col_as_array_col(
dataset: DataFrame, features_col_name: str
) -> Column:
"""It handles
1. Convert vector type to array type
2. Cast to Array(Float32)"""
features_col_datatype = dataset.schema[features_col_name].dataType features_col_datatype = dataset.schema[features_col_name].dataType
features_col = col(features_col_name) features_col = col(features_col_name)
if isinstance(features_col_datatype, ArrayType): if isinstance(features_col_datatype, ArrayType):
@ -466,7 +494,7 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
return features_array_col return features_array_col
def _get_unwrap_udt_fn(): def _get_unwrap_udt_fn() -> Callable[[Union[Column, str]], Column]:
try: try:
from pyspark.sql.functions import unwrap_udt from pyspark.sql.functions import unwrap_udt
@ -475,9 +503,9 @@ def _get_unwrap_udt_fn():
pass pass
try: try:
from pyspark.databricks.sql.functions import unwrap_udt from pyspark.databricks.sql.functions import unwrap_udt as databricks_unwrap_udt
return unwrap_udt return databricks_unwrap_udt
except ImportError as exc: except ImportError as exc:
raise RuntimeError( raise RuntimeError(
"Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 " "Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 "
@ -485,7 +513,7 @@ def _get_unwrap_udt_fn():
) from exc ) from exc
def _get_unwrapped_vec_cols(feature_col): def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
unwrap_udt = _get_unwrap_udt_fn() unwrap_udt = _get_unwrap_udt_fn()
features_unwrapped_vec_col = unwrap_udt(feature_col) features_unwrapped_vec_col = unwrap_udt(feature_col)
@ -519,7 +547,7 @@ FeatureProp = namedtuple(
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._set_xgb_params_default() self._set_xgb_params_default()
self._set_fit_params_default() self._set_fit_params_default()
@ -537,7 +565,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
arbitrary_params_dict={}, arbitrary_params_dict={},
) )
def setParams(self, **kwargs): # pylint: disable=invalid-name def setParams(
self, **kwargs: Dict[str, Any]
) -> None: # pylint: disable=invalid-name
""" """
Set params for the estimator. Set params for the estimator.
""" """
@ -578,17 +608,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params}) self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
@classmethod @classmethod
def _pyspark_model_cls(cls): def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
""" """
Subclasses should override this method and Subclasses should override this method and
returns a _SparkXGBModel subclass returns a _SparkXGBModel subclass
""" """
raise NotImplementedError() raise NotImplementedError()
def _create_pyspark_model(self, xgb_model): def _create_pyspark_model(self, xgb_model: XGBModel) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model) return self._pyspark_model_cls()(xgb_model)
def _convert_to_sklearn_model(self, booster: bytearray, config: str): def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel:
xgb_sklearn_params = self._gen_xgb_params_dict( xgb_sklearn_params = self._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True gen_xgb_sklearn_estimator_param=True
) )
@ -597,7 +627,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
sklearn_model._Booster.load_config(config) sklearn_model._Booster.load_config(config)
return sklearn_model return sklearn_model
def _query_plan_contains_valid_repartition(self, dataset): def _query_plan_contains_valid_repartition(self, dataset: DataFrame) -> bool:
""" """
Returns true if the latest element in the logical plan is a valid repartition Returns true if the latest element in the logical plan is a valid repartition
The logic plan string format is like: The logic plan string format is like:
@ -613,6 +643,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
repartition the dataframe again. repartition the dataframe again.
""" """
num_partitions = dataset.rdd.getNumPartitions() num_partitions = dataset.rdd.getNumPartitions()
assert dataset._sc._jvm is not None
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString( query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
dataset._jdf.queryExecution(), "extended" dataset._jdf.queryExecution(), "extended"
) )
@ -626,7 +657,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
return True return True
return False return False
def _repartition_needed(self, dataset): def _repartition_needed(self, dataset: DataFrame) -> bool:
""" """
We repartition the dataset if the number of workers is not equal to the number of We repartition the dataset if the number of workers is not equal to the number of
partitions. There is also a check to make sure there was "active partitioning" partitions. There is also a check to make sure there was "active partitioning"
@ -641,7 +672,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
pass pass
return True return True
def _get_distributed_train_params(self, dataset): def _get_distributed_train_params(self, dataset: DataFrame) -> Dict[str, Any]:
""" """
This just gets the configuration params for distributed xgboost This just gets the configuration params for distributed xgboost
""" """
@ -664,10 +695,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
else: else:
# use user specified objective or default objective. # use user specified objective or default objective.
# e.g., the default objective for Regressor is 'reg:squarederror' # e.g., the default objective for Regressor is 'reg:squarederror'
params["objective"] = self.getOrDefault(self.objective) params["objective"] = self.getOrDefault("objective")
# TODO: support "num_parallel_tree" for random forest # TODO: support "num_parallel_tree" for random forest
params["num_boost_round"] = self.getOrDefault(self.n_estimators) params["num_boost_round"] = self.getOrDefault("n_estimators")
if self.getOrDefault(self.use_gpu): if self.getOrDefault(self.use_gpu):
params["tree_method"] = "gpu_hist" params["tree_method"] = "gpu_hist"
@ -675,7 +706,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
return params return params
@classmethod @classmethod
def _get_xgb_train_call_args(cls, train_params): def _get_xgb_train_call_args(
cls, train_params: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
xgb_train_default_args = _get_default_params_from_func( xgb_train_default_args = _get_default_params_from_func(
xgboost.train, _unsupported_train_params xgboost.train, _unsupported_train_params
) )
@ -693,7 +726,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def _prepare_input_columns_and_feature_prop( def _prepare_input_columns_and_feature_prop(
self, dataset: DataFrame self, dataset: DataFrame
) -> Tuple[List[str], FeatureProp]: ) -> Tuple[List[Column], FeatureProp]:
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label) label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
select_cols = [label_col] select_cols = [label_col]
@ -721,14 +754,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
) )
select_cols.append(features_array_col) select_cols.append(features_array_col)
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol): if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol) != "":
select_cols.append( select_cols.append(
col(self.getOrDefault(self.weightCol)).alias(alias.weight) col(self.getOrDefault(self.weightCol)).alias(alias.weight)
) )
has_validation_col = False has_validation_col = False
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault( if (
self.validationIndicatorCol self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol) != ""
): ):
select_cols.append( select_cols.append(
col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid) col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid)
@ -738,14 +772,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# which will cause exception or hanging issue when creating DMatrix. # which will cause exception or hanging issue when creating DMatrix.
has_validation_col = True has_validation_col = True
if self.isDefined(self.base_margin_col) and self.getOrDefault( if (
self.base_margin_col self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
): ):
select_cols.append( select_cols.append(
col(self.getOrDefault(self.base_margin_col)).alias(alias.margin) col(self.getOrDefault(self.base_margin_col)).alias(alias.margin)
) )
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col): if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "":
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid)) select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
feature_prop = FeatureProp( feature_prop = FeatureProp(
@ -777,7 +812,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
if self._repartition_needed(dataset) or ( if self._repartition_needed(dataset) or (
self.isDefined(self.validationIndicatorCol) self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol) and self.getOrDefault(self.validationIndicatorCol) != ""
): ):
# If validationIndicatorCol defined, we always repartition dataset # If validationIndicatorCol defined, we always repartition dataset
# to balance data, because user might unionise train and validation dataset, # to balance data, because user might unionise train and validation dataset,
@ -790,13 +825,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
else: else:
dataset = dataset.repartition(num_workers) dataset = dataset.repartition(num_workers)
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col): if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "":
# XGBoost requires qid to be sorted for each partition # XGBoost requires qid to be sorted for each partition
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True) dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
return dataset, feature_prop return dataset, feature_prop
def _get_xgb_parameters(self, dataset: DataFrame): def _get_xgb_parameters(
self, dataset: DataFrame
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
train_params = self._get_distributed_train_params(dataset) train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args( booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params train_params
@ -807,10 +844,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dmatrix_kwargs = { dmatrix_kwargs = {
"nthread": cpu_per_task, "nthread": cpu_per_task,
"feature_types": self.getOrDefault(self.feature_types), "feature_types": self.getOrDefault("feature_types"),
"feature_names": self.getOrDefault(self.feature_names), "feature_names": self.getOrDefault("feature_names"),
"feature_weights": self.getOrDefault(self.feature_weights), "feature_weights": self.getOrDefault("feature_weights"),
"missing": float(self.getOrDefault(self.missing)), "missing": float(self.getOrDefault("missing")),
} }
if dmatrix_kwargs["feature_types"] is not None: if dmatrix_kwargs["feature_types"] is not None:
dmatrix_kwargs["enable_categorical"] = True dmatrix_kwargs["enable_categorical"] = True
@ -825,7 +862,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
return booster_params, train_call_kwargs_params, dmatrix_kwargs return booster_params, train_call_kwargs_params, dmatrix_kwargs
def _fit(self, dataset): def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
# pylint: disable=too-many-statements, too-many-locals # pylint: disable=too-many-statements, too-many-locals
self._validate_params() self._validate_params()
@ -843,7 +880,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
num_workers = self.getOrDefault(self.num_workers) num_workers = self.getOrDefault(self.num_workers)
def _train_booster(pandas_df_iter): def _train_booster(
pandas_df_iter: Iterator[pd.DataFrame],
) -> Iterator[pd.DataFrame]:
"""Takes in an RDD partition and outputs a booster for that partition after """Takes in an RDD partition and outputs a booster for that partition after
going through the Rabit Ring protocol going through the Rabit Ring protocol
@ -893,7 +932,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
_rabit_args = json.loads(messages[0])["rabit_msg"] _rabit_args = json.loads(messages[0])["rabit_msg"]
evals_result = {} evals_result: Dict[str, Any] = {}
with CommunicatorContext(context, **_rabit_args): with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions( dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter, pandas_df_iter,
@ -925,10 +964,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
} }
) )
def _run_job(): def _run_job() -> Tuple[str, str]:
ret = ( ret = (
dataset.mapInPandas( dataset.mapInPandas(
_train_booster, schema="config string, booster string" _train_booster, schema="config string, booster string" # type: ignore
) )
.rdd.barrier() .rdd.barrier()
.mapPartitions(lambda x: x) .mapPartitions(lambda x: x)
@ -947,14 +986,14 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
spark_model._resetUid(self.uid) spark_model._resetUid(self.uid)
return self._copyValues(spark_model) return self._copyValues(spark_model)
def write(self): def write(self) -> "SparkXGBWriter":
""" """
Return the writer for saving the estimator. Return the writer for saving the estimator.
""" """
return SparkXGBWriter(self) return SparkXGBWriter(self)
@classmethod @classmethod
def read(cls): def read(cls) -> "SparkXGBReader":
""" """
Return the reader for loading the estimator. Return the reader for loading the estimator.
""" """
@ -962,21 +1001,24 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self, xgb_sklearn_model=None): def __init__(self, xgb_sklearn_model: Optional[XGBModel] = None) -> None:
super().__init__() super().__init__()
self._xgb_sklearn_model = xgb_sklearn_model self._xgb_sklearn_model = xgb_sklearn_model
@classmethod @classmethod
def _xgb_cls(cls): def _xgb_cls(cls) -> Type[XGBModel]:
raise NotImplementedError() raise NotImplementedError()
def get_booster(self): def get_booster(self) -> Booster:
""" """
Return the `xgboost.core.Booster` instance. Return the `xgboost.core.Booster` instance.
""" """
assert self._xgb_sklearn_model is not None
return self._xgb_sklearn_model.get_booster() return self._xgb_sklearn_model.get_booster()
def get_feature_importances(self, importance_type="weight"): def get_feature_importances(
self, importance_type: str = "weight"
) -> Dict[str, Union[float, List[float]]]:
"""Get feature importance of each feature. """Get feature importance of each feature.
Importance type can be defined as: Importance type can be defined as:
@ -993,20 +1035,22 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
""" """
return self.get_booster().get_score(importance_type=importance_type) return self.get_booster().get_score(importance_type=importance_type)
def write(self): def write(self) -> "SparkXGBModelWriter":
""" """
Return the writer for saving the model. Return the writer for saving the model.
""" """
return SparkXGBModelWriter(self) return SparkXGBModelWriter(self)
@classmethod @classmethod
def read(cls): def read(cls) -> "SparkXGBModelReader":
""" """
Return the reader for loading the model. Return the reader for loading the model.
""" """
return SparkXGBModelReader(cls) return SparkXGBModelReader(cls)
def _get_feature_col(self, dataset) -> (list, Optional[list]): def _get_feature_col(
self, dataset: DataFrame
) -> Tuple[List[Column], Optional[List[str]]]:
"""XGBoost model trained with features_cols parameter can also predict """XGBoost model trained with features_cols parameter can also predict
vector or array feature type. But first we need to check features_cols vector or array feature type. But first we need to check features_cols
and then featuresCol and then featuresCol
@ -1040,7 +1084,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
) )
return features_col, feature_col_names return features_col, feature_col_names
def _transform(self, dataset): def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals # pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable # Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote. # to avoid the `self` object to be pickled to remote.
@ -1048,8 +1092,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
predict_params = self._gen_predict_params_dict() predict_params = self._gen_predict_params_dict()
has_base_margin = False has_base_margin = False
if self.isDefined(self.base_margin_col) and self.getOrDefault( if (
self.base_margin_col self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
): ):
has_base_margin = True has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias( base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
@ -1060,8 +1105,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim) enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
pred_contrib_col_name = None pred_contrib_col_name = None
if self.isDefined(self.pred_contrib_col) and self.getOrDefault( if (
self.pred_contrib_col self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
): ):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col) pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
@ -1071,8 +1117,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
single_pred = False single_pred = False
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>" schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
@pandas_udf(schema) @pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
assert xgb_sklearn_model is not None
model = xgb_sklearn_model model = xgb_sklearn_model
for data in iterator: for data in iterator:
if enable_sparse_data_optim: if enable_sparse_data_optim:
@ -1141,7 +1188,7 @@ class _ClassificationModel( # pylint: disable=abstract-method
.. Note:: This API is experimental. .. Note:: This API is experimental.
""" """
def _transform(self, dataset): def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals # pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable # Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote. # to avoid the `self` object to be pickled to remote.
@ -1149,15 +1196,16 @@ class _ClassificationModel( # pylint: disable=abstract-method
predict_params = self._gen_predict_params_dict() predict_params = self._gen_predict_params_dict()
has_base_margin = False has_base_margin = False
if self.isDefined(self.base_margin_col) and self.getOrDefault( if (
self.base_margin_col self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
): ):
has_base_margin = True has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias( base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
alias.margin alias.margin
) )
def transform_margin(margins: np.ndarray): def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if margins.ndim == 1: if margins.ndim == 1:
# binomial case # binomial case
classone_probs = expit(margins) classone_probs = expit(margins)
@ -1174,8 +1222,9 @@ class _ClassificationModel( # pylint: disable=abstract-method
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim) enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
pred_contrib_col_name = None pred_contrib_col_name = None
if self.isDefined(self.pred_contrib_col) and self.getOrDefault( if (
self.pred_contrib_col self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
): ):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col) pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
@ -1188,17 +1237,18 @@ class _ClassificationModel( # pylint: disable=abstract-method
# So, it will also output 3-D shape result. # So, it will also output 3-D shape result.
schema = f"{schema}, {pred.pred_contrib} array<array<double>>" schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
@pandas_udf(schema) @pandas_udf(schema) # type: ignore
def predict_udf( def predict_udf(
iterator: Iterator[Tuple[pd.Series, ...]] iterator: Iterator[Tuple[pd.Series, ...]]
) -> Iterator[pd.DataFrame]: ) -> Iterator[pd.DataFrame]:
assert xgb_sklearn_model is not None
model = xgb_sklearn_model model = xgb_sklearn_model
for data in iterator: for data in iterator:
if enable_sparse_data_optim: if enable_sparse_data_optim:
X = _read_csr_matrix_from_unwrapped_spark_vec(data) X = _read_csr_matrix_from_unwrapped_spark_vec(data)
else: else:
if feature_col_names is not None: if feature_col_names is not None:
X = data[feature_col_names] X = data[feature_col_names] # type: ignore
else: else:
X = stack_series(data[alias.data]) X = stack_series(data[alias.data])
@ -1219,7 +1269,7 @@ class _ClassificationModel( # pylint: disable=abstract-method
# It seems that they use argmax of class probs, # It seems that they use argmax of class probs,
# not of margin to get the prediction (Note: scala implementation) # not of margin to get the prediction (Note: scala implementation)
preds = np.argmax(class_probs, axis=1) preds = np.argmax(class_probs, axis=1)
data = { result: Dict[str, pd.Series] = {
pred.raw_prediction: pd.Series(list(raw_preds)), pred.raw_prediction: pd.Series(list(raw_preds)),
pred.prediction: pd.Series(preds), pred.prediction: pd.Series(preds),
pred.probability: pd.Series(list(class_probs)), pred.probability: pd.Series(list(class_probs)),
@ -1227,9 +1277,9 @@ class _ClassificationModel( # pylint: disable=abstract-method
if pred_contrib_col_name: if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin, strict_shape=True) contribs = pred_contribs(model, X, base_margin, strict_shape=True)
data[pred.pred_contrib] = pd.Series(list(contribs.tolist())) result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
yield pd.DataFrame(data=data) yield pd.DataFrame(data=result)
if has_base_margin: if has_base_margin:
pred_struct = predict_udf(struct(*features_col, base_margin_col)) pred_struct = predict_udf(struct(*features_col, base_margin_col))
@ -1270,7 +1320,13 @@ class _ClassificationModel( # pylint: disable=abstract-method
class _SparkXGBSharedReadWrite: class _SparkXGBSharedReadWrite:
@staticmethod @staticmethod
def saveMetadata(instance, path, sc, logger, extraMetadata=None): def saveMetadata(
instance: Union[_SparkXGBEstimator, _SparkXGBModel],
path: str,
sc: SparkContext,
logger: logging.Logger,
extraMetadata: Optional[Dict[str, Any]] = None,
) -> None:
""" """
Save the metadata of an xgboost.spark._SparkXGBEstimator or Save the metadata of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel. xgboost.spark._SparkXGBModel.
@ -1283,7 +1339,7 @@ class _SparkXGBSharedReadWrite:
jsonParams[p.name] = v jsonParams[p.name] = v
extraMetadata = extraMetadata or {} extraMetadata = extraMetadata or {}
callbacks = instance.getOrDefault(instance.callbacks) callbacks = instance.getOrDefault("callbacks")
if callbacks is not None: if callbacks is not None:
logger.warning( logger.warning(
"The callbacks parameter is saved using cloudpickle and it " "The callbacks parameter is saved using cloudpickle and it "
@ -1294,7 +1350,7 @@ class _SparkXGBSharedReadWrite:
cloudpickle.dumps(callbacks) cloudpickle.dumps(callbacks)
).decode("ascii") ).decode("ascii")
extraMetadata["serialized_callbacks"] = serialized_callbacks extraMetadata["serialized_callbacks"] = serialized_callbacks
init_booster = instance.getOrDefault(instance.xgb_model) init_booster = instance.getOrDefault("xgb_model")
if init_booster is not None: if init_booster is not None:
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
DefaultParamsWriter.saveMetadata( DefaultParamsWriter.saveMetadata(
@ -1308,7 +1364,12 @@ class _SparkXGBSharedReadWrite:
).write.parquet(save_path) ).write.parquet(save_path)
@staticmethod @staticmethod
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger): def loadMetadataAndInstance(
pyspark_xgb_cls: Union[Type[_SparkXGBEstimator], Type[_SparkXGBModel]],
path: str,
sc: SparkContext,
logger: logging.Logger,
) -> Tuple[Dict[str, Any], Union[_SparkXGBEstimator, _SparkXGBModel]]:
""" """
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
xgboost.spark._SparkXGBModel. xgboost.spark._SparkXGBModel.
@ -1327,7 +1388,7 @@ class _SparkXGBSharedReadWrite:
callbacks = cloudpickle.loads( callbacks = cloudpickle.loads(
base64.decodebytes(serialized_callbacks.encode("ascii")) base64.decodebytes(serialized_callbacks.encode("ascii"))
) )
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) # type: ignore
except Exception as e: # pylint: disable=W0703 except Exception as e: # pylint: disable=W0703
logger.warning( logger.warning(
f"Fails to load the callbacks param due to {e}. Please set the " f"Fails to load the callbacks param due to {e}. Please set the "
@ -1340,7 +1401,7 @@ class _SparkXGBSharedReadWrite:
_get_spark_session().read.parquet(load_path).collect()[0].init_booster _get_spark_session().read.parquet(load_path).collect()[0].init_booster
) )
init_booster = deserialize_booster(ser_init_booster) init_booster = deserialize_booster(ser_init_booster)
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) # type: ignore
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
return metadata, pyspark_xgb return metadata, pyspark_xgb
@ -1351,12 +1412,12 @@ class SparkXGBWriter(MLWriter):
Spark Xgboost estimator writer. Spark Xgboost estimator writer.
""" """
def __init__(self, instance): def __init__(self, instance: "_SparkXGBEstimator") -> None:
super().__init__() super().__init__()
self.instance = instance self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN") self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path): def saveImpl(self, path: str) -> None:
""" """
save model. save model.
""" """
@ -1368,19 +1429,19 @@ class SparkXGBReader(MLReader):
Spark Xgboost estimator reader. Spark Xgboost estimator reader.
""" """
def __init__(self, cls): def __init__(self, cls: Type["_SparkXGBEstimator"]) -> None:
super().__init__() super().__init__()
self.cls = cls self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN") self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path): def load(self, path: str) -> "_SparkXGBEstimator":
""" """
load model. load model.
""" """
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance( _, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger self.cls, path, self.sc, self.logger
) )
return pyspark_xgb return cast("_SparkXGBEstimator", pyspark_xgb)
class SparkXGBModelWriter(MLWriter): class SparkXGBModelWriter(MLWriter):
@ -1388,18 +1449,19 @@ class SparkXGBModelWriter(MLWriter):
Spark Xgboost model writer. Spark Xgboost model writer.
""" """
def __init__(self, instance): def __init__(self, instance: _SparkXGBModel) -> None:
super().__init__() super().__init__()
self.instance = instance self.instance = instance
self.logger = get_logger(self.__class__.__name__, level="WARN") self.logger = get_logger(self.__class__.__name__, level="WARN")
def saveImpl(self, path): def saveImpl(self, path: str) -> None:
""" """
Save metadata and model for a :py:class:`_SparkXGBModel` Save metadata and model for a :py:class:`_SparkXGBModel`
- save metadata to path/metadata - save metadata to path/metadata
- save model to path/model.json - save model to path/model.json
""" """
xgb_model = self.instance._xgb_sklearn_model xgb_model = self.instance._xgb_sklearn_model
assert xgb_model is not None
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger) _SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model") model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8") booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
@ -1413,12 +1475,12 @@ class SparkXGBModelReader(MLReader):
Spark Xgboost model reader. Spark Xgboost model reader.
""" """
def __init__(self, cls): def __init__(self, cls: Type["_SparkXGBModel"]) -> None:
super().__init__() super().__init__()
self.cls = cls self.cls = cls
self.logger = get_logger(self.__class__.__name__, level="WARN") self.logger = get_logger(self.__class__.__name__, level="WARN")
def load(self, path): def load(self, path: str) -> "_SparkXGBModel":
""" """
Load metadata and model for a :py:class:`_SparkXGBModel` Load metadata and model for a :py:class:`_SparkXGBModel`
@ -1427,6 +1489,7 @@ class SparkXGBModelReader(MLReader):
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance( _, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
self.cls, path, self.sc, self.logger self.cls, path, self.sc, self.logger
) )
py_model = cast("_SparkXGBModel", py_model)
xgb_sklearn_params = py_model._gen_xgb_params_dict( xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True gen_xgb_sklearn_estimator_param=True
@ -1437,7 +1500,7 @@ class SparkXGBModelReader(MLReader):
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0] _get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
) )
def create_xgb_model(): def create_xgb_model() -> "XGBModel":
return self.cls._xgb_cls()(**xgb_sklearn_params) return self.cls._xgb_cls()(**xgb_sklearn_params)
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model) xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)

View File

@ -19,7 +19,7 @@ from .utils import get_class_name
def _set_pyspark_xgb_cls_param_attrs( def _set_pyspark_xgb_cls_param_attrs(
estimator: _SparkXGBEstimator, model: _SparkXGBModel estimator: Type[_SparkXGBEstimator], model: Type[_SparkXGBModel]
) -> None: ) -> None:
"""This function automatically infer to xgboost parameters and set them """This function automatically infer to xgboost parameters and set them
into corresponding pyspark estimators and models""" into corresponding pyspark estimators and models"""
@ -304,7 +304,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
raise ValueError( raise ValueError(
"Spark Xgboost classifier estimator does not support `qid_col` param." "Spark Xgboost classifier estimator does not support `qid_col` param."
) )
if self.getOrDefault(self.objective): # pylint: disable=no-member if self.getOrDefault("objective"): # pylint: disable=no-member
raise ValueError( raise ValueError(
"Setting custom 'objective' param is not allowed in 'SparkXGBClassifier'." "Setting custom 'objective' param is not allowed in 'SparkXGBClassifier'."
) )