From 13235313231f073a710a3bc05e5ab6d1d94b1e52 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 27 Oct 2023 11:21:30 +0800 Subject: [PATCH] [pyspark] unify the way for determining whether runs on the GPU. (#9724) --- python-package/xgboost/spark/core.py | 75 ++++++++++--------- .../test_gpu_with_spark/test_gpu_spark.py | 8 +- .../test_with_spark/test_spark_local.py | 24 +++++- 3 files changed, 64 insertions(+), 43 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 9fe73005a..bad3a2382 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -347,7 +347,7 @@ class _SparkXGBParams( def _validate_gpu_params(self) -> None: """Validate the gpu parameters and gpu configurations""" - if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): + if self._run_on_gpu(): ss = _get_spark_session() sc = ss.sparkContext @@ -414,9 +414,7 @@ class _SparkXGBParams( ) if self.getOrDefault(self.features_cols): - if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault( - self.use_gpu - ): + if not self._run_on_gpu(): raise ValueError( "features_col param with list value requires `device=cuda`." ) @@ -473,6 +471,15 @@ class _SparkXGBParams( self._validate_gpu_params() + def _run_on_gpu(self) -> bool: + """If train or transform on the gpu according to the parameters""" + + return ( + use_cuda(self.getOrDefault(self.device)) + or self.getOrDefault(self.use_gpu) + or self.getOrDefault(self.getParam("tree_method")) == "gpu_hist" + ) + def _validate_and_convert_feature_col_as_float_col_list( dataset: DataFrame, features_col_names: List[str] @@ -905,7 +912,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): """Check if stage-level scheduling is not needed, return true to skip stage-level scheduling""" - if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): + if self._run_on_gpu(): ss = _get_spark_session() sc = ss.sparkContext @@ -1022,9 +1029,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): dmatrix_kwargs, ) = self._get_xgb_parameters(dataset) - run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault( - self.use_gpu - ) + run_on_gpu = self._run_on_gpu() + is_local = _is_local(_get_spark_session().sparkContext) num_workers = self.getOrDefault(self.num_workers) @@ -1318,12 +1324,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): dataset = dataset.drop(pred_struct_col) return dataset - def _gpu_transform(self) -> bool: - """If gpu is used to do the prediction, true to gpu prediction""" + def _run_on_gpu(self) -> bool: + """If gpu is used to do the prediction according to the parameters + and spark configurations""" + + use_gpu_by_params = super()._run_on_gpu() if _is_local(_get_spark_session().sparkContext): - # if it's local model, we just use the internal "device" - return use_cuda(self.getOrDefault(self.device)) + # if it's local model, no need to check the spark configurations + return use_gpu_by_params gpu_per_task = ( _get_spark_session() @@ -1333,15 +1342,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): # User don't set gpu configurations, just use cpu if gpu_per_task is None: - if use_cuda(self.getOrDefault(self.device)): + if use_gpu_by_params: get_logger("XGBoost-PySpark").warning( "Do the prediction on the CPUs since " "no gpu configurations are set" ) return False - # User already sets the gpu configurations, we just use the internal "device". - return use_cuda(self.getOrDefault(self.device)) + # User already sets the gpu configurations. + return use_gpu_by_params def _transform(self, dataset: DataFrame) -> DataFrame: # pylint: disable=too-many-statements, too-many-locals @@ -1367,7 +1376,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): _, schema = self._out_schema() is_local = _is_local(_get_spark_session().sparkContext) - run_on_gpu = self._gpu_transform() + run_on_gpu = self._run_on_gpu() @pandas_udf(schema) # type: ignore def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: @@ -1381,9 +1390,10 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): dev_ordinal = -1 - if is_cudf_available(): - if is_local: - if run_on_gpu and is_cupy_available(): + msg = "Do the inference on the CPUs" + if run_on_gpu: + if is_cudf_available() and is_cupy_available(): + if is_local: import cupy as cp # pylint: disable=import-error total_gpus = cp.cuda.runtime.getDeviceCount() @@ -1392,24 +1402,19 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): # For transform local mode, default the dev_ordinal to # (partition id) % gpus. dev_ordinal = partition_id % total_gpus - elif run_on_gpu: - dev_ordinal = _get_gpu_id(context) + else: + dev_ordinal = _get_gpu_id(context) - if dev_ordinal >= 0: - device = "cuda:" + str(dev_ordinal) - get_logger("XGBoost-PySpark").info( - "Do the inference with device: %s", device - ) - model.set_params(device=device) + if dev_ordinal >= 0: + device = "cuda:" + str(dev_ordinal) + msg = "Do the inference with device: " + device + model.set_params(device=device) + else: + msg = "Couldn't get the correct gpu id, fallback the inference on the CPUs" else: - get_logger("XGBoost-PySpark").info("Do the inference on the CPUs") - else: - msg = ( - "CUDF is unavailable, fallback the inference on the CPUs" - if run_on_gpu - else "Do the inference on the CPUs" - ) - get_logger("XGBoost-PySpark").info(msg) + msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs" + + get_logger("XGBoost-PySpark").info(msg) def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: """Move the data to gpu if possible""" diff --git a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py index 513554e43..3bf94c954 100644 --- a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py +++ b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py @@ -251,10 +251,10 @@ def test_gpu_transform(spark_diabetes_dataset) -> None: model: SparkXGBRegressorModel = regressor.fit(train_df) # The model trained with GPUs, and transform with GPU configurations. - assert model._gpu_transform() + assert model._run_on_gpu() model.set_device("cpu") - assert not model._gpu_transform() + assert not model._run_on_gpu() # without error cpu_rows = model.transform(test_df).select("prediction").collect() @@ -263,11 +263,11 @@ def test_gpu_transform(spark_diabetes_dataset) -> None: # The model trained with CPUs. Even with GPU configurations, # still prefer transforming with CPUs - assert not model._gpu_transform() + assert not model._run_on_gpu() # Set gpu transform explicitly. model.set_device("cuda") - assert model._gpu_transform() + assert model._run_on_gpu() # without error gpu_rows = model.transform(test_df).select("prediction").collect() diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 861e67a75..2c5ee3690 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -888,6 +888,22 @@ class TestPySparkLocal: clf = SparkXGBClassifier(device="cuda") clf._validate_params() + def test_gpu_params(self) -> None: + clf = SparkXGBClassifier() + assert not clf._run_on_gpu() + + clf = SparkXGBClassifier(device="cuda", tree_method="hist") + assert clf._run_on_gpu() + + clf = SparkXGBClassifier(device="cuda") + assert clf._run_on_gpu() + + clf = SparkXGBClassifier(tree_method="gpu_hist") + assert clf._run_on_gpu() + + clf = SparkXGBClassifier(use_gpu=True) + assert clf._run_on_gpu() + def test_gpu_transform(self, clf_data: ClfData) -> None: """local mode""" classifier = SparkXGBClassifier(device="cpu") @@ -898,23 +914,23 @@ class TestPySparkLocal: model.write().overwrite().save(path) # The model trained with CPU, transform defaults to cpu - assert not model._gpu_transform() + assert not model._run_on_gpu() # without error model.transform(clf_data.cls_df_test).collect() model.set_device("cuda") - assert model._gpu_transform() + assert model._run_on_gpu() model_loaded = SparkXGBClassifierModel.load(path) # The model trained with CPU, transform defaults to cpu - assert not model_loaded._gpu_transform() + assert not model_loaded._run_on_gpu() # without error model_loaded.transform(clf_data.cls_df_test).collect() model_loaded.set_device("cuda") - assert model_loaded._gpu_transform() + assert model_loaded._run_on_gpu() class XgboostLocalTest(SparkTestCase):