[pyspark] Handle the device parameter in pyspark. (#9390)
- Handle the new `device` parameter in PySpark. - Deprecate the old `use_gpu` parameter.
This commit is contained in:
parent
2a0ff209ff
commit
6e18d3a290
@ -35,13 +35,13 @@ We can create a ``SparkXGBRegressor`` estimator like:
|
||||
)
|
||||
|
||||
|
||||
The above snippet creates a spark estimator which can fit on a spark dataset,
|
||||
and return a spark model that can transform a spark dataset and generate dataset
|
||||
with prediction column. We can set almost all of xgboost sklearn estimator parameters
|
||||
as ``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden
|
||||
in spark estimator, and some parameters are replaced with pyspark specific parameters
|
||||
such as ``weight_col``, ``validation_indicator_col``, ``use_gpu``, for details please see
|
||||
``SparkXGBRegressor`` doc.
|
||||
The above snippet creates a spark estimator which can fit on a spark dataset, and return a
|
||||
spark model that can transform a spark dataset and generate dataset with prediction
|
||||
column. We can set almost all of xgboost sklearn estimator parameters as
|
||||
``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden in
|
||||
spark estimator, and some parameters are replaced with pyspark specific parameters such as
|
||||
``weight_col``, ``validation_indicator_col``, for details please see ``SparkXGBRegressor``
|
||||
doc.
|
||||
|
||||
The following code snippet shows how to train a spark xgboost regressor model,
|
||||
first we need to prepare a training dataset as a spark dataframe contains
|
||||
@ -88,7 +88,7 @@ XGBoost PySpark fully supports GPU acceleration. Users are not only able to enab
|
||||
efficient training but also utilize their GPUs for the whole PySpark pipeline including
|
||||
ETL and inference. In below sections, we will walk through an example of training on a
|
||||
PySpark standalone GPU cluster. To get started, first we need to install some additional
|
||||
packages, then we can set the ``use_gpu`` parameter to ``True``.
|
||||
packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
|
||||
|
||||
Prepare the necessary packages
|
||||
==============================
|
||||
@ -128,7 +128,7 @@ Write your PySpark application
|
||||
==============================
|
||||
|
||||
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
|
||||
using a list of feature names and the additional parameter ``use_gpu``:
|
||||
using a list of feature names and the additional parameter ``device``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -148,12 +148,12 @@ using a list of feature names and the additional parameter ``use_gpu``:
|
||||
# get a list with feature column names
|
||||
feature_names = [x.name for x in train_df.schema if x.name != label_name]
|
||||
|
||||
# create a xgboost pyspark regressor estimator and set use_gpu=True
|
||||
# create a xgboost pyspark regressor estimator and set device="cuda"
|
||||
regressor = SparkXGBRegressor(
|
||||
features_col=feature_names,
|
||||
label_col=label_name,
|
||||
num_workers=2,
|
||||
use_gpu=True,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# train and return the model
|
||||
@ -163,6 +163,7 @@ using a list of feature names and the additional parameter ``use_gpu``:
|
||||
predict_df = model.transform(test_df)
|
||||
predict_df.show()
|
||||
|
||||
Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
|
||||
|
||||
Submit the PySpark application
|
||||
==============================
|
||||
|
||||
@ -276,6 +276,27 @@ def _check_call(ret: int) -> None:
|
||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||
|
||||
|
||||
def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
|
||||
"""Validate parameters in distributed environments."""
|
||||
device = kwargs.get("device", None)
|
||||
if device and not isinstance(device, str):
|
||||
msg = "Invalid type for the `device` parameter"
|
||||
msg += _expect((str,), type(device))
|
||||
raise TypeError(msg)
|
||||
|
||||
if device and device.find(":") != -1:
|
||||
raise ValueError(
|
||||
"Distributed training doesn't support selecting device ordinal as GPUs are"
|
||||
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
|
||||
" instead."
|
||||
)
|
||||
|
||||
if kwargs.get("booster", None) == "gblinear":
|
||||
raise NotImplementedError(
|
||||
f"booster `{kwargs['booster']}` is not supported for distributed training."
|
||||
)
|
||||
|
||||
|
||||
def build_info() -> dict:
|
||||
"""Build information of XGBoost. The returned value format is not stable. Also,
|
||||
please note that build time dependency is not the same as runtime dependency. For
|
||||
|
||||
@ -70,6 +70,7 @@ from .core import (
|
||||
Metric,
|
||||
Objective,
|
||||
QuantileDMatrix,
|
||||
_check_distributed_params,
|
||||
_deprecate_positional_args,
|
||||
_expect,
|
||||
)
|
||||
@ -924,17 +925,7 @@ async def _train_async(
|
||||
) -> Optional[TrainReturnT]:
|
||||
workers = _get_workers_from_data(dtrain, evals)
|
||||
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
|
||||
|
||||
if params.get("booster", None) == "gblinear":
|
||||
raise NotImplementedError(
|
||||
f"booster `{params['booster']}` is not yet supported for dask."
|
||||
)
|
||||
device = params.get("device", None)
|
||||
if device and device.find(":") != -1:
|
||||
raise ValueError(
|
||||
"The dask interface for XGBoost doesn't support selecting specific device"
|
||||
" ordinal. Use `device=cpu` or `device=cuda` instead."
|
||||
)
|
||||
_check_distributed_params(params)
|
||||
|
||||
def dispatched_train(
|
||||
parameters: Dict,
|
||||
|
||||
@ -1004,13 +1004,17 @@ class XGBModel(XGBModelBase):
|
||||
Validation metrics will help us track the performance of the model.
|
||||
|
||||
eval_metric : str, list of str, or callable, optional
|
||||
|
||||
.. deprecated:: 1.6.0
|
||||
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
||||
|
||||
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
||||
|
||||
early_stopping_rounds : int
|
||||
|
||||
.. deprecated:: 1.6.0
|
||||
Use `early_stopping_rounds` in :py:meth:`__init__` or
|
||||
:py:meth:`set_params` instead.
|
||||
|
||||
Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params`
|
||||
instead.
|
||||
verbose :
|
||||
If `verbose` is True and an evaluation set is used, the evaluation metric
|
||||
measured on the validation set is printed to stdout at each boosting stage.
|
||||
|
||||
@ -60,7 +60,7 @@ from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
||||
import xgboost
|
||||
from xgboost import XGBClassifier
|
||||
from xgboost.compat import is_cudf_available
|
||||
from xgboost.core import Booster
|
||||
from xgboost.core import Booster, _check_distributed_params
|
||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
@ -92,6 +92,7 @@ from .utils import (
|
||||
get_class_name,
|
||||
get_logger,
|
||||
serialize_booster,
|
||||
use_cuda,
|
||||
)
|
||||
|
||||
# Put pyspark specific params here, they won't be passed to XGBoost.
|
||||
@ -108,7 +109,6 @@ _pyspark_specific_params = [
|
||||
"arbitrary_params_dict",
|
||||
"force_repartition",
|
||||
"num_workers",
|
||||
"use_gpu",
|
||||
"feature_names",
|
||||
"features_cols",
|
||||
"enable_sparse_data_optim",
|
||||
@ -132,8 +132,7 @@ _pyspark_param_alias_map = {
|
||||
_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()}
|
||||
|
||||
_unsupported_xgb_params = [
|
||||
"gpu_id", # we have "use_gpu" pyspark param instead.
|
||||
"device", # we have "use_gpu" pyspark param instead.
|
||||
"gpu_id", # we have "device" pyspark param instead.
|
||||
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
||||
"use_label_encoder",
|
||||
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
||||
@ -198,11 +197,24 @@ class _SparkXGBParams(
|
||||
"The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.",
|
||||
TypeConverters.toInt,
|
||||
)
|
||||
device = Param(
|
||||
Params._dummy(),
|
||||
"device",
|
||||
(
|
||||
"The device type for XGBoost executors. Available options are `cpu`,`cuda`"
|
||||
" and `gpu`. Set `device` to `cuda` or `gpu` if the executors are running "
|
||||
"on GPU instances. Currently, only one GPU per task is supported."
|
||||
),
|
||||
TypeConverters.toString,
|
||||
)
|
||||
use_gpu = Param(
|
||||
Params._dummy(),
|
||||
"use_gpu",
|
||||
"A boolean variable. Set use_gpu=true if the executors "
|
||||
+ "are running on GPU instances. Currently, only one GPU per task is supported.",
|
||||
(
|
||||
"Deprecated, use `device` instead. A boolean variable. Set use_gpu=true "
|
||||
"if the executors are running on GPU instances. Currently, only one GPU per"
|
||||
" task is supported."
|
||||
),
|
||||
TypeConverters.toBoolean,
|
||||
)
|
||||
force_repartition = Param(
|
||||
@ -336,10 +348,20 @@ class _SparkXGBParams(
|
||||
f"It cannot be less than 1 [Default is 1]"
|
||||
)
|
||||
|
||||
tree_method = self.getOrDefault(self.getParam("tree_method"))
|
||||
if (
|
||||
self.getOrDefault(self.use_gpu) or use_cuda(self.getOrDefault(self.device))
|
||||
) and not _can_use_qdm(tree_method):
|
||||
raise ValueError(
|
||||
f"The `{tree_method}` tree method is not supported on GPU."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.features_cols):
|
||||
if not self.getOrDefault(self.use_gpu):
|
||||
if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault(
|
||||
self.use_gpu
|
||||
):
|
||||
raise ValueError(
|
||||
"features_col param with list value requires enabling use_gpu."
|
||||
"features_col param with list value requires `device=cuda`."
|
||||
)
|
||||
|
||||
if self.getOrDefault("objective") is not None:
|
||||
@ -392,17 +414,7 @@ class _SparkXGBParams(
|
||||
"`pyspark.ml.linalg.Vector` type."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.use_gpu):
|
||||
tree_method = self.getParam("tree_method")
|
||||
if (
|
||||
self.getOrDefault(tree_method) is not None
|
||||
and self.getOrDefault(tree_method) != "gpu_hist"
|
||||
):
|
||||
raise ValueError(
|
||||
f"tree_method should be 'gpu_hist' or None when use_gpu is True,"
|
||||
f"found {self.getOrDefault(tree_method)}."
|
||||
)
|
||||
|
||||
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
|
||||
gpu_per_task = (
|
||||
_get_spark_session()
|
||||
.sparkContext.getConf()
|
||||
@ -424,8 +436,8 @@ class _SparkXGBParams(
|
||||
# so it's okay for printing the below warning instead of checking the real
|
||||
# gpu numbers and raising the exception.
|
||||
get_logger(self.__class__.__name__).warning(
|
||||
"You enabled use_gpu in spark local mode. Please make sure your local node "
|
||||
"has at least %d GPUs",
|
||||
"You enabled GPU in spark local mode. Please make sure your local "
|
||||
"node has at least %d GPUs",
|
||||
self.getOrDefault(self.num_workers),
|
||||
)
|
||||
else:
|
||||
@ -558,6 +570,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
# they are added in `setParams`.
|
||||
self._setDefault(
|
||||
num_workers=1,
|
||||
device="cpu",
|
||||
use_gpu=False,
|
||||
force_repartition=False,
|
||||
repartition_random_shuffle=False,
|
||||
@ -566,9 +579,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
arbitrary_params_dict={},
|
||||
)
|
||||
|
||||
def setParams(
|
||||
self, **kwargs: Dict[str, Any]
|
||||
) -> None: # pylint: disable=invalid-name
|
||||
def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name
|
||||
"""
|
||||
Set params for the estimator.
|
||||
"""
|
||||
@ -613,6 +624,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
raise ValueError(err_msg)
|
||||
_extra_params[k] = v
|
||||
|
||||
_check_distributed_params(kwargs)
|
||||
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
|
||||
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
||||
|
||||
@ -709,9 +722,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
# TODO: support "num_parallel_tree" for random forest
|
||||
params["num_boost_round"] = self.getOrDefault("n_estimators")
|
||||
|
||||
if self.getOrDefault(self.use_gpu):
|
||||
params["tree_method"] = "gpu_hist"
|
||||
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
@ -883,8 +893,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
dmatrix_kwargs,
|
||||
) = self._get_xgb_parameters(dataset)
|
||||
|
||||
use_gpu = self.getOrDefault(self.use_gpu)
|
||||
|
||||
run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(
|
||||
self.use_gpu
|
||||
)
|
||||
is_local = _is_local(_get_spark_session().sparkContext)
|
||||
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
@ -903,7 +914,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
dev_ordinal = None
|
||||
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
||||
|
||||
if use_gpu:
|
||||
if run_on_gpu:
|
||||
dev_ordinal = (
|
||||
context.partitionId() if is_local else _get_gpu_id(context)
|
||||
)
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||
# pylint: disable=unused-argument, too-many-locals
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from pyspark import keyword_only
|
||||
@ -77,27 +77,35 @@ def _set_pyspark_xgb_cls_param_attrs(
|
||||
set_param_attrs(name, param_obj)
|
||||
|
||||
|
||||
def _deprecated_use_gpu() -> None:
|
||||
warnings.warn(
|
||||
"`use_gpu` is deprecated since 2.0.0, use `device` instead", FutureWarning
|
||||
)
|
||||
|
||||
|
||||
class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
"""SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression
|
||||
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
|
||||
and PySpark ML meta algorithms like :py:class:`~pyspark.ml.tuning.CrossValidator`/
|
||||
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
||||
:py:class:`~pyspark.ml.classification.OneVsRest`
|
||||
and PySpark ML meta algorithms like
|
||||
- :py:class:`~pyspark.ml.tuning.CrossValidator`/
|
||||
- :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
||||
- :py:class:`~pyspark.ml.classification.OneVsRest`
|
||||
|
||||
SparkXGBRegressor automatically supports most of the parameters in
|
||||
:py:class:`xgboost.XGBRegressor` constructor and most of the parameters used in
|
||||
:py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict` method.
|
||||
:py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict`
|
||||
method.
|
||||
|
||||
SparkXGBRegressor doesn't support setting `device` but supports another param
|
||||
`use_gpu`, see doc below for more details.
|
||||
To enable GPU support, set `device` to `cuda` or `gpu`.
|
||||
|
||||
SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support
|
||||
another param called `base_margin_col`. see doc below for more details.
|
||||
SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but
|
||||
support another param called `base_margin_col`. see doc below for more details.
|
||||
|
||||
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
|
||||
|
||||
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
||||
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
||||
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the
|
||||
`nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
|
||||
config value.
|
||||
|
||||
|
||||
Parameters
|
||||
@ -133,8 +141,11 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
How many XGBoost workers to be used to train.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean value to specify whether the executors are running on GPU
|
||||
instances.
|
||||
.. deprecated:: 2.0.0
|
||||
|
||||
Use `device` instead.
|
||||
device:
|
||||
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
|
||||
force_repartition:
|
||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||
before XGBoost training.
|
||||
@ -193,14 +204,17 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
||||
weight_col: Optional[str] = None,
|
||||
base_margin_col: Optional[str] = None,
|
||||
num_workers: int = 1,
|
||||
use_gpu: bool = False,
|
||||
use_gpu: Optional[bool] = None,
|
||||
device: Optional[str] = None,
|
||||
force_repartition: bool = False,
|
||||
repartition_random_shuffle: bool = False,
|
||||
enable_sparse_data_optim: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
input_kwargs = self._input_kwargs
|
||||
if use_gpu:
|
||||
_deprecated_use_gpu()
|
||||
self.setParams(**input_kwargs)
|
||||
|
||||
@classmethod
|
||||
@ -238,27 +252,29 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
"""SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost
|
||||
classification algorithm based on XGBoost python library, and it can be used in
|
||||
PySpark Pipeline and PySpark ML meta algorithms like
|
||||
:py:class:`~pyspark.ml.tuning.CrossValidator`/
|
||||
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
||||
:py:class:`~pyspark.ml.classification.OneVsRest`
|
||||
- :py:class:`~pyspark.ml.tuning.CrossValidator`/
|
||||
- :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
||||
- :py:class:`~pyspark.ml.classification.OneVsRest`
|
||||
|
||||
SparkXGBClassifier automatically supports most of the parameters in
|
||||
:py:class:`xgboost.XGBClassifier` constructor and most of the parameters used in
|
||||
:py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict` method.
|
||||
:py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict`
|
||||
method.
|
||||
|
||||
SparkXGBClassifier doesn't support setting `device` but support another param
|
||||
`use_gpu`, see doc below for more details.
|
||||
To enable GPU support, set `device` to `cuda` or `gpu`.
|
||||
|
||||
SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but support
|
||||
another param called `base_margin_col`. see doc below for more details.
|
||||
SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but
|
||||
support another param called `base_margin_col`. see doc below for more details.
|
||||
|
||||
SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin
|
||||
from the raw prediction column. See `raw_prediction_col` param doc below for more details.
|
||||
SparkXGBClassifier doesn't support setting `output_margin`, but we can get output
|
||||
margin from the raw prediction column. See `raw_prediction_col` param doc below for
|
||||
more details.
|
||||
|
||||
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
|
||||
|
||||
SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
||||
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
||||
SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the
|
||||
`nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
|
||||
config value.
|
||||
|
||||
|
||||
Parameters
|
||||
@ -300,8 +316,11 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
How many XGBoost workers to be used to train.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean value to specify whether the executors are running on GPU
|
||||
instances.
|
||||
.. deprecated:: 2.0.0
|
||||
|
||||
Use `device` instead.
|
||||
device:
|
||||
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
|
||||
force_repartition:
|
||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||
before XGBoost training.
|
||||
@ -360,11 +379,12 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
weight_col: Optional[str] = None,
|
||||
base_margin_col: Optional[str] = None,
|
||||
num_workers: int = 1,
|
||||
use_gpu: bool = False,
|
||||
use_gpu: Optional[bool] = None,
|
||||
device: Optional[str] = None,
|
||||
force_repartition: bool = False,
|
||||
repartition_random_shuffle: bool = False,
|
||||
enable_sparse_data_optim: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
|
||||
@ -372,6 +392,8 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
||||
# binary or multinomial input dataset, and we need to remove the fixed default
|
||||
# param value as well to avoid causing ambiguity.
|
||||
input_kwargs = self._input_kwargs
|
||||
if use_gpu:
|
||||
_deprecated_use_gpu()
|
||||
self.setParams(**input_kwargs)
|
||||
self._setDefault(objective=None)
|
||||
|
||||
@ -422,19 +444,20 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
:py:class:`xgboost.XGBRanker` constructor and most of the parameters used in
|
||||
:py:meth:`xgboost.XGBRanker.fit` and :py:meth:`xgboost.XGBRanker.predict` method.
|
||||
|
||||
SparkXGBRanker doesn't support setting `device` but support another param `use_gpu`,
|
||||
see doc below for more details.
|
||||
To enable GPU support, set `device` to `cuda` or `gpu`.
|
||||
|
||||
SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support
|
||||
another param called `base_margin_col`. see doc below for more details.
|
||||
|
||||
SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin
|
||||
from the raw prediction column. See `raw_prediction_col` param doc below for more details.
|
||||
from the raw prediction column. See `raw_prediction_col` param doc below for more
|
||||
details.
|
||||
|
||||
SparkXGBRanker doesn't support `validate_features` and `output_margin` param.
|
||||
|
||||
SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
||||
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
||||
SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the
|
||||
`nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
|
||||
config value.
|
||||
|
||||
|
||||
Parameters
|
||||
@ -467,13 +490,15 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
:py:class:`xgboost.XGBRanker` fit method.
|
||||
qid_col:
|
||||
Query id column name.
|
||||
|
||||
num_workers:
|
||||
How many XGBoost workers to be used to train.
|
||||
Each XGBoost worker corresponds to one spark task.
|
||||
use_gpu:
|
||||
Boolean value to specify whether the executors are running on GPU
|
||||
instances.
|
||||
.. deprecated:: 2.0.0
|
||||
|
||||
Use `device` instead.
|
||||
device:
|
||||
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
|
||||
force_repartition:
|
||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||
before XGBoost training.
|
||||
@ -538,14 +563,17 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
||||
base_margin_col: Optional[str] = None,
|
||||
qid_col: Optional[str] = None,
|
||||
num_workers: int = 1,
|
||||
use_gpu: bool = False,
|
||||
use_gpu: Optional[bool] = None,
|
||||
device: Optional[str] = None,
|
||||
force_repartition: bool = False,
|
||||
repartition_random_shuffle: bool = False,
|
||||
enable_sparse_data_optim: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
input_kwargs = self._input_kwargs
|
||||
if use_gpu:
|
||||
_deprecated_use_gpu()
|
||||
self.setParams(**input_kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import sys
|
||||
import uuid
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, Set, Type
|
||||
from typing import Any, Callable, Dict, Optional, Set, Type
|
||||
|
||||
import pyspark
|
||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles
|
||||
@ -186,3 +186,8 @@ def deserialize_booster(model: str) -> Booster:
|
||||
f.write(model)
|
||||
booster.load_model(tmp_file_name)
|
||||
return booster
|
||||
|
||||
|
||||
def use_cuda(device: Optional[str]) -> bool:
|
||||
"""Whether xgboost is using CUDA workers."""
|
||||
return device in ("cuda", "gpu")
|
||||
|
||||
@ -98,8 +98,8 @@ void MismatchedDevices(Context const* booster, Context const* data) {
|
||||
- Use a data structure that matches the device ordinal in the booster.
|
||||
- Set the device for booster before call to inplace_predict.
|
||||
|
||||
This warning will only be shown once, and subsequent warnings made by the current thread will be
|
||||
suppressed.
|
||||
This warning will only be shown once for each thread. Subsequent warnings made by the
|
||||
current thread will be suppressed.
|
||||
)";
|
||||
logged = true;
|
||||
}
|
||||
|
||||
@ -154,7 +154,7 @@ def spark_diabetes_dataset_feature_cols(spark_session_with_gpu):
|
||||
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
|
||||
classifier = SparkXGBClassifier(use_gpu=True, num_workers=num_workers)
|
||||
classifier = SparkXGBClassifier(device="cuda", num_workers=num_workers)
|
||||
train_df, test_df = spark_iris_dataset
|
||||
model = classifier.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
@ -169,7 +169,7 @@ def test_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_co
|
||||
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||
features_col=feature_names, device="cuda", num_workers=num_workers
|
||||
)
|
||||
|
||||
model = classifier.fit(train_df)
|
||||
@ -185,7 +185,7 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature
|
||||
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||
features_col=feature_names, device="cuda", num_workers=num_workers
|
||||
)
|
||||
grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build()
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
@ -197,11 +197,24 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
assert f1 >= 0.97
|
||||
|
||||
clf = SparkXGBClassifier(
|
||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||
)
|
||||
grid = ParamGridBuilder().addGrid(clf.max_depth, [6, 8]).build()
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
cv = CrossValidator(
|
||||
estimator=clf, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3
|
||||
)
|
||||
cvModel = cv.fit(train_df)
|
||||
pred_result_df = cvModel.transform(test_df)
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
assert f1 >= 0.97
|
||||
|
||||
|
||||
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
|
||||
from pyspark.ml.evaluation import RegressionEvaluator
|
||||
|
||||
regressor = SparkXGBRegressor(use_gpu=True, num_workers=num_workers)
|
||||
regressor = SparkXGBRegressor(device="cuda", num_workers=num_workers)
|
||||
train_df, test_df = spark_diabetes_dataset
|
||||
model = regressor.fit(train_df)
|
||||
pred_result_df = model.transform(test_df)
|
||||
@ -215,7 +228,7 @@ def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature
|
||||
|
||||
train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols
|
||||
regressor = SparkXGBRegressor(
|
||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||
features_col=feature_names, device="cuda", num_workers=num_workers
|
||||
)
|
||||
|
||||
model = regressor.fit(train_df)
|
||||
|
||||
@ -741,11 +741,6 @@ class TestPySparkLocal:
|
||||
with pytest.raises(ValueError, match="early_stopping_rounds"):
|
||||
classifier.fit(clf_data.cls_df_train)
|
||||
|
||||
def test_gpu_param_setting(self, clf_data: ClfData) -> None:
|
||||
py_cls = SparkXGBClassifier(use_gpu=True)
|
||||
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
|
||||
assert train_params["tree_method"] == "gpu_hist"
|
||||
|
||||
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
|
||||
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
|
||||
model = classifier.fit(clf_data.cls_df_train)
|
||||
@ -756,6 +751,53 @@ class TestPySparkLocal:
|
||||
model = classifier.fit(clf_data.cls_df_train)
|
||||
model.transform(clf_data.cls_df_test).collect()
|
||||
|
||||
def test_regressor_params_basic(self) -> None:
|
||||
py_reg = SparkXGBRegressor()
|
||||
assert hasattr(py_reg, "n_estimators")
|
||||
assert py_reg.n_estimators.parent == py_reg.uid
|
||||
assert not hasattr(py_reg, "gpu_id")
|
||||
assert hasattr(py_reg, "device")
|
||||
assert py_reg.getOrDefault(py_reg.n_estimators) == 100
|
||||
assert py_reg.getOrDefault(getattr(py_reg, "objective")), "reg:squarederror"
|
||||
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
||||
assert py_reg2.getOrDefault(getattr(py_reg2, "n_estimators")), 200
|
||||
py_reg3 = py_reg2.copy({getattr(py_reg2, "max_depth"): 10})
|
||||
assert py_reg3.getOrDefault(getattr(py_reg3, "n_estimators")), 200
|
||||
assert py_reg3.getOrDefault(getattr(py_reg3, "max_depth")), 10
|
||||
|
||||
def test_classifier_params_basic(self) -> None:
|
||||
py_clf = SparkXGBClassifier()
|
||||
assert hasattr(py_clf, "n_estimators")
|
||||
assert py_clf.n_estimators.parent == py_clf.uid
|
||||
assert not hasattr(py_clf, "gpu_id")
|
||||
assert hasattr(py_clf, "device")
|
||||
|
||||
assert py_clf.getOrDefault(py_clf.n_estimators) == 100
|
||||
assert py_clf.getOrDefault(getattr(py_clf, "objective")) is None
|
||||
py_clf2 = SparkXGBClassifier(n_estimators=200)
|
||||
assert py_clf2.getOrDefault(getattr(py_clf2, "n_estimators")) == 200
|
||||
py_clf3 = py_clf2.copy({getattr(py_clf2, "max_depth"): 10})
|
||||
assert py_clf3.getOrDefault(getattr(py_clf3, "n_estimators")) == 200
|
||||
assert py_clf3.getOrDefault(getattr(py_clf3, "max_depth")), 10
|
||||
|
||||
def test_classifier_kwargs_basic(self, clf_data: ClfData) -> None:
|
||||
py_clf = SparkXGBClassifier(**clf_data.cls_params)
|
||||
assert hasattr(py_clf, "n_estimators")
|
||||
assert py_clf.n_estimators.parent == py_clf.uid
|
||||
assert not hasattr(py_clf, "gpu_id")
|
||||
assert hasattr(py_clf, "device")
|
||||
assert hasattr(py_clf, "arbitrary_params_dict")
|
||||
assert py_clf.getOrDefault(py_clf.arbitrary_params_dict) == {}
|
||||
|
||||
# Testing overwritten params
|
||||
py_clf = SparkXGBClassifier()
|
||||
py_clf.setParams(x=1, y=2)
|
||||
py_clf.setParams(y=3, z=4)
|
||||
xgb_params = py_clf._gen_xgb_params_dict()
|
||||
assert xgb_params["x"] == 1
|
||||
assert xgb_params["y"] == 3
|
||||
assert xgb_params["z"] == 4
|
||||
|
||||
def test_regressor_model_save_load(self, reg_data: RegData) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = "file:" + tmpdir
|
||||
@ -826,6 +868,24 @@ class TestPySparkLocal:
|
||||
)
|
||||
assert_model_compatible(model.stages[0], tmpdir)
|
||||
|
||||
def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None:
|
||||
clf = SparkXGBClassifier(device="cuda", tree_method="exact")
|
||||
with pytest.raises(ValueError, match="not supported on GPU"):
|
||||
clf.fit(clf_data.cls_df_train)
|
||||
regressor = SparkXGBRegressor(device="cuda", tree_method="exact")
|
||||
with pytest.raises(ValueError, match="not supported on GPU"):
|
||||
regressor.fit(reg_data.reg_df_train)
|
||||
|
||||
reg = SparkXGBRegressor(device="cuda", tree_method="gpu_hist")
|
||||
reg._validate_params()
|
||||
reg = SparkXGBRegressor(device="cuda")
|
||||
reg._validate_params()
|
||||
|
||||
clf = SparkXGBClassifier(device="cuda", tree_method="gpu_hist")
|
||||
clf._validate_params()
|
||||
clf = SparkXGBClassifier(device="cuda")
|
||||
clf._validate_params()
|
||||
|
||||
|
||||
class XgboostLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
@ -1020,55 +1080,6 @@ class XgboostLocalTest(SparkTestCase):
|
||||
assert sklearn_regressor.max_depth == 3
|
||||
assert sklearn_regressor.get_params()["sketch_eps"] == 0.5
|
||||
|
||||
def test_regressor_params_basic(self):
|
||||
py_reg = SparkXGBRegressor()
|
||||
self.assertTrue(hasattr(py_reg, "n_estimators"))
|
||||
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
|
||||
self.assertFalse(hasattr(py_reg, "gpu_id"))
|
||||
self.assertFalse(hasattr(py_reg, "device"))
|
||||
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
|
||||
self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror")
|
||||
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
||||
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
|
||||
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
|
||||
self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200)
|
||||
self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10)
|
||||
|
||||
def test_classifier_params_basic(self):
|
||||
py_cls = SparkXGBClassifier()
|
||||
self.assertTrue(hasattr(py_cls, "n_estimators"))
|
||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
||||
self.assertFalse(hasattr(py_cls, "device"))
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.objective), None)
|
||||
py_cls2 = SparkXGBClassifier(n_estimators=200)
|
||||
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
|
||||
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
|
||||
self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200)
|
||||
self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10)
|
||||
|
||||
def test_classifier_kwargs_basic(self):
|
||||
py_cls = SparkXGBClassifier(**self.cls_params_kwargs)
|
||||
self.assertTrue(hasattr(py_cls, "n_estimators"))
|
||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
||||
self.assertFalse(hasattr(py_cls, "device"))
|
||||
self.assertTrue(hasattr(py_cls, "arbitrary_params_dict"))
|
||||
expected_kwargs = {"sketch_eps": 0.03}
|
||||
self.assertEqual(
|
||||
py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs
|
||||
)
|
||||
|
||||
# Testing overwritten params
|
||||
py_cls = SparkXGBClassifier()
|
||||
py_cls.setParams(x=1, y=2)
|
||||
py_cls.setParams(y=3, z=4)
|
||||
xgb_params = py_cls._gen_xgb_params_dict()
|
||||
assert xgb_params["x"] == 1
|
||||
assert xgb_params["y"] == 3
|
||||
assert xgb_params["z"] == 4
|
||||
|
||||
def test_param_alias(self):
|
||||
py_cls = SparkXGBClassifier(features_col="f1", label_col="l1")
|
||||
self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1")
|
||||
@ -1200,16 +1211,6 @@ class XgboostLocalTest(SparkTestCase):
|
||||
classifier = SparkXGBClassifier(num_workers=0)
|
||||
self.assertRaises(ValueError, classifier._validate_params)
|
||||
|
||||
def test_use_gpu_param(self):
|
||||
classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact")
|
||||
self.assertRaises(ValueError, classifier._validate_params)
|
||||
regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact")
|
||||
self.assertRaises(ValueError, regressor._validate_params)
|
||||
regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist")
|
||||
regressor = SparkXGBRegressor(use_gpu=True)
|
||||
classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist")
|
||||
classifier = SparkXGBClassifier(use_gpu=True)
|
||||
|
||||
def test_feature_importances(self):
|
||||
reg1 = SparkXGBRegressor(**self.reg_params)
|
||||
model = reg1.fit(self.reg_df_train)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user