[pyspark] Handle the device parameter in pyspark. (#9390)

- Handle the new `device` parameter in PySpark.
- Deprecate the old `use_gpu` parameter.
This commit is contained in:
Jiaming Yuan 2023-07-18 08:47:03 +08:00 committed by GitHub
parent 2a0ff209ff
commit 6e18d3a290
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 244 additions and 169 deletions

View File

@ -35,13 +35,13 @@ We can create a ``SparkXGBRegressor`` estimator like:
) )
The above snippet creates a spark estimator which can fit on a spark dataset, The above snippet creates a spark estimator which can fit on a spark dataset, and return a
and return a spark model that can transform a spark dataset and generate dataset spark model that can transform a spark dataset and generate dataset with prediction
with prediction column. We can set almost all of xgboost sklearn estimator parameters column. We can set almost all of xgboost sklearn estimator parameters as
as ``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden ``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden in
in spark estimator, and some parameters are replaced with pyspark specific parameters spark estimator, and some parameters are replaced with pyspark specific parameters such as
such as ``weight_col``, ``validation_indicator_col``, ``use_gpu``, for details please see ``weight_col``, ``validation_indicator_col``, for details please see ``SparkXGBRegressor``
``SparkXGBRegressor`` doc. doc.
The following code snippet shows how to train a spark xgboost regressor model, The following code snippet shows how to train a spark xgboost regressor model,
first we need to prepare a training dataset as a spark dataframe contains first we need to prepare a training dataset as a spark dataframe contains
@ -88,7 +88,7 @@ XGBoost PySpark fully supports GPU acceleration. Users are not only able to enab
efficient training but also utilize their GPUs for the whole PySpark pipeline including efficient training but also utilize their GPUs for the whole PySpark pipeline including
ETL and inference. In below sections, we will walk through an example of training on a ETL and inference. In below sections, we will walk through an example of training on a
PySpark standalone GPU cluster. To get started, first we need to install some additional PySpark standalone GPU cluster. To get started, first we need to install some additional
packages, then we can set the ``use_gpu`` parameter to ``True``. packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
Prepare the necessary packages Prepare the necessary packages
============================== ==============================
@ -128,7 +128,7 @@ Write your PySpark application
============================== ==============================
Below snippet is a small example for training xgboost model with PySpark. Notice that we are Below snippet is a small example for training xgboost model with PySpark. Notice that we are
using a list of feature names and the additional parameter ``use_gpu``: using a list of feature names and the additional parameter ``device``:
.. code-block:: python .. code-block:: python
@ -148,12 +148,12 @@ using a list of feature names and the additional parameter ``use_gpu``:
# get a list with feature column names # get a list with feature column names
feature_names = [x.name for x in train_df.schema if x.name != label_name] feature_names = [x.name for x in train_df.schema if x.name != label_name]
# create a xgboost pyspark regressor estimator and set use_gpu=True # create a xgboost pyspark regressor estimator and set device="cuda"
regressor = SparkXGBRegressor( regressor = SparkXGBRegressor(
features_col=feature_names, features_col=feature_names,
label_col=label_name, label_col=label_name,
num_workers=2, num_workers=2,
use_gpu=True, device="cuda",
) )
# train and return the model # train and return the model
@ -163,6 +163,7 @@ using a list of feature names and the additional parameter ``use_gpu``:
predict_df = model.transform(test_df) predict_df = model.transform(test_df)
predict_df.show() predict_df.show()
Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
Submit the PySpark application Submit the PySpark application
============================== ==============================

View File

@ -276,6 +276,27 @@ def _check_call(ret: int) -> None:
raise XGBoostError(py_str(_LIB.XGBGetLastError())) raise XGBoostError(py_str(_LIB.XGBGetLastError()))
def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
"""Validate parameters in distributed environments."""
device = kwargs.get("device", None)
if device and not isinstance(device, str):
msg = "Invalid type for the `device` parameter"
msg += _expect((str,), type(device))
raise TypeError(msg)
if device and device.find(":") != -1:
raise ValueError(
"Distributed training doesn't support selecting device ordinal as GPUs are"
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
" instead."
)
if kwargs.get("booster", None) == "gblinear":
raise NotImplementedError(
f"booster `{kwargs['booster']}` is not supported for distributed training."
)
def build_info() -> dict: def build_info() -> dict:
"""Build information of XGBoost. The returned value format is not stable. Also, """Build information of XGBoost. The returned value format is not stable. Also,
please note that build time dependency is not the same as runtime dependency. For please note that build time dependency is not the same as runtime dependency. For

View File

@ -70,6 +70,7 @@ from .core import (
Metric, Metric,
Objective, Objective,
QuantileDMatrix, QuantileDMatrix,
_check_distributed_params,
_deprecate_positional_args, _deprecate_positional_args,
_expect, _expect,
) )
@ -924,17 +925,7 @@ async def _train_async(
) -> Optional[TrainReturnT]: ) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals) workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), dconfig, client) _rabit_args = await _get_rabit_args(len(workers), dconfig, client)
_check_distributed_params(params)
if params.get("booster", None) == "gblinear":
raise NotImplementedError(
f"booster `{params['booster']}` is not yet supported for dask."
)
device = params.get("device", None)
if device and device.find(":") != -1:
raise ValueError(
"The dask interface for XGBoost doesn't support selecting specific device"
" ordinal. Use `device=cpu` or `device=cuda` instead."
)
def dispatched_train( def dispatched_train(
parameters: Dict, parameters: Dict,

View File

@ -1004,13 +1004,17 @@ class XGBModel(XGBModelBase):
Validation metrics will help us track the performance of the model. Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional eval_metric : str, list of str, or callable, optional
.. deprecated:: 1.6.0 .. deprecated:: 1.6.0
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int early_stopping_rounds : int
.. deprecated:: 1.6.0 .. deprecated:: 1.6.0
Use `early_stopping_rounds` in :py:meth:`__init__` or
:py:meth:`set_params` instead. Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params`
instead.
verbose : verbose :
If `verbose` is True and an evaluation set is used, the evaluation metric If `verbose` is True and an evaluation set is used, the evaluation metric
measured on the validation set is printed to stdout at each boosting stage. measured on the validation set is printed to stdout at each boosting stage.

View File

@ -60,7 +60,7 @@ from scipy.special import expit, softmax # pylint: disable=no-name-in-module
import xgboost import xgboost
from xgboost import XGBClassifier from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available from xgboost.compat import is_cudf_available
from xgboost.core import Booster from xgboost.core import Booster, _check_distributed_params
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
from xgboost.training import train as worker_train from xgboost.training import train as worker_train
@ -92,6 +92,7 @@ from .utils import (
get_class_name, get_class_name,
get_logger, get_logger,
serialize_booster, serialize_booster,
use_cuda,
) )
# Put pyspark specific params here, they won't be passed to XGBoost. # Put pyspark specific params here, they won't be passed to XGBoost.
@ -108,7 +109,6 @@ _pyspark_specific_params = [
"arbitrary_params_dict", "arbitrary_params_dict",
"force_repartition", "force_repartition",
"num_workers", "num_workers",
"use_gpu",
"feature_names", "feature_names",
"features_cols", "features_cols",
"enable_sparse_data_optim", "enable_sparse_data_optim",
@ -132,8 +132,7 @@ _pyspark_param_alias_map = {
_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()} _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()}
_unsupported_xgb_params = [ _unsupported_xgb_params = [
"gpu_id", # we have "use_gpu" pyspark param instead. "gpu_id", # we have "device" pyspark param instead.
"device", # we have "use_gpu" pyspark param instead.
"enable_categorical", # Use feature_types param to specify categorical feature instead "enable_categorical", # Use feature_types param to specify categorical feature instead
"use_label_encoder", "use_label_encoder",
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead. "n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
@ -198,11 +197,24 @@ class _SparkXGBParams(
"The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.", "The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.",
TypeConverters.toInt, TypeConverters.toInt,
) )
device = Param(
Params._dummy(),
"device",
(
"The device type for XGBoost executors. Available options are `cpu`,`cuda`"
" and `gpu`. Set `device` to `cuda` or `gpu` if the executors are running "
"on GPU instances. Currently, only one GPU per task is supported."
),
TypeConverters.toString,
)
use_gpu = Param( use_gpu = Param(
Params._dummy(), Params._dummy(),
"use_gpu", "use_gpu",
"A boolean variable. Set use_gpu=true if the executors " (
+ "are running on GPU instances. Currently, only one GPU per task is supported.", "Deprecated, use `device` instead. A boolean variable. Set use_gpu=true "
"if the executors are running on GPU instances. Currently, only one GPU per"
" task is supported."
),
TypeConverters.toBoolean, TypeConverters.toBoolean,
) )
force_repartition = Param( force_repartition = Param(
@ -336,10 +348,20 @@ class _SparkXGBParams(
f"It cannot be less than 1 [Default is 1]" f"It cannot be less than 1 [Default is 1]"
) )
tree_method = self.getOrDefault(self.getParam("tree_method"))
if (
self.getOrDefault(self.use_gpu) or use_cuda(self.getOrDefault(self.device))
) and not _can_use_qdm(tree_method):
raise ValueError(
f"The `{tree_method}` tree method is not supported on GPU."
)
if self.getOrDefault(self.features_cols): if self.getOrDefault(self.features_cols):
if not self.getOrDefault(self.use_gpu): if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault(
self.use_gpu
):
raise ValueError( raise ValueError(
"features_col param with list value requires enabling use_gpu." "features_col param with list value requires `device=cuda`."
) )
if self.getOrDefault("objective") is not None: if self.getOrDefault("objective") is not None:
@ -392,17 +414,7 @@ class _SparkXGBParams(
"`pyspark.ml.linalg.Vector` type." "`pyspark.ml.linalg.Vector` type."
) )
if self.getOrDefault(self.use_gpu): if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
tree_method = self.getParam("tree_method")
if (
self.getOrDefault(tree_method) is not None
and self.getOrDefault(tree_method) != "gpu_hist"
):
raise ValueError(
f"tree_method should be 'gpu_hist' or None when use_gpu is True,"
f"found {self.getOrDefault(tree_method)}."
)
gpu_per_task = ( gpu_per_task = (
_get_spark_session() _get_spark_session()
.sparkContext.getConf() .sparkContext.getConf()
@ -424,8 +436,8 @@ class _SparkXGBParams(
# so it's okay for printing the below warning instead of checking the real # so it's okay for printing the below warning instead of checking the real
# gpu numbers and raising the exception. # gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning( get_logger(self.__class__.__name__).warning(
"You enabled use_gpu in spark local mode. Please make sure your local node " "You enabled GPU in spark local mode. Please make sure your local "
"has at least %d GPUs", "node has at least %d GPUs",
self.getOrDefault(self.num_workers), self.getOrDefault(self.num_workers),
) )
else: else:
@ -558,6 +570,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# they are added in `setParams`. # they are added in `setParams`.
self._setDefault( self._setDefault(
num_workers=1, num_workers=1,
device="cpu",
use_gpu=False, use_gpu=False,
force_repartition=False, force_repartition=False,
repartition_random_shuffle=False, repartition_random_shuffle=False,
@ -566,9 +579,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
arbitrary_params_dict={}, arbitrary_params_dict={},
) )
def setParams( def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name
self, **kwargs: Dict[str, Any]
) -> None: # pylint: disable=invalid-name
""" """
Set params for the estimator. Set params for the estimator.
""" """
@ -613,6 +624,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
) )
raise ValueError(err_msg) raise ValueError(err_msg)
_extra_params[k] = v _extra_params[k] = v
_check_distributed_params(kwargs)
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict) _existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params}) self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
@ -709,9 +722,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
# TODO: support "num_parallel_tree" for random forest # TODO: support "num_parallel_tree" for random forest
params["num_boost_round"] = self.getOrDefault("n_estimators") params["num_boost_round"] = self.getOrDefault("n_estimators")
if self.getOrDefault(self.use_gpu):
params["tree_method"] = "gpu_hist"
return params return params
@classmethod @classmethod
@ -883,8 +893,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dmatrix_kwargs, dmatrix_kwargs,
) = self._get_xgb_parameters(dataset) ) = self._get_xgb_parameters(dataset)
use_gpu = self.getOrDefault(self.use_gpu) run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(
self.use_gpu
)
is_local = _is_local(_get_spark_session().sparkContext) is_local = _is_local(_get_spark_session().sparkContext)
num_workers = self.getOrDefault(self.num_workers) num_workers = self.getOrDefault(self.num_workers)
@ -903,7 +914,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dev_ordinal = None dev_ordinal = None
use_qdm = _can_use_qdm(booster_params.get("tree_method", None)) use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
if use_gpu: if run_on_gpu:
dev_ordinal = ( dev_ordinal = (
context.partitionId() if is_local else _get_gpu_id(context) context.partitionId() if is_local else _get_gpu_id(context)
) )

View File

@ -3,8 +3,8 @@
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=unused-argument, too-many-locals # pylint: disable=unused-argument, too-many-locals
import warnings
from typing import Any, Dict, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
import numpy as np import numpy as np
from pyspark import keyword_only from pyspark import keyword_only
@ -77,27 +77,35 @@ def _set_pyspark_xgb_cls_param_attrs(
set_param_attrs(name, param_obj) set_param_attrs(name, param_obj)
def _deprecated_use_gpu() -> None:
warnings.warn(
"`use_gpu` is deprecated since 2.0.0, use `device` instead", FutureWarning
)
class SparkXGBRegressor(_SparkXGBEstimator): class SparkXGBRegressor(_SparkXGBEstimator):
"""SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression """SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
and PySpark ML meta algorithms like :py:class:`~pyspark.ml.tuning.CrossValidator`/ and PySpark ML meta algorithms like
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/ - :py:class:`~pyspark.ml.tuning.CrossValidator`/
:py:class:`~pyspark.ml.classification.OneVsRest` - :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
- :py:class:`~pyspark.ml.classification.OneVsRest`
SparkXGBRegressor automatically supports most of the parameters in SparkXGBRegressor automatically supports most of the parameters in
:py:class:`xgboost.XGBRegressor` constructor and most of the parameters used in :py:class:`xgboost.XGBRegressor` constructor and most of the parameters used in
:py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict` method. :py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict`
method.
SparkXGBRegressor doesn't support setting `device` but supports another param To enable GPU support, set `device` to `cuda` or `gpu`.
`use_gpu`, see doc below for more details.
SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but
another param called `base_margin_col`. see doc below for more details. support another param called `base_margin_col`. see doc below for more details.
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param. SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread` SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the
param for each xgboost worker will be set equal to `spark.task.cpus` config value. `nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
config value.
Parameters Parameters
@ -133,8 +141,11 @@ class SparkXGBRegressor(_SparkXGBEstimator):
How many XGBoost workers to be used to train. How many XGBoost workers to be used to train.
Each XGBoost worker corresponds to one spark task. Each XGBoost worker corresponds to one spark task.
use_gpu: use_gpu:
Boolean value to specify whether the executors are running on GPU .. deprecated:: 2.0.0
instances.
Use `device` instead.
device:
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
force_repartition: force_repartition:
Boolean value to specify if forcing the input dataset to be repartitioned Boolean value to specify if forcing the input dataset to be repartitioned
before XGBoost training. before XGBoost training.
@ -193,14 +204,17 @@ class SparkXGBRegressor(_SparkXGBEstimator):
weight_col: Optional[str] = None, weight_col: Optional[str] = None,
base_margin_col: Optional[str] = None, base_margin_col: Optional[str] = None,
num_workers: int = 1, num_workers: int = 1,
use_gpu: bool = False, use_gpu: Optional[bool] = None,
device: Optional[str] = None,
force_repartition: bool = False, force_repartition: bool = False,
repartition_random_shuffle: bool = False, repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False, enable_sparse_data_optim: bool = False,
**kwargs: Dict[str, Any], **kwargs: Any,
) -> None: ) -> None:
super().__init__() super().__init__()
input_kwargs = self._input_kwargs input_kwargs = self._input_kwargs
if use_gpu:
_deprecated_use_gpu()
self.setParams(**input_kwargs) self.setParams(**input_kwargs)
@classmethod @classmethod
@ -238,27 +252,29 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
"""SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost """SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost
classification algorithm based on XGBoost python library, and it can be used in classification algorithm based on XGBoost python library, and it can be used in
PySpark Pipeline and PySpark ML meta algorithms like PySpark Pipeline and PySpark ML meta algorithms like
:py:class:`~pyspark.ml.tuning.CrossValidator`/ - :py:class:`~pyspark.ml.tuning.CrossValidator`/
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/ - :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
:py:class:`~pyspark.ml.classification.OneVsRest` - :py:class:`~pyspark.ml.classification.OneVsRest`
SparkXGBClassifier automatically supports most of the parameters in SparkXGBClassifier automatically supports most of the parameters in
:py:class:`xgboost.XGBClassifier` constructor and most of the parameters used in :py:class:`xgboost.XGBClassifier` constructor and most of the parameters used in
:py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict` method. :py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict`
method.
SparkXGBClassifier doesn't support setting `device` but support another param To enable GPU support, set `device` to `cuda` or `gpu`.
`use_gpu`, see doc below for more details.
SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but support SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but
another param called `base_margin_col`. see doc below for more details. support another param called `base_margin_col`. see doc below for more details.
SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin SparkXGBClassifier doesn't support setting `output_margin`, but we can get output
from the raw prediction column. See `raw_prediction_col` param doc below for more details. margin from the raw prediction column. See `raw_prediction_col` param doc below for
more details.
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param. SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the `nthread` SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the
param for each xgboost worker will be set equal to `spark.task.cpus` config value. `nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
config value.
Parameters Parameters
@ -300,8 +316,11 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
How many XGBoost workers to be used to train. How many XGBoost workers to be used to train.
Each XGBoost worker corresponds to one spark task. Each XGBoost worker corresponds to one spark task.
use_gpu: use_gpu:
Boolean value to specify whether the executors are running on GPU .. deprecated:: 2.0.0
instances.
Use `device` instead.
device:
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
force_repartition: force_repartition:
Boolean value to specify if forcing the input dataset to be repartitioned Boolean value to specify if forcing the input dataset to be repartitioned
before XGBoost training. before XGBoost training.
@ -360,11 +379,12 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
weight_col: Optional[str] = None, weight_col: Optional[str] = None,
base_margin_col: Optional[str] = None, base_margin_col: Optional[str] = None,
num_workers: int = 1, num_workers: int = 1,
use_gpu: bool = False, use_gpu: Optional[bool] = None,
device: Optional[str] = None,
force_repartition: bool = False, force_repartition: bool = False,
repartition_random_shuffle: bool = False, repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False, enable_sparse_data_optim: bool = False,
**kwargs: Dict[str, Any], **kwargs: Any,
) -> None: ) -> None:
super().__init__() super().__init__()
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor, # The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
@ -372,6 +392,8 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
# binary or multinomial input dataset, and we need to remove the fixed default # binary or multinomial input dataset, and we need to remove the fixed default
# param value as well to avoid causing ambiguity. # param value as well to avoid causing ambiguity.
input_kwargs = self._input_kwargs input_kwargs = self._input_kwargs
if use_gpu:
_deprecated_use_gpu()
self.setParams(**input_kwargs) self.setParams(**input_kwargs)
self._setDefault(objective=None) self._setDefault(objective=None)
@ -422,19 +444,20 @@ class SparkXGBRanker(_SparkXGBEstimator):
:py:class:`xgboost.XGBRanker` constructor and most of the parameters used in :py:class:`xgboost.XGBRanker` constructor and most of the parameters used in
:py:meth:`xgboost.XGBRanker.fit` and :py:meth:`xgboost.XGBRanker.predict` method. :py:meth:`xgboost.XGBRanker.fit` and :py:meth:`xgboost.XGBRanker.predict` method.
SparkXGBRanker doesn't support setting `device` but support another param `use_gpu`, To enable GPU support, set `device` to `cuda` or `gpu`.
see doc below for more details.
SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support
another param called `base_margin_col`. see doc below for more details. another param called `base_margin_col`. see doc below for more details.
SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin
from the raw prediction column. See `raw_prediction_col` param doc below for more details. from the raw prediction column. See `raw_prediction_col` param doc below for more
details.
SparkXGBRanker doesn't support `validate_features` and `output_margin` param. SparkXGBRanker doesn't support `validate_features` and `output_margin` param.
SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the `nthread` SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the
param for each xgboost worker will be set equal to `spark.task.cpus` config value. `nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
config value.
Parameters Parameters
@ -467,13 +490,15 @@ class SparkXGBRanker(_SparkXGBEstimator):
:py:class:`xgboost.XGBRanker` fit method. :py:class:`xgboost.XGBRanker` fit method.
qid_col: qid_col:
Query id column name. Query id column name.
num_workers: num_workers:
How many XGBoost workers to be used to train. How many XGBoost workers to be used to train.
Each XGBoost worker corresponds to one spark task. Each XGBoost worker corresponds to one spark task.
use_gpu: use_gpu:
Boolean value to specify whether the executors are running on GPU .. deprecated:: 2.0.0
instances.
Use `device` instead.
device:
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
force_repartition: force_repartition:
Boolean value to specify if forcing the input dataset to be repartitioned Boolean value to specify if forcing the input dataset to be repartitioned
before XGBoost training. before XGBoost training.
@ -538,14 +563,17 @@ class SparkXGBRanker(_SparkXGBEstimator):
base_margin_col: Optional[str] = None, base_margin_col: Optional[str] = None,
qid_col: Optional[str] = None, qid_col: Optional[str] = None,
num_workers: int = 1, num_workers: int = 1,
use_gpu: bool = False, use_gpu: Optional[bool] = None,
device: Optional[str] = None,
force_repartition: bool = False, force_repartition: bool = False,
repartition_random_shuffle: bool = False, repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False, enable_sparse_data_optim: bool = False,
**kwargs: Dict[str, Any], **kwargs: Any,
) -> None: ) -> None:
super().__init__() super().__init__()
input_kwargs = self._input_kwargs input_kwargs = self._input_kwargs
if use_gpu:
_deprecated_use_gpu()
self.setParams(**input_kwargs) self.setParams(**input_kwargs)
@classmethod @classmethod

View File

@ -7,7 +7,7 @@ import os
import sys import sys
import uuid import uuid
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, Set, Type from typing import Any, Callable, Dict, Optional, Set, Type
import pyspark import pyspark
from pyspark import BarrierTaskContext, SparkContext, SparkFiles from pyspark import BarrierTaskContext, SparkContext, SparkFiles
@ -186,3 +186,8 @@ def deserialize_booster(model: str) -> Booster:
f.write(model) f.write(model)
booster.load_model(tmp_file_name) booster.load_model(tmp_file_name)
return booster return booster
def use_cuda(device: Optional[str]) -> bool:
"""Whether xgboost is using CUDA workers."""
return device in ("cuda", "gpu")

View File

@ -98,8 +98,8 @@ void MismatchedDevices(Context const* booster, Context const* data) {
- Use a data structure that matches the device ordinal in the booster. - Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict. - Set the device for booster before call to inplace_predict.
This warning will only be shown once, and subsequent warnings made by the current thread will be This warning will only be shown once for each thread. Subsequent warnings made by the
suppressed. current thread will be suppressed.
)"; )";
logged = true; logged = true;
} }

View File

@ -154,7 +154,7 @@ def spark_diabetes_dataset_feature_cols(spark_session_with_gpu):
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset): def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
from pyspark.ml.evaluation import MulticlassClassificationEvaluator from pyspark.ml.evaluation import MulticlassClassificationEvaluator
classifier = SparkXGBClassifier(use_gpu=True, num_workers=num_workers) classifier = SparkXGBClassifier(device="cuda", num_workers=num_workers)
train_df, test_df = spark_iris_dataset train_df, test_df = spark_iris_dataset
model = classifier.fit(train_df) model = classifier.fit(train_df)
pred_result_df = model.transform(test_df) pred_result_df = model.transform(test_df)
@ -169,7 +169,7 @@ def test_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_co
train_df, test_df, feature_names = spark_iris_dataset_feature_cols train_df, test_df, feature_names = spark_iris_dataset_feature_cols
classifier = SparkXGBClassifier( classifier = SparkXGBClassifier(
features_col=feature_names, use_gpu=True, num_workers=num_workers features_col=feature_names, device="cuda", num_workers=num_workers
) )
model = classifier.fit(train_df) model = classifier.fit(train_df)
@ -185,7 +185,7 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature
train_df, test_df, feature_names = spark_iris_dataset_feature_cols train_df, test_df, feature_names = spark_iris_dataset_feature_cols
classifier = SparkXGBClassifier( classifier = SparkXGBClassifier(
features_col=feature_names, use_gpu=True, num_workers=num_workers features_col=feature_names, device="cuda", num_workers=num_workers
) )
grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build() grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build()
evaluator = MulticlassClassificationEvaluator(metricName="f1") evaluator = MulticlassClassificationEvaluator(metricName="f1")
@ -197,11 +197,24 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature
f1 = evaluator.evaluate(pred_result_df) f1 = evaluator.evaluate(pred_result_df)
assert f1 >= 0.97 assert f1 >= 0.97
clf = SparkXGBClassifier(
features_col=feature_names, use_gpu=True, num_workers=num_workers
)
grid = ParamGridBuilder().addGrid(clf.max_depth, [6, 8]).build()
evaluator = MulticlassClassificationEvaluator(metricName="f1")
cv = CrossValidator(
estimator=clf, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3
)
cvModel = cv.fit(train_df)
pred_result_df = cvModel.transform(test_df)
f1 = evaluator.evaluate(pred_result_df)
assert f1 >= 0.97
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.evaluation import RegressionEvaluator
regressor = SparkXGBRegressor(use_gpu=True, num_workers=num_workers) regressor = SparkXGBRegressor(device="cuda", num_workers=num_workers)
train_df, test_df = spark_diabetes_dataset train_df, test_df = spark_diabetes_dataset
model = regressor.fit(train_df) model = regressor.fit(train_df)
pred_result_df = model.transform(test_df) pred_result_df = model.transform(test_df)
@ -215,7 +228,7 @@ def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature
train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols
regressor = SparkXGBRegressor( regressor = SparkXGBRegressor(
features_col=feature_names, use_gpu=True, num_workers=num_workers features_col=feature_names, device="cuda", num_workers=num_workers
) )
model = regressor.fit(train_df) model = regressor.fit(train_df)

View File

@ -741,11 +741,6 @@ class TestPySparkLocal:
with pytest.raises(ValueError, match="early_stopping_rounds"): with pytest.raises(ValueError, match="early_stopping_rounds"):
classifier.fit(clf_data.cls_df_train) classifier.fit(clf_data.cls_df_train)
def test_gpu_param_setting(self, clf_data: ClfData) -> None:
py_cls = SparkXGBClassifier(use_gpu=True)
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
assert train_params["tree_method"] == "gpu_hist"
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None: def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"]) classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
model = classifier.fit(clf_data.cls_df_train) model = classifier.fit(clf_data.cls_df_train)
@ -756,6 +751,53 @@ class TestPySparkLocal:
model = classifier.fit(clf_data.cls_df_train) model = classifier.fit(clf_data.cls_df_train)
model.transform(clf_data.cls_df_test).collect() model.transform(clf_data.cls_df_test).collect()
def test_regressor_params_basic(self) -> None:
py_reg = SparkXGBRegressor()
assert hasattr(py_reg, "n_estimators")
assert py_reg.n_estimators.parent == py_reg.uid
assert not hasattr(py_reg, "gpu_id")
assert hasattr(py_reg, "device")
assert py_reg.getOrDefault(py_reg.n_estimators) == 100
assert py_reg.getOrDefault(getattr(py_reg, "objective")), "reg:squarederror"
py_reg2 = SparkXGBRegressor(n_estimators=200)
assert py_reg2.getOrDefault(getattr(py_reg2, "n_estimators")), 200
py_reg3 = py_reg2.copy({getattr(py_reg2, "max_depth"): 10})
assert py_reg3.getOrDefault(getattr(py_reg3, "n_estimators")), 200
assert py_reg3.getOrDefault(getattr(py_reg3, "max_depth")), 10
def test_classifier_params_basic(self) -> None:
py_clf = SparkXGBClassifier()
assert hasattr(py_clf, "n_estimators")
assert py_clf.n_estimators.parent == py_clf.uid
assert not hasattr(py_clf, "gpu_id")
assert hasattr(py_clf, "device")
assert py_clf.getOrDefault(py_clf.n_estimators) == 100
assert py_clf.getOrDefault(getattr(py_clf, "objective")) is None
py_clf2 = SparkXGBClassifier(n_estimators=200)
assert py_clf2.getOrDefault(getattr(py_clf2, "n_estimators")) == 200
py_clf3 = py_clf2.copy({getattr(py_clf2, "max_depth"): 10})
assert py_clf3.getOrDefault(getattr(py_clf3, "n_estimators")) == 200
assert py_clf3.getOrDefault(getattr(py_clf3, "max_depth")), 10
def test_classifier_kwargs_basic(self, clf_data: ClfData) -> None:
py_clf = SparkXGBClassifier(**clf_data.cls_params)
assert hasattr(py_clf, "n_estimators")
assert py_clf.n_estimators.parent == py_clf.uid
assert not hasattr(py_clf, "gpu_id")
assert hasattr(py_clf, "device")
assert hasattr(py_clf, "arbitrary_params_dict")
assert py_clf.getOrDefault(py_clf.arbitrary_params_dict) == {}
# Testing overwritten params
py_clf = SparkXGBClassifier()
py_clf.setParams(x=1, y=2)
py_clf.setParams(y=3, z=4)
xgb_params = py_clf._gen_xgb_params_dict()
assert xgb_params["x"] == 1
assert xgb_params["y"] == 3
assert xgb_params["z"] == 4
def test_regressor_model_save_load(self, reg_data: RegData) -> None: def test_regressor_model_save_load(self, reg_data: RegData) -> None:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
path = "file:" + tmpdir path = "file:" + tmpdir
@ -826,6 +868,24 @@ class TestPySparkLocal:
) )
assert_model_compatible(model.stages[0], tmpdir) assert_model_compatible(model.stages[0], tmpdir)
def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None:
clf = SparkXGBClassifier(device="cuda", tree_method="exact")
with pytest.raises(ValueError, match="not supported on GPU"):
clf.fit(clf_data.cls_df_train)
regressor = SparkXGBRegressor(device="cuda", tree_method="exact")
with pytest.raises(ValueError, match="not supported on GPU"):
regressor.fit(reg_data.reg_df_train)
reg = SparkXGBRegressor(device="cuda", tree_method="gpu_hist")
reg._validate_params()
reg = SparkXGBRegressor(device="cuda")
reg._validate_params()
clf = SparkXGBClassifier(device="cuda", tree_method="gpu_hist")
clf._validate_params()
clf = SparkXGBClassifier(device="cuda")
clf._validate_params()
class XgboostLocalTest(SparkTestCase): class XgboostLocalTest(SparkTestCase):
def setUp(self): def setUp(self):
@ -1020,55 +1080,6 @@ class XgboostLocalTest(SparkTestCase):
assert sklearn_regressor.max_depth == 3 assert sklearn_regressor.max_depth == 3
assert sklearn_regressor.get_params()["sketch_eps"] == 0.5 assert sklearn_regressor.get_params()["sketch_eps"] == 0.5
def test_regressor_params_basic(self):
py_reg = SparkXGBRegressor()
self.assertTrue(hasattr(py_reg, "n_estimators"))
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
self.assertFalse(hasattr(py_reg, "gpu_id"))
self.assertFalse(hasattr(py_reg, "device"))
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror")
py_reg2 = SparkXGBRegressor(n_estimators=200)
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200)
self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10)
def test_classifier_params_basic(self):
py_cls = SparkXGBClassifier()
self.assertTrue(hasattr(py_cls, "n_estimators"))
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
self.assertFalse(hasattr(py_cls, "gpu_id"))
self.assertFalse(hasattr(py_cls, "device"))
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
self.assertEqual(py_cls.getOrDefault(py_cls.objective), None)
py_cls2 = SparkXGBClassifier(n_estimators=200)
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200)
self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10)
def test_classifier_kwargs_basic(self):
py_cls = SparkXGBClassifier(**self.cls_params_kwargs)
self.assertTrue(hasattr(py_cls, "n_estimators"))
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
self.assertFalse(hasattr(py_cls, "gpu_id"))
self.assertFalse(hasattr(py_cls, "device"))
self.assertTrue(hasattr(py_cls, "arbitrary_params_dict"))
expected_kwargs = {"sketch_eps": 0.03}
self.assertEqual(
py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs
)
# Testing overwritten params
py_cls = SparkXGBClassifier()
py_cls.setParams(x=1, y=2)
py_cls.setParams(y=3, z=4)
xgb_params = py_cls._gen_xgb_params_dict()
assert xgb_params["x"] == 1
assert xgb_params["y"] == 3
assert xgb_params["z"] == 4
def test_param_alias(self): def test_param_alias(self):
py_cls = SparkXGBClassifier(features_col="f1", label_col="l1") py_cls = SparkXGBClassifier(features_col="f1", label_col="l1")
self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1") self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1")
@ -1200,16 +1211,6 @@ class XgboostLocalTest(SparkTestCase):
classifier = SparkXGBClassifier(num_workers=0) classifier = SparkXGBClassifier(num_workers=0)
self.assertRaises(ValueError, classifier._validate_params) self.assertRaises(ValueError, classifier._validate_params)
def test_use_gpu_param(self):
classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact")
self.assertRaises(ValueError, classifier._validate_params)
regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact")
self.assertRaises(ValueError, regressor._validate_params)
regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist")
regressor = SparkXGBRegressor(use_gpu=True)
classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist")
classifier = SparkXGBClassifier(use_gpu=True)
def test_feature_importances(self): def test_feature_importances(self):
reg1 = SparkXGBRegressor(**self.reg_params) reg1 = SparkXGBRegressor(**self.reg_params)
model = reg1.fit(self.reg_df_train) model = reg1.fit(self.reg_df_train)