[pyspark] Handle the device parameter in pyspark. (#9390)
- Handle the new `device` parameter in PySpark. - Deprecate the old `use_gpu` parameter.
This commit is contained in:
parent
2a0ff209ff
commit
6e18d3a290
@ -35,13 +35,13 @@ We can create a ``SparkXGBRegressor`` estimator like:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
The above snippet creates a spark estimator which can fit on a spark dataset,
|
The above snippet creates a spark estimator which can fit on a spark dataset, and return a
|
||||||
and return a spark model that can transform a spark dataset and generate dataset
|
spark model that can transform a spark dataset and generate dataset with prediction
|
||||||
with prediction column. We can set almost all of xgboost sklearn estimator parameters
|
column. We can set almost all of xgboost sklearn estimator parameters as
|
||||||
as ``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden
|
``SparkXGBRegressor`` parameters, but some parameter such as ``nthread`` is forbidden in
|
||||||
in spark estimator, and some parameters are replaced with pyspark specific parameters
|
spark estimator, and some parameters are replaced with pyspark specific parameters such as
|
||||||
such as ``weight_col``, ``validation_indicator_col``, ``use_gpu``, for details please see
|
``weight_col``, ``validation_indicator_col``, for details please see ``SparkXGBRegressor``
|
||||||
``SparkXGBRegressor`` doc.
|
doc.
|
||||||
|
|
||||||
The following code snippet shows how to train a spark xgboost regressor model,
|
The following code snippet shows how to train a spark xgboost regressor model,
|
||||||
first we need to prepare a training dataset as a spark dataframe contains
|
first we need to prepare a training dataset as a spark dataframe contains
|
||||||
@ -88,7 +88,7 @@ XGBoost PySpark fully supports GPU acceleration. Users are not only able to enab
|
|||||||
efficient training but also utilize their GPUs for the whole PySpark pipeline including
|
efficient training but also utilize their GPUs for the whole PySpark pipeline including
|
||||||
ETL and inference. In below sections, we will walk through an example of training on a
|
ETL and inference. In below sections, we will walk through an example of training on a
|
||||||
PySpark standalone GPU cluster. To get started, first we need to install some additional
|
PySpark standalone GPU cluster. To get started, first we need to install some additional
|
||||||
packages, then we can set the ``use_gpu`` parameter to ``True``.
|
packages, then we can set the ``device`` parameter to ``cuda`` or ``gpu``.
|
||||||
|
|
||||||
Prepare the necessary packages
|
Prepare the necessary packages
|
||||||
==============================
|
==============================
|
||||||
@ -128,7 +128,7 @@ Write your PySpark application
|
|||||||
==============================
|
==============================
|
||||||
|
|
||||||
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
|
Below snippet is a small example for training xgboost model with PySpark. Notice that we are
|
||||||
using a list of feature names and the additional parameter ``use_gpu``:
|
using a list of feature names and the additional parameter ``device``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -148,12 +148,12 @@ using a list of feature names and the additional parameter ``use_gpu``:
|
|||||||
# get a list with feature column names
|
# get a list with feature column names
|
||||||
feature_names = [x.name for x in train_df.schema if x.name != label_name]
|
feature_names = [x.name for x in train_df.schema if x.name != label_name]
|
||||||
|
|
||||||
# create a xgboost pyspark regressor estimator and set use_gpu=True
|
# create a xgboost pyspark regressor estimator and set device="cuda"
|
||||||
regressor = SparkXGBRegressor(
|
regressor = SparkXGBRegressor(
|
||||||
features_col=feature_names,
|
features_col=feature_names,
|
||||||
label_col=label_name,
|
label_col=label_name,
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
use_gpu=True,
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
# train and return the model
|
# train and return the model
|
||||||
@ -163,6 +163,7 @@ using a list of feature names and the additional parameter ``use_gpu``:
|
|||||||
predict_df = model.transform(test_df)
|
predict_df = model.transform(test_df)
|
||||||
predict_df.show()
|
predict_df.show()
|
||||||
|
|
||||||
|
Like other distributed interfaces, the ```device`` parameter doesn't support specifying ordinal as GPUs are managed by Spark instead of XGBoost (good: ``device=cuda``, bad: ``device=cuda:0``).
|
||||||
|
|
||||||
Submit the PySpark application
|
Submit the PySpark application
|
||||||
==============================
|
==============================
|
||||||
|
|||||||
@ -276,6 +276,27 @@ def _check_call(ret: int) -> None:
|
|||||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
|
||||||
|
"""Validate parameters in distributed environments."""
|
||||||
|
device = kwargs.get("device", None)
|
||||||
|
if device and not isinstance(device, str):
|
||||||
|
msg = "Invalid type for the `device` parameter"
|
||||||
|
msg += _expect((str,), type(device))
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
if device and device.find(":") != -1:
|
||||||
|
raise ValueError(
|
||||||
|
"Distributed training doesn't support selecting device ordinal as GPUs are"
|
||||||
|
" managed by the distributed framework. use `device=cuda` or `device=gpu`"
|
||||||
|
" instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
if kwargs.get("booster", None) == "gblinear":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"booster `{kwargs['booster']}` is not supported for distributed training."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_info() -> dict:
|
def build_info() -> dict:
|
||||||
"""Build information of XGBoost. The returned value format is not stable. Also,
|
"""Build information of XGBoost. The returned value format is not stable. Also,
|
||||||
please note that build time dependency is not the same as runtime dependency. For
|
please note that build time dependency is not the same as runtime dependency. For
|
||||||
|
|||||||
@ -70,6 +70,7 @@ from .core import (
|
|||||||
Metric,
|
Metric,
|
||||||
Objective,
|
Objective,
|
||||||
QuantileDMatrix,
|
QuantileDMatrix,
|
||||||
|
_check_distributed_params,
|
||||||
_deprecate_positional_args,
|
_deprecate_positional_args,
|
||||||
_expect,
|
_expect,
|
||||||
)
|
)
|
||||||
@ -924,17 +925,7 @@ async def _train_async(
|
|||||||
) -> Optional[TrainReturnT]:
|
) -> Optional[TrainReturnT]:
|
||||||
workers = _get_workers_from_data(dtrain, evals)
|
workers = _get_workers_from_data(dtrain, evals)
|
||||||
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
|
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
|
||||||
|
_check_distributed_params(params)
|
||||||
if params.get("booster", None) == "gblinear":
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"booster `{params['booster']}` is not yet supported for dask."
|
|
||||||
)
|
|
||||||
device = params.get("device", None)
|
|
||||||
if device and device.find(":") != -1:
|
|
||||||
raise ValueError(
|
|
||||||
"The dask interface for XGBoost doesn't support selecting specific device"
|
|
||||||
" ordinal. Use `device=cpu` or `device=cuda` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
def dispatched_train(
|
def dispatched_train(
|
||||||
parameters: Dict,
|
parameters: Dict,
|
||||||
|
|||||||
@ -1004,13 +1004,17 @@ class XGBModel(XGBModelBase):
|
|||||||
Validation metrics will help us track the performance of the model.
|
Validation metrics will help us track the performance of the model.
|
||||||
|
|
||||||
eval_metric : str, list of str, or callable, optional
|
eval_metric : str, list of str, or callable, optional
|
||||||
|
|
||||||
.. deprecated:: 1.6.0
|
.. deprecated:: 1.6.0
|
||||||
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
|
||||||
|
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
||||||
|
|
||||||
early_stopping_rounds : int
|
early_stopping_rounds : int
|
||||||
|
|
||||||
.. deprecated:: 1.6.0
|
.. deprecated:: 1.6.0
|
||||||
Use `early_stopping_rounds` in :py:meth:`__init__` or
|
|
||||||
:py:meth:`set_params` instead.
|
Use `early_stopping_rounds` in :py:meth:`__init__` or :py:meth:`set_params`
|
||||||
|
instead.
|
||||||
verbose :
|
verbose :
|
||||||
If `verbose` is True and an evaluation set is used, the evaluation metric
|
If `verbose` is True and an evaluation set is used, the evaluation metric
|
||||||
measured on the validation set is printed to stdout at each boosting stage.
|
measured on the validation set is printed to stdout at each boosting stage.
|
||||||
|
|||||||
@ -60,7 +60,7 @@ from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
|||||||
import xgboost
|
import xgboost
|
||||||
from xgboost import XGBClassifier
|
from xgboost import XGBClassifier
|
||||||
from xgboost.compat import is_cudf_available
|
from xgboost.compat import is_cudf_available
|
||||||
from xgboost.core import Booster
|
from xgboost.core import Booster, _check_distributed_params
|
||||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
||||||
from xgboost.training import train as worker_train
|
from xgboost.training import train as worker_train
|
||||||
|
|
||||||
@ -92,6 +92,7 @@ from .utils import (
|
|||||||
get_class_name,
|
get_class_name,
|
||||||
get_logger,
|
get_logger,
|
||||||
serialize_booster,
|
serialize_booster,
|
||||||
|
use_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Put pyspark specific params here, they won't be passed to XGBoost.
|
# Put pyspark specific params here, they won't be passed to XGBoost.
|
||||||
@ -108,7 +109,6 @@ _pyspark_specific_params = [
|
|||||||
"arbitrary_params_dict",
|
"arbitrary_params_dict",
|
||||||
"force_repartition",
|
"force_repartition",
|
||||||
"num_workers",
|
"num_workers",
|
||||||
"use_gpu",
|
|
||||||
"feature_names",
|
"feature_names",
|
||||||
"features_cols",
|
"features_cols",
|
||||||
"enable_sparse_data_optim",
|
"enable_sparse_data_optim",
|
||||||
@ -132,8 +132,7 @@ _pyspark_param_alias_map = {
|
|||||||
_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()}
|
_inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.items()}
|
||||||
|
|
||||||
_unsupported_xgb_params = [
|
_unsupported_xgb_params = [
|
||||||
"gpu_id", # we have "use_gpu" pyspark param instead.
|
"gpu_id", # we have "device" pyspark param instead.
|
||||||
"device", # we have "use_gpu" pyspark param instead.
|
|
||||||
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
||||||
"use_label_encoder",
|
"use_label_encoder",
|
||||||
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
||||||
@ -198,11 +197,24 @@ class _SparkXGBParams(
|
|||||||
"The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.",
|
"The number of XGBoost workers. Each XGBoost worker corresponds to one spark task.",
|
||||||
TypeConverters.toInt,
|
TypeConverters.toInt,
|
||||||
)
|
)
|
||||||
|
device = Param(
|
||||||
|
Params._dummy(),
|
||||||
|
"device",
|
||||||
|
(
|
||||||
|
"The device type for XGBoost executors. Available options are `cpu`,`cuda`"
|
||||||
|
" and `gpu`. Set `device` to `cuda` or `gpu` if the executors are running "
|
||||||
|
"on GPU instances. Currently, only one GPU per task is supported."
|
||||||
|
),
|
||||||
|
TypeConverters.toString,
|
||||||
|
)
|
||||||
use_gpu = Param(
|
use_gpu = Param(
|
||||||
Params._dummy(),
|
Params._dummy(),
|
||||||
"use_gpu",
|
"use_gpu",
|
||||||
"A boolean variable. Set use_gpu=true if the executors "
|
(
|
||||||
+ "are running on GPU instances. Currently, only one GPU per task is supported.",
|
"Deprecated, use `device` instead. A boolean variable. Set use_gpu=true "
|
||||||
|
"if the executors are running on GPU instances. Currently, only one GPU per"
|
||||||
|
" task is supported."
|
||||||
|
),
|
||||||
TypeConverters.toBoolean,
|
TypeConverters.toBoolean,
|
||||||
)
|
)
|
||||||
force_repartition = Param(
|
force_repartition = Param(
|
||||||
@ -336,10 +348,20 @@ class _SparkXGBParams(
|
|||||||
f"It cannot be less than 1 [Default is 1]"
|
f"It cannot be less than 1 [Default is 1]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tree_method = self.getOrDefault(self.getParam("tree_method"))
|
||||||
|
if (
|
||||||
|
self.getOrDefault(self.use_gpu) or use_cuda(self.getOrDefault(self.device))
|
||||||
|
) and not _can_use_qdm(tree_method):
|
||||||
|
raise ValueError(
|
||||||
|
f"The `{tree_method}` tree method is not supported on GPU."
|
||||||
|
)
|
||||||
|
|
||||||
if self.getOrDefault(self.features_cols):
|
if self.getOrDefault(self.features_cols):
|
||||||
if not self.getOrDefault(self.use_gpu):
|
if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault(
|
||||||
|
self.use_gpu
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"features_col param with list value requires enabling use_gpu."
|
"features_col param with list value requires `device=cuda`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.getOrDefault("objective") is not None:
|
if self.getOrDefault("objective") is not None:
|
||||||
@ -392,17 +414,7 @@ class _SparkXGBParams(
|
|||||||
"`pyspark.ml.linalg.Vector` type."
|
"`pyspark.ml.linalg.Vector` type."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.getOrDefault(self.use_gpu):
|
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
|
||||||
tree_method = self.getParam("tree_method")
|
|
||||||
if (
|
|
||||||
self.getOrDefault(tree_method) is not None
|
|
||||||
and self.getOrDefault(tree_method) != "gpu_hist"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"tree_method should be 'gpu_hist' or None when use_gpu is True,"
|
|
||||||
f"found {self.getOrDefault(tree_method)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
gpu_per_task = (
|
gpu_per_task = (
|
||||||
_get_spark_session()
|
_get_spark_session()
|
||||||
.sparkContext.getConf()
|
.sparkContext.getConf()
|
||||||
@ -424,8 +436,8 @@ class _SparkXGBParams(
|
|||||||
# so it's okay for printing the below warning instead of checking the real
|
# so it's okay for printing the below warning instead of checking the real
|
||||||
# gpu numbers and raising the exception.
|
# gpu numbers and raising the exception.
|
||||||
get_logger(self.__class__.__name__).warning(
|
get_logger(self.__class__.__name__).warning(
|
||||||
"You enabled use_gpu in spark local mode. Please make sure your local node "
|
"You enabled GPU in spark local mode. Please make sure your local "
|
||||||
"has at least %d GPUs",
|
"node has at least %d GPUs",
|
||||||
self.getOrDefault(self.num_workers),
|
self.getOrDefault(self.num_workers),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -558,6 +570,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# they are added in `setParams`.
|
# they are added in `setParams`.
|
||||||
self._setDefault(
|
self._setDefault(
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
|
device="cpu",
|
||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
force_repartition=False,
|
force_repartition=False,
|
||||||
repartition_random_shuffle=False,
|
repartition_random_shuffle=False,
|
||||||
@ -566,9 +579,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
arbitrary_params_dict={},
|
arbitrary_params_dict={},
|
||||||
)
|
)
|
||||||
|
|
||||||
def setParams(
|
def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name
|
||||||
self, **kwargs: Dict[str, Any]
|
|
||||||
) -> None: # pylint: disable=invalid-name
|
|
||||||
"""
|
"""
|
||||||
Set params for the estimator.
|
Set params for the estimator.
|
||||||
"""
|
"""
|
||||||
@ -613,6 +624,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
)
|
)
|
||||||
raise ValueError(err_msg)
|
raise ValueError(err_msg)
|
||||||
_extra_params[k] = v
|
_extra_params[k] = v
|
||||||
|
|
||||||
|
_check_distributed_params(kwargs)
|
||||||
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
|
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
|
||||||
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
||||||
|
|
||||||
@ -709,9 +722,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# TODO: support "num_parallel_tree" for random forest
|
# TODO: support "num_parallel_tree" for random forest
|
||||||
params["num_boost_round"] = self.getOrDefault("n_estimators")
|
params["num_boost_round"] = self.getOrDefault("n_estimators")
|
||||||
|
|
||||||
if self.getOrDefault(self.use_gpu):
|
|
||||||
params["tree_method"] = "gpu_hist"
|
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -883,8 +893,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
dmatrix_kwargs,
|
dmatrix_kwargs,
|
||||||
) = self._get_xgb_parameters(dataset)
|
) = self._get_xgb_parameters(dataset)
|
||||||
|
|
||||||
use_gpu = self.getOrDefault(self.use_gpu)
|
run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(
|
||||||
|
self.use_gpu
|
||||||
|
)
|
||||||
is_local = _is_local(_get_spark_session().sparkContext)
|
is_local = _is_local(_get_spark_session().sparkContext)
|
||||||
|
|
||||||
num_workers = self.getOrDefault(self.num_workers)
|
num_workers = self.getOrDefault(self.num_workers)
|
||||||
@ -903,7 +914,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
dev_ordinal = None
|
dev_ordinal = None
|
||||||
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
||||||
|
|
||||||
if use_gpu:
|
if run_on_gpu:
|
||||||
dev_ordinal = (
|
dev_ordinal = (
|
||||||
context.partitionId() if is_local else _get_gpu_id(context)
|
context.partitionId() if is_local else _get_gpu_id(context)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,8 +3,8 @@
|
|||||||
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
|
||||||
# pylint: disable=unused-argument, too-many-locals
|
# pylint: disable=unused-argument, too-many-locals
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import Any, List, Optional, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pyspark import keyword_only
|
from pyspark import keyword_only
|
||||||
@ -77,27 +77,35 @@ def _set_pyspark_xgb_cls_param_attrs(
|
|||||||
set_param_attrs(name, param_obj)
|
set_param_attrs(name, param_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _deprecated_use_gpu() -> None:
|
||||||
|
warnings.warn(
|
||||||
|
"`use_gpu` is deprecated since 2.0.0, use `device` instead", FutureWarning
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SparkXGBRegressor(_SparkXGBEstimator):
|
class SparkXGBRegressor(_SparkXGBEstimator):
|
||||||
"""SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression
|
"""SparkXGBRegressor is a PySpark ML estimator. It implements the XGBoost regression
|
||||||
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
|
algorithm based on XGBoost python library, and it can be used in PySpark Pipeline
|
||||||
and PySpark ML meta algorithms like :py:class:`~pyspark.ml.tuning.CrossValidator`/
|
and PySpark ML meta algorithms like
|
||||||
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
- :py:class:`~pyspark.ml.tuning.CrossValidator`/
|
||||||
:py:class:`~pyspark.ml.classification.OneVsRest`
|
- :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
||||||
|
- :py:class:`~pyspark.ml.classification.OneVsRest`
|
||||||
|
|
||||||
SparkXGBRegressor automatically supports most of the parameters in
|
SparkXGBRegressor automatically supports most of the parameters in
|
||||||
:py:class:`xgboost.XGBRegressor` constructor and most of the parameters used in
|
:py:class:`xgboost.XGBRegressor` constructor and most of the parameters used in
|
||||||
:py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict` method.
|
:py:meth:`xgboost.XGBRegressor.fit` and :py:meth:`xgboost.XGBRegressor.predict`
|
||||||
|
method.
|
||||||
|
|
||||||
SparkXGBRegressor doesn't support setting `device` but supports another param
|
To enable GPU support, set `device` to `cuda` or `gpu`.
|
||||||
`use_gpu`, see doc below for more details.
|
|
||||||
|
|
||||||
SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but support
|
SparkXGBRegressor doesn't support setting `base_margin` explicitly as well, but
|
||||||
another param called `base_margin_col`. see doc below for more details.
|
support another param called `base_margin_col`. see doc below for more details.
|
||||||
|
|
||||||
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
|
SparkXGBRegressor doesn't support `validate_features` and `output_margin` param.
|
||||||
|
|
||||||
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
SparkXGBRegressor doesn't support setting `nthread` xgboost param, instead, the
|
||||||
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
`nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
|
||||||
|
config value.
|
||||||
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -133,8 +141,11 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
|||||||
How many XGBoost workers to be used to train.
|
How many XGBoost workers to be used to train.
|
||||||
Each XGBoost worker corresponds to one spark task.
|
Each XGBoost worker corresponds to one spark task.
|
||||||
use_gpu:
|
use_gpu:
|
||||||
Boolean value to specify whether the executors are running on GPU
|
.. deprecated:: 2.0.0
|
||||||
instances.
|
|
||||||
|
Use `device` instead.
|
||||||
|
device:
|
||||||
|
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
|
||||||
force_repartition:
|
force_repartition:
|
||||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||||
before XGBoost training.
|
before XGBoost training.
|
||||||
@ -193,14 +204,17 @@ class SparkXGBRegressor(_SparkXGBEstimator):
|
|||||||
weight_col: Optional[str] = None,
|
weight_col: Optional[str] = None,
|
||||||
base_margin_col: Optional[str] = None,
|
base_margin_col: Optional[str] = None,
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
use_gpu: bool = False,
|
use_gpu: Optional[bool] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
force_repartition: bool = False,
|
force_repartition: bool = False,
|
||||||
repartition_random_shuffle: bool = False,
|
repartition_random_shuffle: bool = False,
|
||||||
enable_sparse_data_optim: bool = False,
|
enable_sparse_data_optim: bool = False,
|
||||||
**kwargs: Dict[str, Any],
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
input_kwargs = self._input_kwargs
|
input_kwargs = self._input_kwargs
|
||||||
|
if use_gpu:
|
||||||
|
_deprecated_use_gpu()
|
||||||
self.setParams(**input_kwargs)
|
self.setParams(**input_kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -238,27 +252,29 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
"""SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost
|
"""SparkXGBClassifier is a PySpark ML estimator. It implements the XGBoost
|
||||||
classification algorithm based on XGBoost python library, and it can be used in
|
classification algorithm based on XGBoost python library, and it can be used in
|
||||||
PySpark Pipeline and PySpark ML meta algorithms like
|
PySpark Pipeline and PySpark ML meta algorithms like
|
||||||
:py:class:`~pyspark.ml.tuning.CrossValidator`/
|
- :py:class:`~pyspark.ml.tuning.CrossValidator`/
|
||||||
:py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
- :py:class:`~pyspark.ml.tuning.TrainValidationSplit`/
|
||||||
:py:class:`~pyspark.ml.classification.OneVsRest`
|
- :py:class:`~pyspark.ml.classification.OneVsRest`
|
||||||
|
|
||||||
SparkXGBClassifier automatically supports most of the parameters in
|
SparkXGBClassifier automatically supports most of the parameters in
|
||||||
:py:class:`xgboost.XGBClassifier` constructor and most of the parameters used in
|
:py:class:`xgboost.XGBClassifier` constructor and most of the parameters used in
|
||||||
:py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict` method.
|
:py:meth:`xgboost.XGBClassifier.fit` and :py:meth:`xgboost.XGBClassifier.predict`
|
||||||
|
method.
|
||||||
|
|
||||||
SparkXGBClassifier doesn't support setting `device` but support another param
|
To enable GPU support, set `device` to `cuda` or `gpu`.
|
||||||
`use_gpu`, see doc below for more details.
|
|
||||||
|
|
||||||
SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but support
|
SparkXGBClassifier doesn't support setting `base_margin` explicitly as well, but
|
||||||
another param called `base_margin_col`. see doc below for more details.
|
support another param called `base_margin_col`. see doc below for more details.
|
||||||
|
|
||||||
SparkXGBClassifier doesn't support setting `output_margin`, but we can get output margin
|
SparkXGBClassifier doesn't support setting `output_margin`, but we can get output
|
||||||
from the raw prediction column. See `raw_prediction_col` param doc below for more details.
|
margin from the raw prediction column. See `raw_prediction_col` param doc below for
|
||||||
|
more details.
|
||||||
|
|
||||||
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
|
SparkXGBClassifier doesn't support `validate_features` and `output_margin` param.
|
||||||
|
|
||||||
SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
SparkXGBClassifier doesn't support setting `nthread` xgboost param, instead, the
|
||||||
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
`nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
|
||||||
|
config value.
|
||||||
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -300,8 +316,11 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
How many XGBoost workers to be used to train.
|
How many XGBoost workers to be used to train.
|
||||||
Each XGBoost worker corresponds to one spark task.
|
Each XGBoost worker corresponds to one spark task.
|
||||||
use_gpu:
|
use_gpu:
|
||||||
Boolean value to specify whether the executors are running on GPU
|
.. deprecated:: 2.0.0
|
||||||
instances.
|
|
||||||
|
Use `device` instead.
|
||||||
|
device:
|
||||||
|
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
|
||||||
force_repartition:
|
force_repartition:
|
||||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||||
before XGBoost training.
|
before XGBoost training.
|
||||||
@ -360,11 +379,12 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
weight_col: Optional[str] = None,
|
weight_col: Optional[str] = None,
|
||||||
base_margin_col: Optional[str] = None,
|
base_margin_col: Optional[str] = None,
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
use_gpu: bool = False,
|
use_gpu: Optional[bool] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
force_repartition: bool = False,
|
force_repartition: bool = False,
|
||||||
repartition_random_shuffle: bool = False,
|
repartition_random_shuffle: bool = False,
|
||||||
enable_sparse_data_optim: bool = False,
|
enable_sparse_data_optim: bool = False,
|
||||||
**kwargs: Dict[str, Any],
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
|
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
|
||||||
@ -372,6 +392,8 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
# binary or multinomial input dataset, and we need to remove the fixed default
|
# binary or multinomial input dataset, and we need to remove the fixed default
|
||||||
# param value as well to avoid causing ambiguity.
|
# param value as well to avoid causing ambiguity.
|
||||||
input_kwargs = self._input_kwargs
|
input_kwargs = self._input_kwargs
|
||||||
|
if use_gpu:
|
||||||
|
_deprecated_use_gpu()
|
||||||
self.setParams(**input_kwargs)
|
self.setParams(**input_kwargs)
|
||||||
self._setDefault(objective=None)
|
self._setDefault(objective=None)
|
||||||
|
|
||||||
@ -422,19 +444,20 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
|||||||
:py:class:`xgboost.XGBRanker` constructor and most of the parameters used in
|
:py:class:`xgboost.XGBRanker` constructor and most of the parameters used in
|
||||||
:py:meth:`xgboost.XGBRanker.fit` and :py:meth:`xgboost.XGBRanker.predict` method.
|
:py:meth:`xgboost.XGBRanker.fit` and :py:meth:`xgboost.XGBRanker.predict` method.
|
||||||
|
|
||||||
SparkXGBRanker doesn't support setting `device` but support another param `use_gpu`,
|
To enable GPU support, set `device` to `cuda` or `gpu`.
|
||||||
see doc below for more details.
|
|
||||||
|
|
||||||
SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support
|
SparkXGBRanker doesn't support setting `base_margin` explicitly as well, but support
|
||||||
another param called `base_margin_col`. see doc below for more details.
|
another param called `base_margin_col`. see doc below for more details.
|
||||||
|
|
||||||
SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin
|
SparkXGBRanker doesn't support setting `output_margin`, but we can get output margin
|
||||||
from the raw prediction column. See `raw_prediction_col` param doc below for more details.
|
from the raw prediction column. See `raw_prediction_col` param doc below for more
|
||||||
|
details.
|
||||||
|
|
||||||
SparkXGBRanker doesn't support `validate_features` and `output_margin` param.
|
SparkXGBRanker doesn't support `validate_features` and `output_margin` param.
|
||||||
|
|
||||||
SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the `nthread`
|
SparkXGBRanker doesn't support setting `nthread` xgboost param, instead, the
|
||||||
param for each xgboost worker will be set equal to `spark.task.cpus` config value.
|
`nthread` param for each xgboost worker will be set equal to `spark.task.cpus`
|
||||||
|
config value.
|
||||||
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -467,13 +490,15 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
|||||||
:py:class:`xgboost.XGBRanker` fit method.
|
:py:class:`xgboost.XGBRanker` fit method.
|
||||||
qid_col:
|
qid_col:
|
||||||
Query id column name.
|
Query id column name.
|
||||||
|
|
||||||
num_workers:
|
num_workers:
|
||||||
How many XGBoost workers to be used to train.
|
How many XGBoost workers to be used to train.
|
||||||
Each XGBoost worker corresponds to one spark task.
|
Each XGBoost worker corresponds to one spark task.
|
||||||
use_gpu:
|
use_gpu:
|
||||||
Boolean value to specify whether the executors are running on GPU
|
.. deprecated:: 2.0.0
|
||||||
instances.
|
|
||||||
|
Use `device` instead.
|
||||||
|
device:
|
||||||
|
Device for XGBoost workers, available options are `cpu`, `cuda`, and `gpu`.
|
||||||
force_repartition:
|
force_repartition:
|
||||||
Boolean value to specify if forcing the input dataset to be repartitioned
|
Boolean value to specify if forcing the input dataset to be repartitioned
|
||||||
before XGBoost training.
|
before XGBoost training.
|
||||||
@ -538,14 +563,17 @@ class SparkXGBRanker(_SparkXGBEstimator):
|
|||||||
base_margin_col: Optional[str] = None,
|
base_margin_col: Optional[str] = None,
|
||||||
qid_col: Optional[str] = None,
|
qid_col: Optional[str] = None,
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
use_gpu: bool = False,
|
use_gpu: Optional[bool] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
force_repartition: bool = False,
|
force_repartition: bool = False,
|
||||||
repartition_random_shuffle: bool = False,
|
repartition_random_shuffle: bool = False,
|
||||||
enable_sparse_data_optim: bool = False,
|
enable_sparse_data_optim: bool = False,
|
||||||
**kwargs: Dict[str, Any],
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
input_kwargs = self._input_kwargs
|
input_kwargs = self._input_kwargs
|
||||||
|
if use_gpu:
|
||||||
|
_deprecated_use_gpu()
|
||||||
self.setParams(**input_kwargs)
|
self.setParams(**input_kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Dict, Set, Type
|
from typing import Any, Callable, Dict, Optional, Set, Type
|
||||||
|
|
||||||
import pyspark
|
import pyspark
|
||||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles
|
from pyspark import BarrierTaskContext, SparkContext, SparkFiles
|
||||||
@ -186,3 +186,8 @@ def deserialize_booster(model: str) -> Booster:
|
|||||||
f.write(model)
|
f.write(model)
|
||||||
booster.load_model(tmp_file_name)
|
booster.load_model(tmp_file_name)
|
||||||
return booster
|
return booster
|
||||||
|
|
||||||
|
|
||||||
|
def use_cuda(device: Optional[str]) -> bool:
|
||||||
|
"""Whether xgboost is using CUDA workers."""
|
||||||
|
return device in ("cuda", "gpu")
|
||||||
|
|||||||
@ -98,8 +98,8 @@ void MismatchedDevices(Context const* booster, Context const* data) {
|
|||||||
- Use a data structure that matches the device ordinal in the booster.
|
- Use a data structure that matches the device ordinal in the booster.
|
||||||
- Set the device for booster before call to inplace_predict.
|
- Set the device for booster before call to inplace_predict.
|
||||||
|
|
||||||
This warning will only be shown once, and subsequent warnings made by the current thread will be
|
This warning will only be shown once for each thread. Subsequent warnings made by the
|
||||||
suppressed.
|
current thread will be suppressed.
|
||||||
)";
|
)";
|
||||||
logged = true;
|
logged = true;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -154,7 +154,7 @@ def spark_diabetes_dataset_feature_cols(spark_session_with_gpu):
|
|||||||
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
|
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
|
||||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||||
|
|
||||||
classifier = SparkXGBClassifier(use_gpu=True, num_workers=num_workers)
|
classifier = SparkXGBClassifier(device="cuda", num_workers=num_workers)
|
||||||
train_df, test_df = spark_iris_dataset
|
train_df, test_df = spark_iris_dataset
|
||||||
model = classifier.fit(train_df)
|
model = classifier.fit(train_df)
|
||||||
pred_result_df = model.transform(test_df)
|
pred_result_df = model.transform(test_df)
|
||||||
@ -169,7 +169,7 @@ def test_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature_co
|
|||||||
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
||||||
|
|
||||||
classifier = SparkXGBClassifier(
|
classifier = SparkXGBClassifier(
|
||||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
features_col=feature_names, device="cuda", num_workers=num_workers
|
||||||
)
|
)
|
||||||
|
|
||||||
model = classifier.fit(train_df)
|
model = classifier.fit(train_df)
|
||||||
@ -185,7 +185,7 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature
|
|||||||
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
train_df, test_df, feature_names = spark_iris_dataset_feature_cols
|
||||||
|
|
||||||
classifier = SparkXGBClassifier(
|
classifier = SparkXGBClassifier(
|
||||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
features_col=feature_names, device="cuda", num_workers=num_workers
|
||||||
)
|
)
|
||||||
grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build()
|
grid = ParamGridBuilder().addGrid(classifier.max_depth, [6, 8]).build()
|
||||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||||
@ -197,11 +197,24 @@ def test_cv_sparkxgb_classifier_feature_cols_with_gpu(spark_iris_dataset_feature
|
|||||||
f1 = evaluator.evaluate(pred_result_df)
|
f1 = evaluator.evaluate(pred_result_df)
|
||||||
assert f1 >= 0.97
|
assert f1 >= 0.97
|
||||||
|
|
||||||
|
clf = SparkXGBClassifier(
|
||||||
|
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
||||||
|
)
|
||||||
|
grid = ParamGridBuilder().addGrid(clf.max_depth, [6, 8]).build()
|
||||||
|
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||||
|
cv = CrossValidator(
|
||||||
|
estimator=clf, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3
|
||||||
|
)
|
||||||
|
cvModel = cv.fit(train_df)
|
||||||
|
pred_result_df = cvModel.transform(test_df)
|
||||||
|
f1 = evaluator.evaluate(pred_result_df)
|
||||||
|
assert f1 >= 0.97
|
||||||
|
|
||||||
|
|
||||||
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
|
def test_sparkxgb_regressor_with_gpu(spark_diabetes_dataset):
|
||||||
from pyspark.ml.evaluation import RegressionEvaluator
|
from pyspark.ml.evaluation import RegressionEvaluator
|
||||||
|
|
||||||
regressor = SparkXGBRegressor(use_gpu=True, num_workers=num_workers)
|
regressor = SparkXGBRegressor(device="cuda", num_workers=num_workers)
|
||||||
train_df, test_df = spark_diabetes_dataset
|
train_df, test_df = spark_diabetes_dataset
|
||||||
model = regressor.fit(train_df)
|
model = regressor.fit(train_df)
|
||||||
pred_result_df = model.transform(test_df)
|
pred_result_df = model.transform(test_df)
|
||||||
@ -215,7 +228,7 @@ def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature
|
|||||||
|
|
||||||
train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols
|
train_df, test_df, feature_names = spark_diabetes_dataset_feature_cols
|
||||||
regressor = SparkXGBRegressor(
|
regressor = SparkXGBRegressor(
|
||||||
features_col=feature_names, use_gpu=True, num_workers=num_workers
|
features_col=feature_names, device="cuda", num_workers=num_workers
|
||||||
)
|
)
|
||||||
|
|
||||||
model = regressor.fit(train_df)
|
model = regressor.fit(train_df)
|
||||||
|
|||||||
@ -741,11 +741,6 @@ class TestPySparkLocal:
|
|||||||
with pytest.raises(ValueError, match="early_stopping_rounds"):
|
with pytest.raises(ValueError, match="early_stopping_rounds"):
|
||||||
classifier.fit(clf_data.cls_df_train)
|
classifier.fit(clf_data.cls_df_train)
|
||||||
|
|
||||||
def test_gpu_param_setting(self, clf_data: ClfData) -> None:
|
|
||||||
py_cls = SparkXGBClassifier(use_gpu=True)
|
|
||||||
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
|
|
||||||
assert train_params["tree_method"] == "gpu_hist"
|
|
||||||
|
|
||||||
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
|
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
|
||||||
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
|
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
|
||||||
model = classifier.fit(clf_data.cls_df_train)
|
model = classifier.fit(clf_data.cls_df_train)
|
||||||
@ -756,6 +751,53 @@ class TestPySparkLocal:
|
|||||||
model = classifier.fit(clf_data.cls_df_train)
|
model = classifier.fit(clf_data.cls_df_train)
|
||||||
model.transform(clf_data.cls_df_test).collect()
|
model.transform(clf_data.cls_df_test).collect()
|
||||||
|
|
||||||
|
def test_regressor_params_basic(self) -> None:
|
||||||
|
py_reg = SparkXGBRegressor()
|
||||||
|
assert hasattr(py_reg, "n_estimators")
|
||||||
|
assert py_reg.n_estimators.parent == py_reg.uid
|
||||||
|
assert not hasattr(py_reg, "gpu_id")
|
||||||
|
assert hasattr(py_reg, "device")
|
||||||
|
assert py_reg.getOrDefault(py_reg.n_estimators) == 100
|
||||||
|
assert py_reg.getOrDefault(getattr(py_reg, "objective")), "reg:squarederror"
|
||||||
|
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
||||||
|
assert py_reg2.getOrDefault(getattr(py_reg2, "n_estimators")), 200
|
||||||
|
py_reg3 = py_reg2.copy({getattr(py_reg2, "max_depth"): 10})
|
||||||
|
assert py_reg3.getOrDefault(getattr(py_reg3, "n_estimators")), 200
|
||||||
|
assert py_reg3.getOrDefault(getattr(py_reg3, "max_depth")), 10
|
||||||
|
|
||||||
|
def test_classifier_params_basic(self) -> None:
|
||||||
|
py_clf = SparkXGBClassifier()
|
||||||
|
assert hasattr(py_clf, "n_estimators")
|
||||||
|
assert py_clf.n_estimators.parent == py_clf.uid
|
||||||
|
assert not hasattr(py_clf, "gpu_id")
|
||||||
|
assert hasattr(py_clf, "device")
|
||||||
|
|
||||||
|
assert py_clf.getOrDefault(py_clf.n_estimators) == 100
|
||||||
|
assert py_clf.getOrDefault(getattr(py_clf, "objective")) is None
|
||||||
|
py_clf2 = SparkXGBClassifier(n_estimators=200)
|
||||||
|
assert py_clf2.getOrDefault(getattr(py_clf2, "n_estimators")) == 200
|
||||||
|
py_clf3 = py_clf2.copy({getattr(py_clf2, "max_depth"): 10})
|
||||||
|
assert py_clf3.getOrDefault(getattr(py_clf3, "n_estimators")) == 200
|
||||||
|
assert py_clf3.getOrDefault(getattr(py_clf3, "max_depth")), 10
|
||||||
|
|
||||||
|
def test_classifier_kwargs_basic(self, clf_data: ClfData) -> None:
|
||||||
|
py_clf = SparkXGBClassifier(**clf_data.cls_params)
|
||||||
|
assert hasattr(py_clf, "n_estimators")
|
||||||
|
assert py_clf.n_estimators.parent == py_clf.uid
|
||||||
|
assert not hasattr(py_clf, "gpu_id")
|
||||||
|
assert hasattr(py_clf, "device")
|
||||||
|
assert hasattr(py_clf, "arbitrary_params_dict")
|
||||||
|
assert py_clf.getOrDefault(py_clf.arbitrary_params_dict) == {}
|
||||||
|
|
||||||
|
# Testing overwritten params
|
||||||
|
py_clf = SparkXGBClassifier()
|
||||||
|
py_clf.setParams(x=1, y=2)
|
||||||
|
py_clf.setParams(y=3, z=4)
|
||||||
|
xgb_params = py_clf._gen_xgb_params_dict()
|
||||||
|
assert xgb_params["x"] == 1
|
||||||
|
assert xgb_params["y"] == 3
|
||||||
|
assert xgb_params["z"] == 4
|
||||||
|
|
||||||
def test_regressor_model_save_load(self, reg_data: RegData) -> None:
|
def test_regressor_model_save_load(self, reg_data: RegData) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
path = "file:" + tmpdir
|
path = "file:" + tmpdir
|
||||||
@ -826,6 +868,24 @@ class TestPySparkLocal:
|
|||||||
)
|
)
|
||||||
assert_model_compatible(model.stages[0], tmpdir)
|
assert_model_compatible(model.stages[0], tmpdir)
|
||||||
|
|
||||||
|
def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None:
|
||||||
|
clf = SparkXGBClassifier(device="cuda", tree_method="exact")
|
||||||
|
with pytest.raises(ValueError, match="not supported on GPU"):
|
||||||
|
clf.fit(clf_data.cls_df_train)
|
||||||
|
regressor = SparkXGBRegressor(device="cuda", tree_method="exact")
|
||||||
|
with pytest.raises(ValueError, match="not supported on GPU"):
|
||||||
|
regressor.fit(reg_data.reg_df_train)
|
||||||
|
|
||||||
|
reg = SparkXGBRegressor(device="cuda", tree_method="gpu_hist")
|
||||||
|
reg._validate_params()
|
||||||
|
reg = SparkXGBRegressor(device="cuda")
|
||||||
|
reg._validate_params()
|
||||||
|
|
||||||
|
clf = SparkXGBClassifier(device="cuda", tree_method="gpu_hist")
|
||||||
|
clf._validate_params()
|
||||||
|
clf = SparkXGBClassifier(device="cuda")
|
||||||
|
clf._validate_params()
|
||||||
|
|
||||||
|
|
||||||
class XgboostLocalTest(SparkTestCase):
|
class XgboostLocalTest(SparkTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1020,55 +1080,6 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
assert sklearn_regressor.max_depth == 3
|
assert sklearn_regressor.max_depth == 3
|
||||||
assert sklearn_regressor.get_params()["sketch_eps"] == 0.5
|
assert sklearn_regressor.get_params()["sketch_eps"] == 0.5
|
||||||
|
|
||||||
def test_regressor_params_basic(self):
|
|
||||||
py_reg = SparkXGBRegressor()
|
|
||||||
self.assertTrue(hasattr(py_reg, "n_estimators"))
|
|
||||||
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
|
|
||||||
self.assertFalse(hasattr(py_reg, "gpu_id"))
|
|
||||||
self.assertFalse(hasattr(py_reg, "device"))
|
|
||||||
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
|
|
||||||
self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror")
|
|
||||||
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
|
||||||
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
|
|
||||||
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
|
|
||||||
self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200)
|
|
||||||
self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10)
|
|
||||||
|
|
||||||
def test_classifier_params_basic(self):
|
|
||||||
py_cls = SparkXGBClassifier()
|
|
||||||
self.assertTrue(hasattr(py_cls, "n_estimators"))
|
|
||||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
|
||||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
|
||||||
self.assertFalse(hasattr(py_cls, "device"))
|
|
||||||
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
|
|
||||||
self.assertEqual(py_cls.getOrDefault(py_cls.objective), None)
|
|
||||||
py_cls2 = SparkXGBClassifier(n_estimators=200)
|
|
||||||
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
|
|
||||||
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
|
|
||||||
self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200)
|
|
||||||
self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10)
|
|
||||||
|
|
||||||
def test_classifier_kwargs_basic(self):
|
|
||||||
py_cls = SparkXGBClassifier(**self.cls_params_kwargs)
|
|
||||||
self.assertTrue(hasattr(py_cls, "n_estimators"))
|
|
||||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
|
||||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
|
||||||
self.assertFalse(hasattr(py_cls, "device"))
|
|
||||||
self.assertTrue(hasattr(py_cls, "arbitrary_params_dict"))
|
|
||||||
expected_kwargs = {"sketch_eps": 0.03}
|
|
||||||
self.assertEqual(
|
|
||||||
py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Testing overwritten params
|
|
||||||
py_cls = SparkXGBClassifier()
|
|
||||||
py_cls.setParams(x=1, y=2)
|
|
||||||
py_cls.setParams(y=3, z=4)
|
|
||||||
xgb_params = py_cls._gen_xgb_params_dict()
|
|
||||||
assert xgb_params["x"] == 1
|
|
||||||
assert xgb_params["y"] == 3
|
|
||||||
assert xgb_params["z"] == 4
|
|
||||||
|
|
||||||
def test_param_alias(self):
|
def test_param_alias(self):
|
||||||
py_cls = SparkXGBClassifier(features_col="f1", label_col="l1")
|
py_cls = SparkXGBClassifier(features_col="f1", label_col="l1")
|
||||||
self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1")
|
self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1")
|
||||||
@ -1200,16 +1211,6 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
classifier = SparkXGBClassifier(num_workers=0)
|
classifier = SparkXGBClassifier(num_workers=0)
|
||||||
self.assertRaises(ValueError, classifier._validate_params)
|
self.assertRaises(ValueError, classifier._validate_params)
|
||||||
|
|
||||||
def test_use_gpu_param(self):
|
|
||||||
classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact")
|
|
||||||
self.assertRaises(ValueError, classifier._validate_params)
|
|
||||||
regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact")
|
|
||||||
self.assertRaises(ValueError, regressor._validate_params)
|
|
||||||
regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist")
|
|
||||||
regressor = SparkXGBRegressor(use_gpu=True)
|
|
||||||
classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist")
|
|
||||||
classifier = SparkXGBClassifier(use_gpu=True)
|
|
||||||
|
|
||||||
def test_feature_importances(self):
|
def test_feature_importances(self):
|
||||||
reg1 = SparkXGBRegressor(**self.reg_params)
|
reg1 = SparkXGBRegressor(**self.reg_params)
|
||||||
model = reg1.fit(self.reg_df_train)
|
model = reg1.fit(self.reg_df_train)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user