[pyspark] add parameters in the ctor of all estimators. (#9202)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
03bc6e6427
commit
320323f533
@ -337,11 +337,9 @@ class _SparkXGBParams(
|
||||
|
||||
if self.getOrDefault(self.features_cols):
|
||||
if not self.getOrDefault(self.use_gpu):
|
||||
raise ValueError("features_cols param requires enabling use_gpu.")
|
||||
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"If features_cols param set, then features_col param is ignored."
|
||||
)
|
||||
raise ValueError(
|
||||
"features_col param with list value requires enabling use_gpu."
|
||||
)
|
||||
|
||||
if self.getOrDefault("objective") is not None:
|
||||
if not isinstance(self.getOrDefault("objective"), str):
|
||||
@ -547,6 +545,8 @@ FeatureProp = namedtuple(
|
||||
|
||||
|
||||
class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
_input_kwargs: Dict[str, Any]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._set_xgb_params_default()
|
||||
@ -576,6 +576,11 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
raise ValueError("Invalid param name: 'arbitrary_params_dict'.")
|
||||
|
||||
for k, v in kwargs.items():
|
||||
# We're not allowing user use features_cols directly.
|
||||
if k == self.features_cols.name:
|
||||
raise ValueError(
|
||||
f"Unsupported param '{k}' please use features_col instead."
|
||||
)
|
||||
if k in _inverse_pyspark_param_alias_map:
|
||||
raise ValueError(
|
||||
f"Please use param name {_inverse_pyspark_param_alias_map[k]} instead."
|
||||
@ -591,7 +596,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
k = real_k
|
||||
|
||||
if self.hasParam(k):
|
||||
self._set(**{str(k): v})
|
||||
if k == "features_col" and isinstance(v, list):
|
||||
self._set(**{"features_cols": v})
|
||||
else:
|
||||
self._set(**{str(k): v})
|
||||
else:
|
||||
if (
|
||||
k in _unsupported_xgb_params
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
"""Xgboost pyspark integration submodule for estimator API."""
|
||||
# pylint: disable=too-many-ancestors
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=unused-argument, too-many-locals
|
||||
|
||||
from typing import Any, Type
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from pyspark import keyword_only
|
||||
from pyspark.ml.param import Param, Params
|
||||
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
|
||||
|
||||
@ -83,8 +86,8 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
:py:class:`~pyspark.ml.classification.OneVsRest`
|
||||
|
||||
SparkXGBRegressor automatically supports most of the parameters in
|
||||
`xgboost.XGBRegressor` constructor and most of the parameters used in
|
||||
:py:class:`xgboost.XGBRegressor` fit and predict method.
|
||||
:py:class:`xgboost.XGBRegressor` constructor and most of the parameters used in
|
||||
:py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict` method.
|
||||
|
||||
SparkXGBRegressor doesn't support setting `gpu_id` but support another param `use_gpu`,
|
||||
see doc below for more details.
|
||||
@ -97,13 +100,23 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
||||
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
||||
|
||||
callbacks:
|
||||
The export and import of the callback functions are at best effort.
|
||||
For details, see :py:attr:`xgboost.spark.SparkXGBRegressor.callbacks` param doc.
|
||||
validation_indicator_col
|
||||
For params related to `xgboost.XGBRegressor` training
|
||||
with evaluation dataset's supervision, set
|
||||
:py:attr:`xgboost.spark.SparkXGBRegressor.validation_indicator_col`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
features_col:
|
||||
When the value is string, it requires the features column name to be vector type.
|
||||
When the value is a list of string, it requires all the feature columns to be numeric types.
|
||||
label_col:
|
||||
Label column name. Default to "label".
|
||||
prediction_col:
|
||||
Prediction column name. Default to "prediction"
|
||||
pred_contrib_col:
|
||||
Contribution prediction column name.
|
||||
validation_indicator_col:
|
||||
For params related to `xgboost.XGBRegressor` training with
|
||||
evaluation dataset's supervision,
|
||||
set :py:attr:`xgboost.spark.SparkXGBRegressor.validation_indicator_col`
|
||||
parameter instead of setting the `eval_set` parameter in `xgboost.XGBRegressor`
|
||||
fit method.
|
||||
weight_col:
|
||||
@ -111,26 +124,40 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
:py:attr:`xgboost.spark.SparkXGBRegressor.weight_col` parameter instead of setting
|
||||
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRegressor`
|
||||
fit method.
|
||||
xgb_model:
|
||||
Set the value to be the instance returned by
|
||||
:func:`xgboost.spark.SparkXGBRegressorModel.get_booster`.
|
||||
num_workers:
|
||||
Integer that specifies the number of XGBoost workers to use.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean that specifies whether the executors are running on GPU
|
||||
instances.
|
||||
base_margin_col:
|
||||
To specify the base margins of the training and validation
|
||||
dataset, set :py:attr:`xgboost.spark.SparkXGBRegressor.base_margin_col` parameter
|
||||
instead of setting `base_margin` and `base_margin_eval_set` in the
|
||||
`xgboost.XGBRegressor` fit method. Note: this isn't available for distributed
|
||||
training.
|
||||
`xgboost.XGBRegressor` fit method.
|
||||
|
||||
.. Note:: The Parameters chart above contains parameters that need special handling.
|
||||
For a full list of parameters, see entries with `Param(parent=...` below.
|
||||
num_workers:
|
||||
How many XGBoost workers to be used to train.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean value to specify whether the executors are running on GPU
|
||||
instances.
|
||||
force_repartition:
|
||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||
before XGBoost training.
|
||||
repartition_random_shuffle:
|
||||
Boolean value to specify if randomly shuffling the dataset when repartitioning is required.
|
||||
enable_sparse_data_optim:
|
||||
Boolean value to specify if enabling sparse data optimization, if True,
|
||||
Xgboost DMatrix object will be constructed from sparse matrix instead of
|
||||
dense matrix.
|
||||
|
||||
kwargs:
|
||||
A dictionary of xgboost parameters, please refer to
|
||||
https://xgboost.readthedocs.io/en/stable/parameter.html
|
||||
|
||||
Note
|
||||
----
|
||||
|
||||
The Parameters chart above contains parameters that need special handling.
|
||||
For a full list of parameters, see entries with `Param(parent=...` below.
|
||||
|
||||
This API is experimental.
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@ -155,9 +182,27 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
@keyword_only
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
features_col: Union[str, List[str]] = "features",
|
||||
label_col: str = "label",
|
||||
prediction_col: str = "prediction",
|
||||
pred_contrib_col: Optional[str] = None,
|
||||
validation_indicator_col: Optional[str] = None,
|
||||
weight_col: Optional[str] = None,
|
||||
base_margin_col: Optional[str] = None,
|
||||
num_workers: int = 1,
|
||||
use_gpu: bool = False,
|
||||
force_repartition: bool = False,
|
||||
repartition_random_shuffle: bool = False,
|
||||
enable_sparse_data_optim: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.setParams(**kwargs)
|
||||
input_kwargs = self._input_kwargs
|
||||
self.setParams(**input_kwargs)
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls) -> Type[XGBRegressor]:
|
||||
@ -199,8 +244,8 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
:py:class:`~pyspark.ml.classification.OneVsRest`
|
||||
|
||||
SparkXGBClassifier automatically supports most of the parameters in
|
||||
`xgboost.XGBClassifier` constructor and most of the parameters used in
|
||||
:py:class:`xgboost.XGBClassifier` fit and predict method.
|
||||
:py:class:`xgboost.XGBClassifier` constructor and most of the parameters used in
|
||||
:py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict` method.
|
||||
|
||||
SparkXGBClassifier doesn't support setting `gpu_id` but support another param `use_gpu`,
|
||||
see doc below for more details.
|
||||
@ -220,13 +265,21 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
Parameters
|
||||
----------
|
||||
|
||||
callbacks:
|
||||
The export and import of the callback functions are at best effort. For
|
||||
details, see :py:attr:`xgboost.spark.SparkXGBClassifier.callbacks` param doc.
|
||||
features_col:
|
||||
When the value is string, it requires the features column name to be vector type.
|
||||
When the value is a list of string, it requires all the feature columns to be numeric types.
|
||||
label_col:
|
||||
Label column name. Default to "label".
|
||||
prediction_col:
|
||||
Prediction column name. Default to "prediction"
|
||||
probability_col:
|
||||
Column name for predicted class conditional probabilities. Default to probabilityCol
|
||||
raw_prediction_col:
|
||||
The `output_margin=True` is implicitly supported by the
|
||||
`rawPredictionCol` output column, which is always returned with the predicted margin
|
||||
values.
|
||||
pred_contrib_col:
|
||||
Contribution prediction column name.
|
||||
validation_indicator_col:
|
||||
For params related to `xgboost.XGBClassifier` training with
|
||||
evaluation dataset's supervision,
|
||||
@ -238,26 +291,39 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
:py:attr:`xgboost.spark.SparkXGBClassifier.weight_col` parameter instead of setting
|
||||
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBClassifier`
|
||||
fit method.
|
||||
xgb_model:
|
||||
Set the value to be the instance returned by
|
||||
:func:`xgboost.spark.SparkXGBClassifierModel.get_booster`.
|
||||
num_workers:
|
||||
Integer that specifies the number of XGBoost workers to use.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean that specifies whether the executors are running on GPU
|
||||
instances.
|
||||
base_margin_col:
|
||||
To specify the base margins of the training and validation
|
||||
dataset, set :py:attr:`xgboost.spark.SparkXGBClassifier.base_margin_col` parameter
|
||||
instead of setting `base_margin` and `base_margin_eval_set` in the
|
||||
`xgboost.XGBClassifier` fit method. Note: this isn't available for distributed
|
||||
training.
|
||||
`xgboost.XGBClassifier` fit method.
|
||||
|
||||
.. Note:: The Parameters chart above contains parameters that need special handling.
|
||||
For a full list of parameters, see entries with `Param(parent=...` below.
|
||||
num_workers:
|
||||
How many XGBoost workers to be used to train.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean value to specify whether the executors are running on GPU
|
||||
instances.
|
||||
force_repartition:
|
||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||
before XGBoost training.
|
||||
repartition_random_shuffle:
|
||||
Boolean value to specify if randomly shuffling the dataset when repartitioning is required.
|
||||
enable_sparse_data_optim:
|
||||
Boolean value to specify if enabling sparse data optimization, if True,
|
||||
Xgboost DMatrix object will be constructed from sparse matrix instead of
|
||||
dense matrix.
|
||||
|
||||
.. Note:: This API is experimental.
|
||||
kwargs:
|
||||
A dictionary of xgboost parameters, please refer to
|
||||
https://xgboost.readthedocs.io/en/stable/parameter.html
|
||||
|
||||
Note
|
||||
----
|
||||
|
||||
The Parameters chart above contains parameters that need special handling.
|
||||
For a full list of parameters, see entries with `Param(parent=...` below.
|
||||
|
||||
This API is experimental.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@ -281,14 +347,34 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
@keyword_only
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
features_col: Union[str, List[str]] = "features",
|
||||
label_col: str = "label",
|
||||
prediction_col: str = "prediction",
|
||||
probability_col: str = "probability",
|
||||
raw_prediction_col: str = "rawPrediction",
|
||||
pred_contrib_col: Optional[str] = None,
|
||||
validation_indicator_col: Optional[str] = None,
|
||||
weight_col: Optional[str] = None,
|
||||
base_margin_col: Optional[str] = None,
|
||||
num_workers: int = 1,
|
||||
use_gpu: bool = False,
|
||||
force_repartition: bool = False,
|
||||
repartition_random_shuffle: bool = False,
|
||||
enable_sparse_data_optim: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
|
||||
# but in pyspark we will automatically set objective param depending on
|
||||
# binary or multinomial input dataset, and we need to remove the fixed default
|
||||
# param value as well to avoid causing ambiguity.
|
||||
input_kwargs = self._input_kwargs
|
||||
self.setParams(**input_kwargs)
|
||||
self._setDefault(objective=None)
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls) -> Type[XGBClassifier]:
|
||||
@ -334,8 +420,8 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
:py:class:`~pyspark.ml.classification.OneVsRest`
|
||||
|
||||
SparkXGBRanker automatically supports most of the parameters in
|
||||
`xgboost.XGBRanker` constructor and most of the parameters used in
|
||||
:py:class:`xgboost.XGBRanker` fit and predict method.
|
||||
:py:class:`xgboost.XGBRanker` constructor and most of the parameters used in
|
||||
:py:meth:`xgboost.XGBRanker.fit` and :py:meth:`xgboost.XGBRanker.predict` method.
|
||||
|
||||
SparkXGBRanker doesn't support setting `gpu_id` but support another param `use_gpu`,
|
||||
see doc below for more details.
|
||||
@ -355,39 +441,53 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
Parameters
|
||||
----------
|
||||
|
||||
callbacks:
|
||||
The export and import of the callback functions are at best effort. For
|
||||
details, see :py:attr:`xgboost.spark.SparkXGBRanker.callbacks` param doc.
|
||||
features_col:
|
||||
When the value is string, it requires the features column name to be vector type.
|
||||
When the value is a list of string, it requires all the feature columns to be numeric types.
|
||||
label_col:
|
||||
Label column name. Default to "label".
|
||||
prediction_col:
|
||||
Prediction column name. Default to "prediction"
|
||||
pred_contrib_col:
|
||||
Contribution prediction column name.
|
||||
validation_indicator_col:
|
||||
For params related to `xgboost.XGBRanker` training with
|
||||
evaluation dataset's supervision,
|
||||
set :py:attr:`xgboost.spark.XGBRanker.validation_indicator_col`
|
||||
parameter instead of setting the `eval_set` parameter in `xgboost.XGBRanker`
|
||||
set :py:attr:`xgboost.spark.SparkXGBRanker.validation_indicator_col`
|
||||
parameter instead of setting the `eval_set` parameter in :py:class:`xgboost.XGBRanker`
|
||||
fit method.
|
||||
weight_col:
|
||||
To specify the weight of the training and validation dataset, set
|
||||
:py:attr:`xgboost.spark.SparkXGBRanker.weight_col` parameter instead of setting
|
||||
`sample_weight` and `sample_weight_eval_set` parameter in `xgboost.XGBRanker`
|
||||
`sample_weight` and `sample_weight_eval_set` parameter in :py:class:`xgboost.XGBRanker`
|
||||
fit method.
|
||||
xgb_model:
|
||||
Set the value to be the instance returned by
|
||||
:func:`xgboost.spark.SparkXGBRankerModel.get_booster`.
|
||||
num_workers:
|
||||
Integer that specifies the number of XGBoost workers to use.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean that specifies whether the executors are running on GPU
|
||||
instances.
|
||||
base_margin_col:
|
||||
To specify the base margins of the training and validation
|
||||
dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.base_margin_col` parameter
|
||||
instead of setting `base_margin` and `base_margin_eval_set` in the
|
||||
`xgboost.XGBRanker` fit method.
|
||||
:py:class:`xgboost.XGBRanker` fit method.
|
||||
qid_col:
|
||||
To specify the qid of the training and validation
|
||||
dataset, set :py:attr:`xgboost.spark.SparkXGBRanker.qid_col` parameter
|
||||
instead of setting `qid` / `group`, `eval_qid` / `eval_group` in the
|
||||
`xgboost.XGBRanker` fit method.
|
||||
Query id column name.
|
||||
|
||||
num_workers:
|
||||
How many XGBoost workers to be used to train.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean value to specify whether the executors are running on GPU
|
||||
instances.
|
||||
force_repartition:
|
||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||
before XGBoost training.
|
||||
repartition_random_shuffle:
|
||||
Boolean value to specify if randomly shuffling the dataset when repartitioning is required.
|
||||
enable_sparse_data_optim:
|
||||
Boolean value to specify if enabling sparse data optimization, if True,
|
||||
Xgboost DMatrix object will be constructed from sparse matrix instead of
|
||||
dense matrix.
|
||||
|
||||
kwargs:
|
||||
A dictionary of xgboost parameters, please refer to
|
||||
https://xgboost.readthedocs.io/en/stable/parameter.html
|
||||
|
||||
.. Note:: The Parameters chart above contains parameters that need special handling.
|
||||
For a full list of parameters, see entries with `Param(parent=...` below.
|
||||
@ -426,9 +526,28 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
>>> model.transform(df_test).show()
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
@keyword_only
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
features_col: Union[str, List[str]] = "features",
|
||||
label_col: str = "label",
|
||||
prediction_col: str = "prediction",
|
||||
pred_contrib_col: Optional[str] = None,
|
||||
validation_indicator_col: Optional[str] = None,
|
||||
weight_col: Optional[str] = None,
|
||||
base_margin_col: Optional[str] = None,
|
||||
qid_col: Optional[str] = None,
|
||||
num_workers: int = 1,
|
||||
use_gpu: bool = False,
|
||||
force_repartition: bool = False,
|
||||
repartition_random_shuffle: bool = False,
|
||||
enable_sparse_data_optim: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.setParams(**kwargs)
|
||||
input_kwargs = self._input_kwargs
|
||||
self.setParams(**input_kwargs)
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls) -> Type[XGBRanker]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user