Merge branch 'master' into sync-condition-2023May15
This commit is contained in:
commit
7663d47383
@ -1,17 +1,28 @@
|
||||
# type: ignore
|
||||
"""Xgboost pyspark integration submodule for core code."""
|
||||
import base64
|
||||
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyspark import cloudpickle
|
||||
from pyspark import SparkContext, cloudpickle
|
||||
from pyspark.ml import Estimator, Model
|
||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||
from pyspark.ml.linalg import VectorUDT
|
||||
@ -33,7 +44,7 @@ from pyspark.ml.util import (
|
||||
MLWritable,
|
||||
MLWriter,
|
||||
)
|
||||
from pyspark.sql import DataFrame
|
||||
from pyspark.sql import Column, DataFrame
|
||||
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
|
||||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
@ -50,7 +61,7 @@ import xgboost
|
||||
from xgboost import XGBClassifier
|
||||
from xgboost.compat import is_cudf_available
|
||||
from xgboost.core import Booster
|
||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS
|
||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .data import (
|
||||
@ -191,6 +202,7 @@ class _SparkXGBParams(
|
||||
"use_gpu",
|
||||
"A boolean variable. Set use_gpu=true if the executors "
|
||||
+ "are running on GPU instances. Currently, only one GPU per task is supported.",
|
||||
TypeConverters.toBoolean,
|
||||
)
|
||||
force_repartition = Param(
|
||||
Params._dummy(),
|
||||
@ -199,19 +211,24 @@ class _SparkXGBParams(
|
||||
+ "want to force the input dataset to be repartitioned before XGBoost training."
|
||||
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
|
||||
+ "to have force_repartition be True.",
|
||||
TypeConverters.toBoolean,
|
||||
)
|
||||
repartition_random_shuffle = Param(
|
||||
Params._dummy(),
|
||||
"repartition_random_shuffle",
|
||||
"A boolean variable. Set repartition_random_shuffle=true if you want to random shuffle "
|
||||
"dataset when repartitioning is required. By default is True.",
|
||||
TypeConverters.toBoolean,
|
||||
)
|
||||
feature_names = Param(
|
||||
Params._dummy(), "feature_names", "A list of str to specify feature names."
|
||||
Params._dummy(),
|
||||
"feature_names",
|
||||
"A list of str to specify feature names.",
|
||||
TypeConverters.toList,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
def _xgb_cls(cls) -> Type[XGBModel]:
|
||||
"""
|
||||
Subclasses should override this method and
|
||||
returns an xgboost.XGBModel subclass
|
||||
@ -220,7 +237,8 @@ class _SparkXGBParams(
|
||||
|
||||
# Parameters for xgboost.XGBModel()
|
||||
@classmethod
|
||||
def _get_xgb_params_default(cls):
|
||||
def _get_xgb_params_default(cls) -> Dict[str, Any]:
|
||||
"""Get the xgboost.sklearn.XGBModel default parameters and filter out some"""
|
||||
xgb_model_default = cls._xgb_cls()()
|
||||
params_dict = xgb_model_default.get_params()
|
||||
filtered_params_dict = {
|
||||
@ -229,11 +247,15 @@ class _SparkXGBParams(
|
||||
filtered_params_dict["n_estimators"] = DEFAULT_N_ESTIMATORS
|
||||
return filtered_params_dict
|
||||
|
||||
def _set_xgb_params_default(self):
|
||||
def _set_xgb_params_default(self) -> None:
|
||||
"""Set xgboost parameters into spark parameters"""
|
||||
filtered_params_dict = self._get_xgb_params_default()
|
||||
self._setDefault(**filtered_params_dict)
|
||||
|
||||
def _gen_xgb_params_dict(self, gen_xgb_sklearn_estimator_param=False):
|
||||
def _gen_xgb_params_dict(
|
||||
self, gen_xgb_sklearn_estimator_param: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate the xgboost parameters which will be passed into xgboost library"""
|
||||
xgb_params = {}
|
||||
non_xgb_params = (
|
||||
set(_pyspark_specific_params)
|
||||
@ -254,20 +276,20 @@ class _SparkXGBParams(
|
||||
|
||||
# Parameters for xgboost.XGBModel().fit()
|
||||
@classmethod
|
||||
def _get_fit_params_default(cls):
|
||||
def _get_fit_params_default(cls) -> Dict[str, Any]:
|
||||
"""Get the xgboost.XGBModel().fit() parameters"""
|
||||
fit_params = _get_default_params_from_func(
|
||||
cls._xgb_cls().fit, _unsupported_fit_params
|
||||
)
|
||||
return fit_params
|
||||
|
||||
def _set_fit_params_default(self):
|
||||
def _set_fit_params_default(self) -> None:
|
||||
"""Get the xgboost.XGBModel().fit() parameters and set them to spark parameters"""
|
||||
filtered_params_dict = self._get_fit_params_default()
|
||||
self._setDefault(**filtered_params_dict)
|
||||
|
||||
def _gen_fit_params_dict(self):
|
||||
"""
|
||||
Returns a dict of params for .fit()
|
||||
"""
|
||||
def _gen_fit_params_dict(self) -> Dict[str, Any]:
|
||||
"""Generate the fit parameters which will be passed into fit function"""
|
||||
fit_params_keys = self._get_fit_params_default().keys()
|
||||
fit_params = {}
|
||||
for param in self.extractParamMap():
|
||||
@ -275,22 +297,22 @@ class _SparkXGBParams(
|
||||
fit_params[param.name] = self.getOrDefault(param)
|
||||
return fit_params
|
||||
|
||||
# Parameters for xgboost.XGBModel().predict()
|
||||
@classmethod
|
||||
def _get_predict_params_default(cls):
|
||||
def _get_predict_params_default(cls) -> Dict[str, Any]:
|
||||
"""Get the parameters from xgboost.XGBModel().predict()"""
|
||||
predict_params = _get_default_params_from_func(
|
||||
cls._xgb_cls().predict, _unsupported_predict_params
|
||||
)
|
||||
return predict_params
|
||||
|
||||
def _set_predict_params_default(self):
|
||||
def _set_predict_params_default(self) -> None:
|
||||
"""Get the parameters from xgboost.XGBModel().predict() and
|
||||
set them into spark parameters"""
|
||||
filtered_params_dict = self._get_predict_params_default()
|
||||
self._setDefault(**filtered_params_dict)
|
||||
|
||||
def _gen_predict_params_dict(self):
|
||||
"""
|
||||
Returns a dict of params for .predict()
|
||||
"""
|
||||
def _gen_predict_params_dict(self) -> Dict[str, Any]:
|
||||
"""Generate predict parameters which will be passed into xgboost.XGBModel().predict()"""
|
||||
predict_params_keys = self._get_predict_params_default().keys()
|
||||
predict_params = {}
|
||||
for param in self.extractParamMap():
|
||||
@ -298,9 +320,9 @@ class _SparkXGBParams(
|
||||
predict_params[param.name] = self.getOrDefault(param)
|
||||
return predict_params
|
||||
|
||||
def _validate_params(self):
|
||||
def _validate_params(self) -> None:
|
||||
# pylint: disable=too-many-branches
|
||||
init_model = self.getOrDefault(self.xgb_model)
|
||||
init_model = self.getOrDefault("xgb_model")
|
||||
if init_model is not None and not isinstance(init_model, Booster):
|
||||
raise ValueError(
|
||||
"The xgb_model param must be set with a `xgboost.core.Booster` "
|
||||
@ -321,18 +343,19 @@ class _SparkXGBParams(
|
||||
"If features_cols param set, then features_col param is ignored."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.objective) is not None:
|
||||
if not isinstance(self.getOrDefault(self.objective), str):
|
||||
if self.getOrDefault("objective") is not None:
|
||||
if not isinstance(self.getOrDefault("objective"), str):
|
||||
raise ValueError("Only string type 'objective' param is allowed.")
|
||||
|
||||
if self.getOrDefault(self.eval_metric) is not None:
|
||||
eval_metric = "eval_metric"
|
||||
if self.getOrDefault(eval_metric) is not None:
|
||||
if not (
|
||||
isinstance(self.getOrDefault(self.eval_metric), str)
|
||||
isinstance(self.getOrDefault(eval_metric), str)
|
||||
or (
|
||||
isinstance(self.getOrDefault(self.eval_metric), List)
|
||||
isinstance(self.getOrDefault(eval_metric), List)
|
||||
and all(
|
||||
isinstance(metric, str)
|
||||
for metric in self.getOrDefault(self.eval_metric)
|
||||
for metric in self.getOrDefault(eval_metric)
|
||||
)
|
||||
)
|
||||
):
|
||||
@ -340,10 +363,10 @@ class _SparkXGBParams(
|
||||
"Only string type or list of string type 'eval_metric' param is allowed."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.early_stopping_rounds) is not None:
|
||||
if self.getOrDefault("early_stopping_rounds") is not None:
|
||||
if not (
|
||||
self.isDefined(self.validationIndicatorCol)
|
||||
and self.getOrDefault(self.validationIndicatorCol)
|
||||
and self.getOrDefault(self.validationIndicatorCol) != ""
|
||||
):
|
||||
raise ValueError(
|
||||
"If 'early_stopping_rounds' param is set, you need to set "
|
||||
@ -351,7 +374,7 @@ class _SparkXGBParams(
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.enable_sparse_data_optim):
|
||||
if self.getOrDefault(self.missing) != 0.0:
|
||||
if self.getOrDefault("missing") != 0.0:
|
||||
# If DMatrix is constructed from csr / csc matrix, then inactive elements
|
||||
# in csr / csc matrix are regarded as missing value, but, in pyspark, we
|
||||
# are hard to control elements to be active or inactive in sparse vector column,
|
||||
@ -424,8 +447,8 @@ class _SparkXGBParams(
|
||||
|
||||
|
||||
def _validate_and_convert_feature_col_as_float_col_list(
|
||||
dataset, features_col_names: list
|
||||
) -> list:
|
||||
dataset: DataFrame, features_col_names: List[str]
|
||||
) -> List[Column]:
|
||||
"""Values in feature columns must be integral types or float/double types"""
|
||||
feature_cols = []
|
||||
for c in features_col_names:
|
||||
@ -440,7 +463,12 @@ def _validate_and_convert_feature_col_as_float_col_list(
|
||||
return feature_cols
|
||||
|
||||
|
||||
def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
|
||||
def _validate_and_convert_feature_col_as_array_col(
|
||||
dataset: DataFrame, features_col_name: str
|
||||
) -> Column:
|
||||
"""It handles
|
||||
1. Convert vector type to array type
|
||||
2. Cast to Array(Float32)"""
|
||||
features_col_datatype = dataset.schema[features_col_name].dataType
|
||||
features_col = col(features_col_name)
|
||||
if isinstance(features_col_datatype, ArrayType):
|
||||
@ -466,7 +494,7 @@ def _validate_and_convert_feature_col_as_array_col(dataset, features_col_name):
|
||||
return features_array_col
|
||||
|
||||
|
||||
def _get_unwrap_udt_fn():
|
||||
def _get_unwrap_udt_fn() -> Callable[[Union[Column, str]], Column]:
|
||||
try:
|
||||
from pyspark.sql.functions import unwrap_udt
|
||||
|
||||
@ -475,9 +503,9 @@ def _get_unwrap_udt_fn():
|
||||
pass
|
||||
|
||||
try:
|
||||
from pyspark.databricks.sql.functions import unwrap_udt
|
||||
from pyspark.databricks.sql.functions import unwrap_udt as databricks_unwrap_udt
|
||||
|
||||
return unwrap_udt
|
||||
return databricks_unwrap_udt
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 "
|
||||
@ -485,7 +513,7 @@ def _get_unwrap_udt_fn():
|
||||
) from exc
|
||||
|
||||
|
||||
def _get_unwrapped_vec_cols(feature_col):
|
||||
def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
|
||||
unwrap_udt = _get_unwrap_udt_fn()
|
||||
features_unwrapped_vec_col = unwrap_udt(feature_col)
|
||||
|
||||
@ -519,7 +547,7 @@ FeatureProp = namedtuple(
|
||||
|
||||
|
||||
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._set_xgb_params_default()
|
||||
self._set_fit_params_default()
|
||||
@ -537,7 +565,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
arbitrary_params_dict={},
|
||||
)
|
||||
|
||||
def setParams(self, **kwargs): # pylint: disable=invalid-name
|
||||
def setParams(
|
||||
self, **kwargs: Dict[str, Any]
|
||||
) -> None: # pylint: disable=invalid-name
|
||||
"""
|
||||
Set params for the estimator.
|
||||
"""
|
||||
@ -578,17 +608,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
||||
|
||||
@classmethod
|
||||
def _pyspark_model_cls(cls):
|
||||
def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
|
||||
"""
|
||||
Subclasses should override this method and
|
||||
returns a _SparkXGBModel subclass
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _create_pyspark_model(self, xgb_model):
|
||||
def _create_pyspark_model(self, xgb_model: XGBModel) -> "_SparkXGBModel":
|
||||
return self._pyspark_model_cls()(xgb_model)
|
||||
|
||||
def _convert_to_sklearn_model(self, booster: bytearray, config: str):
|
||||
def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel:
|
||||
xgb_sklearn_params = self._gen_xgb_params_dict(
|
||||
gen_xgb_sklearn_estimator_param=True
|
||||
)
|
||||
@ -597,7 +627,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
sklearn_model._Booster.load_config(config)
|
||||
return sklearn_model
|
||||
|
||||
def _query_plan_contains_valid_repartition(self, dataset):
|
||||
def _query_plan_contains_valid_repartition(self, dataset: DataFrame) -> bool:
|
||||
"""
|
||||
Returns true if the latest element in the logical plan is a valid repartition
|
||||
The logic plan string format is like:
|
||||
@ -613,6 +643,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
repartition the dataframe again.
|
||||
"""
|
||||
num_partitions = dataset.rdd.getNumPartitions()
|
||||
assert dataset._sc._jvm is not None
|
||||
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
|
||||
dataset._jdf.queryExecution(), "extended"
|
||||
)
|
||||
@ -626,7 +657,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _repartition_needed(self, dataset):
|
||||
def _repartition_needed(self, dataset: DataFrame) -> bool:
|
||||
"""
|
||||
We repartition the dataset if the number of workers is not equal to the number of
|
||||
partitions. There is also a check to make sure there was "active partitioning"
|
||||
@ -641,7 +672,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
pass
|
||||
return True
|
||||
|
||||
def _get_distributed_train_params(self, dataset):
|
||||
def _get_distributed_train_params(self, dataset: DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
This just gets the configuration params for distributed xgboost
|
||||
"""
|
||||
@ -664,10 +695,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
else:
|
||||
# use user specified objective or default objective.
|
||||
# e.g., the default objective for Regressor is 'reg:squarederror'
|
||||
params["objective"] = self.getOrDefault(self.objective)
|
||||
params["objective"] = self.getOrDefault("objective")
|
||||
|
||||
# TODO: support "num_parallel_tree" for random forest
|
||||
params["num_boost_round"] = self.getOrDefault(self.n_estimators)
|
||||
params["num_boost_round"] = self.getOrDefault("n_estimators")
|
||||
|
||||
if self.getOrDefault(self.use_gpu):
|
||||
params["tree_method"] = "gpu_hist"
|
||||
@ -675,7 +706,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def _get_xgb_train_call_args(cls, train_params):
|
||||
def _get_xgb_train_call_args(
|
||||
cls, train_params: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
xgb_train_default_args = _get_default_params_from_func(
|
||||
xgboost.train, _unsupported_train_params
|
||||
)
|
||||
@ -693,7 +726,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
def _prepare_input_columns_and_feature_prop(
|
||||
self, dataset: DataFrame
|
||||
) -> Tuple[List[str], FeatureProp]:
|
||||
) -> Tuple[List[Column], FeatureProp]:
|
||||
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)
|
||||
|
||||
select_cols = [label_col]
|
||||
@ -721,14 +754,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
select_cols.append(features_array_col)
|
||||
|
||||
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol):
|
||||
if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol) != "":
|
||||
select_cols.append(
|
||||
col(self.getOrDefault(self.weightCol)).alias(alias.weight)
|
||||
)
|
||||
|
||||
has_validation_col = False
|
||||
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault(
|
||||
self.validationIndicatorCol
|
||||
if (
|
||||
self.isDefined(self.validationIndicatorCol)
|
||||
and self.getOrDefault(self.validationIndicatorCol) != ""
|
||||
):
|
||||
select_cols.append(
|
||||
col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid)
|
||||
@ -738,14 +772,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
# which will cause exception or hanging issue when creating DMatrix.
|
||||
has_validation_col = True
|
||||
|
||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||
self.base_margin_col
|
||||
if (
|
||||
self.isDefined(self.base_margin_col)
|
||||
and self.getOrDefault(self.base_margin_col) != ""
|
||||
):
|
||||
select_cols.append(
|
||||
col(self.getOrDefault(self.base_margin_col)).alias(alias.margin)
|
||||
)
|
||||
|
||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
|
||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "":
|
||||
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))
|
||||
|
||||
feature_prop = FeatureProp(
|
||||
@ -777,7 +812,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
if self._repartition_needed(dataset) or (
|
||||
self.isDefined(self.validationIndicatorCol)
|
||||
and self.getOrDefault(self.validationIndicatorCol)
|
||||
and self.getOrDefault(self.validationIndicatorCol) != ""
|
||||
):
|
||||
# If validationIndicatorCol defined, we always repartition dataset
|
||||
# to balance data, because user might unionise train and validation dataset,
|
||||
@ -790,13 +825,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
else:
|
||||
dataset = dataset.repartition(num_workers)
|
||||
|
||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
|
||||
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "":
|
||||
# XGBoost requires qid to be sorted for each partition
|
||||
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)
|
||||
|
||||
return dataset, feature_prop
|
||||
|
||||
def _get_xgb_parameters(self, dataset: DataFrame):
|
||||
def _get_xgb_parameters(
|
||||
self, dataset: DataFrame
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
train_params = self._get_distributed_train_params(dataset)
|
||||
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
||||
train_params
|
||||
@ -807,10 +844,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
dmatrix_kwargs = {
|
||||
"nthread": cpu_per_task,
|
||||
"feature_types": self.getOrDefault(self.feature_types),
|
||||
"feature_names": self.getOrDefault(self.feature_names),
|
||||
"feature_weights": self.getOrDefault(self.feature_weights),
|
||||
"missing": float(self.getOrDefault(self.missing)),
|
||||
"feature_types": self.getOrDefault("feature_types"),
|
||||
"feature_names": self.getOrDefault("feature_names"),
|
||||
"feature_weights": self.getOrDefault("feature_weights"),
|
||||
"missing": float(self.getOrDefault("missing")),
|
||||
}
|
||||
if dmatrix_kwargs["feature_types"] is not None:
|
||||
dmatrix_kwargs["enable_categorical"] = True
|
||||
@ -825,7 +862,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
||||
|
||||
def _fit(self, dataset):
|
||||
def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
self._validate_params()
|
||||
|
||||
@ -843,7 +880,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
|
||||
def _train_booster(pandas_df_iter):
|
||||
def _train_booster(
|
||||
pandas_df_iter: Iterator[pd.DataFrame],
|
||||
) -> Iterator[pd.DataFrame]:
|
||||
"""Takes in an RDD partition and outputs a booster for that partition after
|
||||
going through the Rabit Ring protocol
|
||||
|
||||
@ -893,7 +932,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
_rabit_args = json.loads(messages[0])["rabit_msg"]
|
||||
|
||||
evals_result = {}
|
||||
evals_result: Dict[str, Any] = {}
|
||||
with CommunicatorContext(context, **_rabit_args):
|
||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||
pandas_df_iter,
|
||||
@ -925,10 +964,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
}
|
||||
)
|
||||
|
||||
def _run_job():
|
||||
def _run_job() -> Tuple[str, str]:
|
||||
ret = (
|
||||
dataset.mapInPandas(
|
||||
_train_booster, schema="config string, booster string"
|
||||
_train_booster, schema="config string, booster string" # type: ignore
|
||||
)
|
||||
.rdd.barrier()
|
||||
.mapPartitions(lambda x: x)
|
||||
@ -947,14 +986,14 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
spark_model._resetUid(self.uid)
|
||||
return self._copyValues(spark_model)
|
||||
|
||||
def write(self):
|
||||
def write(self) -> "SparkXGBWriter":
|
||||
"""
|
||||
Return the writer for saving the estimator.
|
||||
"""
|
||||
return SparkXGBWriter(self)
|
||||
|
||||
@classmethod
|
||||
def read(cls):
|
||||
def read(cls) -> "SparkXGBReader":
|
||||
"""
|
||||
Return the reader for loading the estimator.
|
||||
"""
|
||||
@ -962,21 +1001,24 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
|
||||
class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
def __init__(self, xgb_sklearn_model=None):
|
||||
def __init__(self, xgb_sklearn_model: Optional[XGBModel] = None) -> None:
|
||||
super().__init__()
|
||||
self._xgb_sklearn_model = xgb_sklearn_model
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls):
|
||||
def _xgb_cls(cls) -> Type[XGBModel]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_booster(self):
|
||||
def get_booster(self) -> Booster:
|
||||
"""
|
||||
Return the `xgboost.core.Booster` instance.
|
||||
"""
|
||||
assert self._xgb_sklearn_model is not None
|
||||
return self._xgb_sklearn_model.get_booster()
|
||||
|
||||
def get_feature_importances(self, importance_type="weight"):
|
||||
def get_feature_importances(
|
||||
self, importance_type: str = "weight"
|
||||
) -> Dict[str, Union[float, List[float]]]:
|
||||
"""Get feature importance of each feature.
|
||||
Importance type can be defined as:
|
||||
|
||||
@ -993,20 +1035,22 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
"""
|
||||
return self.get_booster().get_score(importance_type=importance_type)
|
||||
|
||||
def write(self):
|
||||
def write(self) -> "SparkXGBModelWriter":
|
||||
"""
|
||||
Return the writer for saving the model.
|
||||
"""
|
||||
return SparkXGBModelWriter(self)
|
||||
|
||||
@classmethod
|
||||
def read(cls):
|
||||
def read(cls) -> "SparkXGBModelReader":
|
||||
"""
|
||||
Return the reader for loading the model.
|
||||
"""
|
||||
return SparkXGBModelReader(cls)
|
||||
|
||||
def _get_feature_col(self, dataset) -> (list, Optional[list]):
|
||||
def _get_feature_col(
|
||||
self, dataset: DataFrame
|
||||
) -> Tuple[List[Column], Optional[List[str]]]:
|
||||
"""XGBoost model trained with features_cols parameter can also predict
|
||||
vector or array feature type. But first we need to check features_cols
|
||||
and then featuresCol
|
||||
@ -1040,7 +1084,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
return features_col, feature_col_names
|
||||
|
||||
def _transform(self, dataset):
|
||||
def _transform(self, dataset: DataFrame) -> DataFrame:
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
@ -1048,8 +1092,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
|
||||
has_base_margin = False
|
||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||
self.base_margin_col
|
||||
if (
|
||||
self.isDefined(self.base_margin_col)
|
||||
and self.getOrDefault(self.base_margin_col) != ""
|
||||
):
|
||||
has_base_margin = True
|
||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||
@ -1060,8 +1105,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
|
||||
pred_contrib_col_name = None
|
||||
if self.isDefined(self.pred_contrib_col) and self.getOrDefault(
|
||||
self.pred_contrib_col
|
||||
if (
|
||||
self.isDefined(self.pred_contrib_col)
|
||||
and self.getOrDefault(self.pred_contrib_col) != ""
|
||||
):
|
||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||
|
||||
@ -1071,8 +1117,9 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
single_pred = False
|
||||
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
|
||||
|
||||
@pandas_udf(schema)
|
||||
@pandas_udf(schema) # type: ignore
|
||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||
assert xgb_sklearn_model is not None
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
if enable_sparse_data_optim:
|
||||
@ -1141,7 +1188,7 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
.. Note:: This API is experimental.
|
||||
"""
|
||||
|
||||
def _transform(self, dataset):
|
||||
def _transform(self, dataset: DataFrame) -> DataFrame:
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
@ -1149,15 +1196,16 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
predict_params = self._gen_predict_params_dict()
|
||||
|
||||
has_base_margin = False
|
||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||
self.base_margin_col
|
||||
if (
|
||||
self.isDefined(self.base_margin_col)
|
||||
and self.getOrDefault(self.base_margin_col) != ""
|
||||
):
|
||||
has_base_margin = True
|
||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||
alias.margin
|
||||
)
|
||||
|
||||
def transform_margin(margins: np.ndarray):
|
||||
def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if margins.ndim == 1:
|
||||
# binomial case
|
||||
classone_probs = expit(margins)
|
||||
@ -1174,8 +1222,9 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
|
||||
pred_contrib_col_name = None
|
||||
if self.isDefined(self.pred_contrib_col) and self.getOrDefault(
|
||||
self.pred_contrib_col
|
||||
if (
|
||||
self.isDefined(self.pred_contrib_col)
|
||||
and self.getOrDefault(self.pred_contrib_col) != ""
|
||||
):
|
||||
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
|
||||
|
||||
@ -1188,17 +1237,18 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
# So, it will also output 3-D shape result.
|
||||
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"
|
||||
|
||||
@pandas_udf(schema)
|
||||
@pandas_udf(schema) # type: ignore
|
||||
def predict_udf(
|
||||
iterator: Iterator[Tuple[pd.Series, ...]]
|
||||
) -> Iterator[pd.DataFrame]:
|
||||
assert xgb_sklearn_model is not None
|
||||
model = xgb_sklearn_model
|
||||
for data in iterator:
|
||||
if enable_sparse_data_optim:
|
||||
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
|
||||
else:
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
X = data[feature_col_names] # type: ignore
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
|
||||
@ -1219,7 +1269,7 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
# It seems that they use argmax of class probs,
|
||||
# not of margin to get the prediction (Note: scala implementation)
|
||||
preds = np.argmax(class_probs, axis=1)
|
||||
data = {
|
||||
result: Dict[str, pd.Series] = {
|
||||
pred.raw_prediction: pd.Series(list(raw_preds)),
|
||||
pred.prediction: pd.Series(preds),
|
||||
pred.probability: pd.Series(list(class_probs)),
|
||||
@ -1227,9 +1277,9 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
|
||||
if pred_contrib_col_name:
|
||||
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
|
||||
data[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
|
||||
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
|
||||
|
||||
yield pd.DataFrame(data=data)
|
||||
yield pd.DataFrame(data=result)
|
||||
|
||||
if has_base_margin:
|
||||
pred_struct = predict_udf(struct(*features_col, base_margin_col))
|
||||
@ -1270,7 +1320,13 @@ class _ClassificationModel( # pylint: disable=abstract-method
|
||||
|
||||
class _SparkXGBSharedReadWrite:
|
||||
@staticmethod
|
||||
def saveMetadata(instance, path, sc, logger, extraMetadata=None):
|
||||
def saveMetadata(
|
||||
instance: Union[_SparkXGBEstimator, _SparkXGBModel],
|
||||
path: str,
|
||||
sc: SparkContext,
|
||||
logger: logging.Logger,
|
||||
extraMetadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save the metadata of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
@ -1283,7 +1339,7 @@ class _SparkXGBSharedReadWrite:
|
||||
jsonParams[p.name] = v
|
||||
|
||||
extraMetadata = extraMetadata or {}
|
||||
callbacks = instance.getOrDefault(instance.callbacks)
|
||||
callbacks = instance.getOrDefault("callbacks")
|
||||
if callbacks is not None:
|
||||
logger.warning(
|
||||
"The callbacks parameter is saved using cloudpickle and it "
|
||||
@ -1294,7 +1350,7 @@ class _SparkXGBSharedReadWrite:
|
||||
cloudpickle.dumps(callbacks)
|
||||
).decode("ascii")
|
||||
extraMetadata["serialized_callbacks"] = serialized_callbacks
|
||||
init_booster = instance.getOrDefault(instance.xgb_model)
|
||||
init_booster = instance.getOrDefault("xgb_model")
|
||||
if init_booster is not None:
|
||||
extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
|
||||
DefaultParamsWriter.saveMetadata(
|
||||
@ -1308,7 +1364,12 @@ class _SparkXGBSharedReadWrite:
|
||||
).write.parquet(save_path)
|
||||
|
||||
@staticmethod
|
||||
def loadMetadataAndInstance(pyspark_xgb_cls, path, sc, logger):
|
||||
def loadMetadataAndInstance(
|
||||
pyspark_xgb_cls: Union[Type[_SparkXGBEstimator], Type[_SparkXGBModel]],
|
||||
path: str,
|
||||
sc: SparkContext,
|
||||
logger: logging.Logger,
|
||||
) -> Tuple[Dict[str, Any], Union[_SparkXGBEstimator, _SparkXGBModel]]:
|
||||
"""
|
||||
Load the metadata and the instance of an xgboost.spark._SparkXGBEstimator or
|
||||
xgboost.spark._SparkXGBModel.
|
||||
@ -1327,7 +1388,7 @@ class _SparkXGBSharedReadWrite:
|
||||
callbacks = cloudpickle.loads(
|
||||
base64.decodebytes(serialized_callbacks.encode("ascii"))
|
||||
)
|
||||
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks)
|
||||
pyspark_xgb.set(pyspark_xgb.callbacks, callbacks) # type: ignore
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
logger.warning(
|
||||
f"Fails to load the callbacks param due to {e}. Please set the "
|
||||
@ -1340,7 +1401,7 @@ class _SparkXGBSharedReadWrite:
|
||||
_get_spark_session().read.parquet(load_path).collect()[0].init_booster
|
||||
)
|
||||
init_booster = deserialize_booster(ser_init_booster)
|
||||
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster)
|
||||
pyspark_xgb.set(pyspark_xgb.xgb_model, init_booster) # type: ignore
|
||||
|
||||
pyspark_xgb._resetUid(metadata["uid"]) # pylint: disable=protected-access
|
||||
return metadata, pyspark_xgb
|
||||
@ -1351,12 +1412,12 @@ class SparkXGBWriter(MLWriter):
|
||||
Spark Xgboost estimator writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
def __init__(self, instance: "_SparkXGBEstimator") -> None:
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
def saveImpl(self, path: str) -> None:
|
||||
"""
|
||||
save model.
|
||||
"""
|
||||
@ -1368,19 +1429,19 @@ class SparkXGBReader(MLReader):
|
||||
Spark Xgboost estimator reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
def __init__(self, cls: Type["_SparkXGBEstimator"]) -> None:
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
def load(self, path: str) -> "_SparkXGBEstimator":
|
||||
"""
|
||||
load model.
|
||||
"""
|
||||
_, pyspark_xgb = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
return pyspark_xgb
|
||||
return cast("_SparkXGBEstimator", pyspark_xgb)
|
||||
|
||||
|
||||
class SparkXGBModelWriter(MLWriter):
|
||||
@ -1388,18 +1449,19 @@ class SparkXGBModelWriter(MLWriter):
|
||||
Spark Xgboost model writer.
|
||||
"""
|
||||
|
||||
def __init__(self, instance):
|
||||
def __init__(self, instance: _SparkXGBModel) -> None:
|
||||
super().__init__()
|
||||
self.instance = instance
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def saveImpl(self, path):
|
||||
def saveImpl(self, path: str) -> None:
|
||||
"""
|
||||
Save metadata and model for a :py:class:`_SparkXGBModel`
|
||||
- save metadata to path/metadata
|
||||
- save model to path/model.json
|
||||
"""
|
||||
xgb_model = self.instance._xgb_sklearn_model
|
||||
assert xgb_model is not None
|
||||
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
|
||||
model_save_path = os.path.join(path, "model")
|
||||
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
|
||||
@ -1413,12 +1475,12 @@ class SparkXGBModelReader(MLReader):
|
||||
Spark Xgboost model reader.
|
||||
"""
|
||||
|
||||
def __init__(self, cls):
|
||||
def __init__(self, cls: Type["_SparkXGBModel"]) -> None:
|
||||
super().__init__()
|
||||
self.cls = cls
|
||||
self.logger = get_logger(self.__class__.__name__, level="WARN")
|
||||
|
||||
def load(self, path):
|
||||
def load(self, path: str) -> "_SparkXGBModel":
|
||||
"""
|
||||
Load metadata and model for a :py:class:`_SparkXGBModel`
|
||||
|
||||
@ -1427,6 +1489,7 @@ class SparkXGBModelReader(MLReader):
|
||||
_, py_model = _SparkXGBSharedReadWrite.loadMetadataAndInstance(
|
||||
self.cls, path, self.sc, self.logger
|
||||
)
|
||||
py_model = cast("_SparkXGBModel", py_model)
|
||||
|
||||
xgb_sklearn_params = py_model._gen_xgb_params_dict(
|
||||
gen_xgb_sklearn_estimator_param=True
|
||||
@ -1437,7 +1500,7 @@ class SparkXGBModelReader(MLReader):
|
||||
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
|
||||
)
|
||||
|
||||
def create_xgb_model():
|
||||
def create_xgb_model() -> "XGBModel":
|
||||
return self.cls._xgb_cls()(**xgb_sklearn_params)
|
||||
|
||||
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
|
||||
|
||||
@ -19,7 +19,7 @@ from .utils import get_class_name
|
||||
|
||||
|
||||
def _set_pyspark_xgb_cls_param_attrs(
|
||||
estimator: _SparkXGBEstimator, model: _SparkXGBModel
|
||||
estimator: Type[_SparkXGBEstimator], model: Type[_SparkXGBModel]
|
||||
) -> None:
|
||||
"""This function automatically infer to xgboost parameters and set them
|
||||
into corresponding pyspark estimators and models"""
|
||||
@ -304,7 +304,7 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
raise ValueError(
|
||||
"Spark Xgboost classifier estimator does not support `qid_col` param."
|
||||
)
|
||||
if self.getOrDefault(self.objective): # pylint: disable=no-member
|
||||
if self.getOrDefault("objective"): # pylint: disable=no-member
|
||||
raise ValueError(
|
||||
"Setting custom 'objective' param is not allowed in 'SparkXGBClassifier'."
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user