diff --git a/doc/tutorials/spark_estimator.rst b/doc/tutorials/spark_estimator.rst index 545403a34..44bdd7733 100644 --- a/doc/tutorials/spark_estimator.rst +++ b/doc/tutorials/spark_estimator.rst @@ -35,13 +35,13 @@ We can create a ``SparkXGBRegressor`` estimator like: ) -The above snippet creates a spark estimator which can fit on a spark dataset, -and return a spark model that can transform a spark dataset and generate dataset -with prediction column. We can set almost all of xgboost sklearn estimator parameters -as ``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden -in spark estimator, and some parameters are replaced with pyspark specific parameters -such as ``weight_col``, ``validation_indicator_col``, ``use_gpu``, for details please see -``SparkXGBRegressor`` doc. +The above snippet creates a spark estimator which can fit on a spark dataset, and return a +spark model that can transform a spark dataset and generate dataset with prediction +column. We can set almost all of xgboost sklearn estimator parameters as +``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden in +spark estimator, and some parameters are replaced with pyspark specific parameters such as +``weight_col``, ``validation_indicator_col``, for details please see ``SparkXGBRegressor`` +doc. The following code snippet shows how to train a spark xgboost regressor model, first we need to prepare a training dataset as a spark dataframe contains @@ -88,7 +88,7 @@ XGBoost PySpark fully supports GPU acceleration. Users are not only able to enab efficient training but also utilize their GPUs for the whole PySpark pipeline including ETL and inference. In below sections, we will walk through an example of training on a PySpark standalone GPU cluster. To get started, first we need to install some additional -packages, then we can set the ``use_gpu`` parameter to ``True``. +packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``. Prepare the necessary packages ============================== @@ -128,7 +128,7 @@ Write your PySpark application ============================== Below snippet is a small example for training xgboost model with PySpark. Notice that we are -using a list of feature names and the additional parameter ``use_gpu``: +using a list of feature names and the additional parameter ``device``: .. code-block:: python @@ -148,12 +148,12 @@ using a list of feature names and the additional parameter ``use_gpu``: # get a list with feature column names feature_names = [x.name for x in train_df.schema if x.name != label_name] - # create a xgboost pyspark regressor estimator and set use_gpu=True + # create a xgboost pyspark regressor estimator and set device="cuda" regressor = SparkXGBRegressor( features_col=feature_names, label_col=label_name, num_workers=2, - use_gpu=True, + device="cuda", ) # train and return the model @@ -163,6 +163,7 @@ using a list of feature names and the additional parameter ``use_gpu``: predict_df = model.transform(test_df) predict_df.show() +Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``). Submit the PySpark application ============================== diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index d41976e8b..4cacd61f3 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -276,6 +276,27 @@ def _check_call(ret: int) -> None: raise XGBoostError(py_str(_LIB.XGBGetLastError())) +def _check_distributed_params(kwargs: Dict[str, Any]) -> None: + """Validate parameters in distributed environments.""" + device = kwargs.get("device", None) + if device and not isinstance(device, str): + msg = "Invalid type for the `device` parameter" + msg += _expect((str,), type(device)) + raise TypeError(msg) + + if device and device.find(":") != -1: + raise ValueError( + "Distributed training doesn't support selecting device ordinal as GPUs are" + " managed by the distributed framework. use `device=cuda` or `device=gpu`" + " instead." + ) + + if kwargs.get("booster", None) == "gblinear": + raise NotImplementedError( + f"booster `{kwargs['booster']}` is not supported for distributed training." + ) + + def build_info() -> dict: """Build information of XGBoost. The returned value format is not stable. Also, please note that build time dependency is not the same as runtime dependency. For diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 32dd2a4a7..271a5e458 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -70,6 +70,7 @@ from .core import ( Metric, Objective, QuantileDMatrix, + _check_distributed_params, _deprecate_positional_args, _expect, ) @@ -924,17 +925,7 @@ async def _train_async( ) -> Optional[TrainReturnT]: workers = _get_workers_from_data(dtrain, evals) _rabit_args = await _get_rabit_args(len(workers), dconfig, client) - - if params.get("booster", None) == "gblinear": - raise NotImplementedError( - f"booster `{params['booster']}` is not yet supported for dask." - ) - device = params.get("device", None) - if device and device.find(":") != -1: - raise ValueError( - "The dask interface for XGBoost doesn't support selecting specific device" - " ordinal. Use `device=cpu` or `device=cuda` instead." - ) + _check_distributed_params(params) def dispatched_train( parameters: Dict, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index d69cb3a01..46a3ffa4a 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1004,13 +1004,17 @@ class XGBModel(XGBModelBase): Validation metrics will help us track the performance of the model. eval_metric : str, list of str, or callable, optional + .. deprecated:: 1.6.0 - Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead. + + Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead. early_stopping_rounds : int + .. deprecated:: 1.6.0 - Use `early_stopping_rounds` in :py:meth:`__init__` or - :py:meth:`set_params` instead. + + Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params` + instead. verbose : If `verbose` is True and an evaluation set is used, the evaluation metric measured on the validation set is printed to stdout at each boosting stage. diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 283999c6d..998afbf77 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -60,7 +60,7 @@ from scipy.special import expit, softmax # pylint: disable=no-name-in-module import xgboost from xgboost import XGBClassifier from xgboost.compat import is_cudf_available -from xgboost.core import Booster +from xgboost.core import Booster, _check_distributed_params from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm from xgboost.training import train as worker_train @@ -92,6 +92,7 @@ from .utils import ( get_class_name, get_logger, serialize_booster, + use_cuda, ) # Put pyspark specific params here, they won't be passed to XGBoost. @@ -108,7 +109,6 @@ _pyspark_specific_params = [ "arbitrary_params_dict", "force_repartition", "num_workers", - "use_gpu", "feature_names", "features_cols", "enable_sparse_data_optim", @@ -132,8 +132,7 @@ _pyspark_param_alias_map = { _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()} _unsupported_xgb_params = [ - "gpu_id", # we have "use_gpu" pyspark param instead. - "device", # we have "use_gpu" pyspark param instead. + "gpu_id", # we have "device" pyspark param instead. "enable_categorical", # Use feature_types param to specify categorical feature instead "use_label_encoder", "n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead. @@ -198,11 +197,24 @@ class _SparkXGBParams( "The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.", TypeConverters.toInt, ) + device = Param( + Params._dummy(), + "device", + ( + "The device type for XGBoost executors. Available options are `cpu`,`cuda`" + " and `gpu`. Set `device` to `cuda` or `gpu` if the executors are running " + "on GPU instances. Currently, only one GPU per task is supported." + ), + TypeConverters.toString, + ) use_gpu = Param( Params._dummy(), "use_gpu", - "A boolean variable. Set use_gpu=true if the executors " - + "are running on GPU instances. Currently, only one GPU per task is supported.", + ( + "Deprecated, use `device` instead. A boolean variable. Set use_gpu=true " + "if the executors are running on GPU instances. Currently, only one GPU per" + " task is supported." + ), TypeConverters.toBoolean, ) force_repartition = Param( @@ -336,10 +348,20 @@ class _SparkXGBParams( f"It cannot be less than 1 [Default is 1]" ) + tree_method = self.getOrDefault(self.getParam("tree_method")) + if ( + self.getOrDefault(self.use_gpu) or use_cuda(self.getOrDefault(self.device)) + ) and not _can_use_qdm(tree_method): + raise ValueError( + f"The `{tree_method}` tree method is not supported on GPU." + ) + if self.getOrDefault(self.features_cols): - if not self.getOrDefault(self.use_gpu): + if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault( + self.use_gpu + ): raise ValueError( - "features_col param with list value requires enabling use_gpu." + "features_col param with list value requires `device=cuda`." ) if self.getOrDefault("objective") is not None: @@ -392,17 +414,7 @@ class _SparkXGBParams( "`pyspark.ml.linalg.Vector` type." ) - if self.getOrDefault(self.use_gpu): - tree_method = self.getParam("tree_method") - if ( - self.getOrDefault(tree_method) is not None - and self.getOrDefault(tree_method) != "gpu_hist" - ): - raise ValueError( - f"tree_method should be 'gpu_hist' or None when use_gpu is True," - f"found {self.getOrDefault(tree_method)}." - ) - + if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): gpu_per_task = ( _get_spark_session() .sparkContext.getConf() @@ -424,8 +436,8 @@ class _SparkXGBParams( # so it's okay for printing the below warning instead of checking the real # gpu numbers and raising the exception. get_logger(self.__class__.__name__).warning( - "You enabled use_gpu in spark local mode. Please make sure your local node " - "has at least %d GPUs", + "You enabled GPU in spark local mode. Please make sure your local " + "node has at least %d GPUs", self.getOrDefault(self.num_workers), ) else: @@ -558,6 +570,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): # they are added in `setParams`. self._setDefault( num_workers=1, + device="cpu", use_gpu=False, force_repartition=False, repartition_random_shuffle=False, @@ -566,9 +579,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): arbitrary_params_dict={}, ) - def setParams( - self, **kwargs: Dict[str, Any] - ) -> None: # pylint: disable=invalid-name + def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name """ Set params for the estimator. """ @@ -613,6 +624,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): ) raise ValueError(err_msg) _extra_params[k] = v + + _check_distributed_params(kwargs) _existing_extra_params = self.getOrDefault(self.arbitrary_params_dict) self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params}) @@ -709,9 +722,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): # TODO: support "num_parallel_tree" for random forest params["num_boost_round"] = self.getOrDefault("n_estimators") - if self.getOrDefault(self.use_gpu): - params["tree_method"] = "gpu_hist" - return params @classmethod @@ -883,8 +893,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): dmatrix_kwargs, ) = self._get_xgb_parameters(dataset) - use_gpu = self.getOrDefault(self.use_gpu) - + run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault( + self.use_gpu + ) is_local = _is_local(_get_spark_session().sparkContext) num_workers = self.getOrDefault(self.num_workers) @@ -903,7 +914,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): dev_ordinal = None use_qdm = _can_use_qdm(booster_params.get("tree_method", None)) - if use_gpu: + if run_on_gpu: dev_ordinal = ( context.partitionId() if is_local else _get_gpu_id(context) ) diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index ba75aca7f..f11a0eda8 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -3,8 +3,8 @@ # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=unused-argument, too-many-locals - -from typing import Any, Dict, List, Optional, Type, Union +import warnings +from typing import Any, List, Optional, Type, Union import numpy as np from pyspark import keyword_only @@ -77,27 +77,35 @@ def _set_pyspark_xgb_cls_param_attrs( set_param_attrs(name, param_obj) +def _deprecated_use_gpu() -> None: + warnings.warn( + "`use_gpu` is deprecated since 2.0.0, use `device` instead", FutureWarning + ) + + class SparkXGBRegressor(_SparkXGBEstimator): """SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression algorithm based on XGBoost python library, and it can be used in PySpark Pipeline - and PySpark ML meta algorithms like :py:class:`~pyspark.ml.tuning.CrossValidator`/ - :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/ - :py:class:`~pyspark.ml.classification.OneVsRest` + and PySpark ML meta algorithms like + - :py:class:`~pyspark.ml.tuning.CrossValidator`/ + - :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/ + - :py:class:`~pyspark.ml.classification.OneVsRest` SparkXGBRegressor automatically supports most of the parameters in :py:class:`xgboost.XGBRegressor` constructor and most of the parameters used in - :py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict` method. + :py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict` + method. - SparkXGBRegressor doesn't support setting `device` but supports another param - `use_gpu`, see doc below for more details. + To enable GPU support, set `device` to `cuda` or `gpu`. - SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support - another param called `base_margin_col`. see doc below for more details. + SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but + support another param called `base_margin_col`. see doc below for more details. SparkXGBRegressor doesn't support `validate_features` and `output_margin` param. - SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread` - param for each xgboost worker will be set equal to `spark.task.cpus` config value. + SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the + `nthread` param for each xgboost worker will be set equal to `spark.task.cpus` + config value. Parameters @@ -133,8 +141,11 @@ class SparkXGBRegressor(_SparkXGBEstimator): How many XGBoost workers to be used to train. Each XGBoost worker corresponds to one spark task. use_gpu: - Boolean value to specify whether the executors are running on GPU - instances. + .. deprecated:: 2.0.0 + + Use `device` instead. + device: + Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`. force_repartition: Boolean value to specify if forcing the input dataset to be repartitioned before XGBoost training. @@ -193,14 +204,17 @@ class SparkXGBRegressor(_SparkXGBEstimator): weight_col: Optional[str] = None, base_margin_col: Optional[str] = None, num_workers: int = 1, - use_gpu: bool = False, + use_gpu: Optional[bool] = None, + device: Optional[str] = None, force_repartition: bool = False, repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__() input_kwargs = self._input_kwargs + if use_gpu: + _deprecated_use_gpu() self.setParams(**input_kwargs) @classmethod @@ -238,27 +252,29 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction """SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost classification algorithm based on XGBoost python library, and it can be used in PySpark Pipeline and PySpark ML meta algorithms like - :py:class:`~pyspark.ml.tuning.CrossValidator`/ - :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/ - :py:class:`~pyspark.ml.classification.OneVsRest` + - :py:class:`~pyspark.ml.tuning.CrossValidator`/ + - :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/ + - :py:class:`~pyspark.ml.classification.OneVsRest` SparkXGBClassifier automatically supports most of the parameters in :py:class:`xgboost.XGBClassifier` constructor and most of the parameters used in - :py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict` method. + :py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict` + method. - SparkXGBClassifier doesn't support setting `device` but support another param - `use_gpu`, see doc below for more details. + To enable GPU support, set `device` to `cuda` or `gpu`. - SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but support - another param called `base_margin_col`. see doc below for more details. + SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but + support another param called `base_margin_col`. see doc below for more details. - SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin - from the raw prediction column. See `raw_prediction_col` param doc below for more details. + SparkXGBClassifier doesn't support setting `output_margin`, but we can get output + margin from the raw prediction column. See `raw_prediction_col` param doc below for + more details. SparkXGBClassifier doesn't support `validate_features` and `output_margin` param. - SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the `nthread` - param for each xgboost worker will be set equal to `spark.task.cpus` config value. + SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the + `nthread` param for each xgboost worker will be set equal to `spark.task.cpus` + config value. Parameters @@ -300,8 +316,11 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction How many XGBoost workers to be used to train. Each XGBoost worker corresponds to one spark task. use_gpu: - Boolean value to specify whether the executors are running on GPU - instances. + .. deprecated:: 2.0.0 + + Use `device` instead. + device: + Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`. force_repartition: Boolean value to specify if forcing the input dataset to be repartitioned before XGBoost training. @@ -360,11 +379,12 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction weight_col: Optional[str] = None, base_margin_col: Optional[str] = None, num_workers: int = 1, - use_gpu: bool = False, + use_gpu: Optional[bool] = None, + device: Optional[str] = None, force_repartition: bool = False, repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__() # The default 'objective' param value comes from sklearn `XGBClassifier` ctor, @@ -372,6 +392,8 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction # binary or multinomial input dataset, and we need to remove the fixed default # param value as well to avoid causing ambiguity. input_kwargs = self._input_kwargs + if use_gpu: + _deprecated_use_gpu() self.setParams(**input_kwargs) self._setDefault(objective=None) @@ -422,19 +444,20 @@ class SparkXGBRanker(_SparkXGBEstimator): :py:class:`xgboost.XGBRanker` constructor and most of the parameters used in :py:meth:`xgboost.XGBRanker.fit` and :py:meth:`xgboost.XGBRanker.predict` method. - SparkXGBRanker doesn't support setting `device` but support another param `use_gpu`, - see doc below for more details. + To enable GPU support, set `device` to `cuda` or `gpu`. SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support another param called `base_margin_col`. see doc below for more details. SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin - from the raw prediction column. See `raw_prediction_col` param doc below for more details. + from the raw prediction column. See `raw_prediction_col` param doc below for more + details. SparkXGBRanker doesn't support `validate_features` and `output_margin` param. - SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the `nthread` - param for each xgboost worker will be set equal to `spark.task.cpus` config value. + SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the + `nthread` param for each xgboost worker will be set equal to `spark.task.cpus` + config value. Parameters @@ -467,13 +490,15 @@ class SparkXGBRanker(_SparkXGBEstimator): :py:class:`xgboost.XGBRanker` fit method. qid_col: Query id column name. - num_workers: How many XGBoost workers to be used to train. Each XGBoost worker corresponds to one spark task. use_gpu: - Boolean value to specify whether the executors are running on GPU - instances. + .. deprecated:: 2.0.0 + + Use `device` instead. + device: + Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`. force_repartition: Boolean value to specify if forcing the input dataset to be repartitioned before XGBoost training. @@ -538,14 +563,17 @@ class SparkXGBRanker(_SparkXGBEstimator): base_margin_col: Optional[str] = None, qid_col: Optional[str] = None, num_workers: int = 1, - use_gpu: bool = False, + use_gpu: Optional[bool] = None, + device: Optional[str] = None, force_repartition: bool = False, repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__() input_kwargs = self._input_kwargs + if use_gpu: + _deprecated_use_gpu() self.setParams(**input_kwargs) @classmethod diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 46e465dde..5f3bb19ba 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -7,7 +7,7 @@ import os import sys import uuid from threading import Thread -from typing import Any, Callable, Dict, Set, Type +from typing import Any, Callable, Dict, Optional, Set, Type import pyspark from pyspark import BarrierTaskContext, SparkContext, SparkFiles @@ -186,3 +186,8 @@ def deserialize_booster(model: str) -> Booster: f.write(model) booster.load_model(tmp_file_name) return booster + + +def use_cuda(device: Optional[str]) -> bool: + """Whether xgboost is using CUDA workers.""" + return device in ("cuda", "gpu") diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index e97b27665..0806c13a7 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -98,8 +98,8 @@ void MismatchedDevices(Context const* booster, Context const* data) { - Use a data structure that matches the device ordinal in the booster. - Set the device for booster before call to inplace_predict. -This warning will only be shown once, and subsequent warnings made by the current thread will be -suppressed. +This warning will only be shown once for each thread. Subsequent warnings made by the +current thread will be suppressed. )"; logged = true; } diff --git a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py index 1f986f96e..a962f778e 100644 --- a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py +++ b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py @@ -154,7 +154,7 @@ def spark_diabetes_dataset_feature_cols(spark_session_with_gpu): def test_sparkxgb_classifier_with_gpu(spark_iris_dataset): from pyspark.ml.evaluation import MulticlassClassificationEvaluator - classifier = SparkXGBClassifier(use_gpu=True, num_workers=num_workers) + classifier = SparkXGBClassifier(device="cuda", num_workers=num_workers) train_df, test_df = spark_iris_dataset model = classifier.fit(train_df) pred_result_df = model.transform(test_df) @@ -169,7 +169,7 @@ def test_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_co train_df, test_df, feature_names = spark_iris_dataset_feature_cols classifier = SparkXGBClassifier( - features_col=feature_names, use_gpu=True, num_workers=num_workers + features_col=feature_names, device="cuda", num_workers=num_workers ) model = classifier.fit(train_df) @@ -185,7 +185,7 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature train_df, test_df, feature_names = spark_iris_dataset_feature_cols classifier = SparkXGBClassifier( - features_col=feature_names, use_gpu=True, num_workers=num_workers + features_col=feature_names, device="cuda", num_workers=num_workers ) grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build() evaluator = MulticlassClassificationEvaluator(metricName="f1") @@ -197,11 +197,24 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature f1 = evaluator.evaluate(pred_result_df) assert f1 >= 0.97 + clf = SparkXGBClassifier( + features_col=feature_names, use_gpu=True, num_workers=num_workers + ) + grid = ParamGridBuilder().addGrid(clf.max_depth, [6, 8]).build() + evaluator = MulticlassClassificationEvaluator(metricName="f1") + cv = CrossValidator( + estimator=clf, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3 + ) + cvModel = cv.fit(train_df) + pred_result_df = cvModel.transform(test_df) + f1 = evaluator.evaluate(pred_result_df) + assert f1 >= 0.97 + def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset): from pyspark.ml.evaluation import RegressionEvaluator - regressor = SparkXGBRegressor(use_gpu=True, num_workers=num_workers) + regressor = SparkXGBRegressor(device="cuda", num_workers=num_workers) train_df, test_df = spark_diabetes_dataset model = regressor.fit(train_df) pred_result_df = model.transform(test_df) @@ -215,7 +228,7 @@ def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols regressor = SparkXGBRegressor( - features_col=feature_names, use_gpu=True, num_workers=num_workers + features_col=feature_names, device="cuda", num_workers=num_workers ) model = regressor.fit(train_df) diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 124f36d02..50eafb0a1 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -741,11 +741,6 @@ class TestPySparkLocal: with pytest.raises(ValueError, match="early_stopping_rounds"): classifier.fit(clf_data.cls_df_train) - def test_gpu_param_setting(self, clf_data: ClfData) -> None: - py_cls = SparkXGBClassifier(use_gpu=True) - train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train) - assert train_params["tree_method"] == "gpu_hist" - def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None: classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"]) model = classifier.fit(clf_data.cls_df_train) @@ -756,6 +751,53 @@ class TestPySparkLocal: model = classifier.fit(clf_data.cls_df_train) model.transform(clf_data.cls_df_test).collect() + def test_regressor_params_basic(self) -> None: + py_reg = SparkXGBRegressor() + assert hasattr(py_reg, "n_estimators") + assert py_reg.n_estimators.parent == py_reg.uid + assert not hasattr(py_reg, "gpu_id") + assert hasattr(py_reg, "device") + assert py_reg.getOrDefault(py_reg.n_estimators) == 100 + assert py_reg.getOrDefault(getattr(py_reg, "objective")), "reg:squarederror" + py_reg2 = SparkXGBRegressor(n_estimators=200) + assert py_reg2.getOrDefault(getattr(py_reg2, "n_estimators")), 200 + py_reg3 = py_reg2.copy({getattr(py_reg2, "max_depth"): 10}) + assert py_reg3.getOrDefault(getattr(py_reg3, "n_estimators")), 200 + assert py_reg3.getOrDefault(getattr(py_reg3, "max_depth")), 10 + + def test_classifier_params_basic(self) -> None: + py_clf = SparkXGBClassifier() + assert hasattr(py_clf, "n_estimators") + assert py_clf.n_estimators.parent == py_clf.uid + assert not hasattr(py_clf, "gpu_id") + assert hasattr(py_clf, "device") + + assert py_clf.getOrDefault(py_clf.n_estimators) == 100 + assert py_clf.getOrDefault(getattr(py_clf, "objective")) is None + py_clf2 = SparkXGBClassifier(n_estimators=200) + assert py_clf2.getOrDefault(getattr(py_clf2, "n_estimators")) == 200 + py_clf3 = py_clf2.copy({getattr(py_clf2, "max_depth"): 10}) + assert py_clf3.getOrDefault(getattr(py_clf3, "n_estimators")) == 200 + assert py_clf3.getOrDefault(getattr(py_clf3, "max_depth")), 10 + + def test_classifier_kwargs_basic(self, clf_data: ClfData) -> None: + py_clf = SparkXGBClassifier(**clf_data.cls_params) + assert hasattr(py_clf, "n_estimators") + assert py_clf.n_estimators.parent == py_clf.uid + assert not hasattr(py_clf, "gpu_id") + assert hasattr(py_clf, "device") + assert hasattr(py_clf, "arbitrary_params_dict") + assert py_clf.getOrDefault(py_clf.arbitrary_params_dict) == {} + + # Testing overwritten params + py_clf = SparkXGBClassifier() + py_clf.setParams(x=1, y=2) + py_clf.setParams(y=3, z=4) + xgb_params = py_clf._gen_xgb_params_dict() + assert xgb_params["x"] == 1 + assert xgb_params["y"] == 3 + assert xgb_params["z"] == 4 + def test_regressor_model_save_load(self, reg_data: RegData) -> None: with tempfile.TemporaryDirectory() as tmpdir: path = "file:" + tmpdir @@ -826,6 +868,24 @@ class TestPySparkLocal: ) assert_model_compatible(model.stages[0], tmpdir) + def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None: + clf = SparkXGBClassifier(device="cuda", tree_method="exact") + with pytest.raises(ValueError, match="not supported on GPU"): + clf.fit(clf_data.cls_df_train) + regressor = SparkXGBRegressor(device="cuda", tree_method="exact") + with pytest.raises(ValueError, match="not supported on GPU"): + regressor.fit(reg_data.reg_df_train) + + reg = SparkXGBRegressor(device="cuda", tree_method="gpu_hist") + reg._validate_params() + reg = SparkXGBRegressor(device="cuda") + reg._validate_params() + + clf = SparkXGBClassifier(device="cuda", tree_method="gpu_hist") + clf._validate_params() + clf = SparkXGBClassifier(device="cuda") + clf._validate_params() + class XgboostLocalTest(SparkTestCase): def setUp(self): @@ -1020,55 +1080,6 @@ class XgboostLocalTest(SparkTestCase): assert sklearn_regressor.max_depth == 3 assert sklearn_regressor.get_params()["sketch_eps"] == 0.5 - def test_regressor_params_basic(self): - py_reg = SparkXGBRegressor() - self.assertTrue(hasattr(py_reg, "n_estimators")) - self.assertEqual(py_reg.n_estimators.parent, py_reg.uid) - self.assertFalse(hasattr(py_reg, "gpu_id")) - self.assertFalse(hasattr(py_reg, "device")) - self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100) - self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror") - py_reg2 = SparkXGBRegressor(n_estimators=200) - self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200) - py_reg3 = py_reg2.copy({py_reg2.max_depth: 10}) - self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200) - self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10) - - def test_classifier_params_basic(self): - py_cls = SparkXGBClassifier() - self.assertTrue(hasattr(py_cls, "n_estimators")) - self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) - self.assertFalse(hasattr(py_cls, "gpu_id")) - self.assertFalse(hasattr(py_cls, "device")) - self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100) - self.assertEqual(py_cls.getOrDefault(py_cls.objective), None) - py_cls2 = SparkXGBClassifier(n_estimators=200) - self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200) - py_cls3 = py_cls2.copy({py_cls2.max_depth: 10}) - self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200) - self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10) - - def test_classifier_kwargs_basic(self): - py_cls = SparkXGBClassifier(**self.cls_params_kwargs) - self.assertTrue(hasattr(py_cls, "n_estimators")) - self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) - self.assertFalse(hasattr(py_cls, "gpu_id")) - self.assertFalse(hasattr(py_cls, "device")) - self.assertTrue(hasattr(py_cls, "arbitrary_params_dict")) - expected_kwargs = {"sketch_eps": 0.03} - self.assertEqual( - py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs - ) - - # Testing overwritten params - py_cls = SparkXGBClassifier() - py_cls.setParams(x=1, y=2) - py_cls.setParams(y=3, z=4) - xgb_params = py_cls._gen_xgb_params_dict() - assert xgb_params["x"] == 1 - assert xgb_params["y"] == 3 - assert xgb_params["z"] == 4 - def test_param_alias(self): py_cls = SparkXGBClassifier(features_col="f1", label_col="l1") self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1") @@ -1200,16 +1211,6 @@ class XgboostLocalTest(SparkTestCase): classifier = SparkXGBClassifier(num_workers=0) self.assertRaises(ValueError, classifier._validate_params) - def test_use_gpu_param(self): - classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact") - self.assertRaises(ValueError, classifier._validate_params) - regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact") - self.assertRaises(ValueError, regressor._validate_params) - regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist") - regressor = SparkXGBRegressor(use_gpu=True) - classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist") - classifier = SparkXGBClassifier(use_gpu=True) - def test_feature_importances(self): reg1 = SparkXGBRegressor(**self.reg_params) model = reg1.fit(self.reg_df_train)