[pyspark] unify the way for determining whether runs on the GPU. (#9724)

This commit is contained in:
Bobby Wang 2023-10-27 11:21:30 +08:00 committed by GitHub
parent f41a08fda8
commit 1323531323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 43 deletions

View File

@ -347,7 +347,7 @@ class _SparkXGBParams(
def _validate_gpu_params(self) -> None: def _validate_gpu_params(self) -> None:
"""Validate the gpu parameters and gpu configurations""" """Validate the gpu parameters and gpu configurations"""
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): if self._run_on_gpu():
ss = _get_spark_session() ss = _get_spark_session()
sc = ss.sparkContext sc = ss.sparkContext
@ -414,9 +414,7 @@ class _SparkXGBParams(
) )
if self.getOrDefault(self.features_cols): if self.getOrDefault(self.features_cols):
if not use_cuda(self.getOrDefault(self.device)) and not self.getOrDefault( if not self._run_on_gpu():
self.use_gpu
):
raise ValueError( raise ValueError(
"features_col param with list value requires `device=cuda`." "features_col param with list value requires `device=cuda`."
) )
@ -473,6 +471,15 @@ class _SparkXGBParams(
self._validate_gpu_params() self._validate_gpu_params()
def _run_on_gpu(self) -> bool:
"""If train or transform on the gpu according to the parameters"""
return (
use_cuda(self.getOrDefault(self.device))
or self.getOrDefault(self.use_gpu)
or self.getOrDefault(self.getParam("tree_method")) == "gpu_hist"
)
def _validate_and_convert_feature_col_as_float_col_list( def _validate_and_convert_feature_col_as_float_col_list(
dataset: DataFrame, features_col_names: List[str] dataset: DataFrame, features_col_names: List[str]
@ -905,7 +912,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"""Check if stage-level scheduling is not needed, """Check if stage-level scheduling is not needed,
return true to skip stage-level scheduling""" return true to skip stage-level scheduling"""
if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): if self._run_on_gpu():
ss = _get_spark_session() ss = _get_spark_session()
sc = ss.sparkContext sc = ss.sparkContext
@ -1022,9 +1029,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dmatrix_kwargs, dmatrix_kwargs,
) = self._get_xgb_parameters(dataset) ) = self._get_xgb_parameters(dataset)
run_on_gpu = use_cuda(self.getOrDefault(self.device)) or self.getOrDefault( run_on_gpu = self._run_on_gpu()
self.use_gpu
)
is_local = _is_local(_get_spark_session().sparkContext) is_local = _is_local(_get_spark_session().sparkContext)
num_workers = self.getOrDefault(self.num_workers) num_workers = self.getOrDefault(self.num_workers)
@ -1318,12 +1324,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
dataset = dataset.drop(pred_struct_col) dataset = dataset.drop(pred_struct_col)
return dataset return dataset
def _gpu_transform(self) -> bool: def _run_on_gpu(self) -> bool:
"""If gpu is used to do the prediction, true to gpu prediction""" """If gpu is used to do the prediction according to the parameters
and spark configurations"""
use_gpu_by_params = super()._run_on_gpu()
if _is_local(_get_spark_session().sparkContext): if _is_local(_get_spark_session().sparkContext):
# if it's local model, we just use the internal "device" # if it's local model, no need to check the spark configurations
return use_cuda(self.getOrDefault(self.device)) return use_gpu_by_params
gpu_per_task = ( gpu_per_task = (
_get_spark_session() _get_spark_session()
@ -1333,15 +1342,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
# User don't set gpu configurations, just use cpu # User don't set gpu configurations, just use cpu
if gpu_per_task is None: if gpu_per_task is None:
if use_cuda(self.getOrDefault(self.device)): if use_gpu_by_params:
get_logger("XGBoost-PySpark").warning( get_logger("XGBoost-PySpark").warning(
"Do the prediction on the CPUs since " "Do the prediction on the CPUs since "
"no gpu configurations are set" "no gpu configurations are set"
) )
return False return False
# User already sets the gpu configurations, we just use the internal "device". # User already sets the gpu configurations.
return use_cuda(self.getOrDefault(self.device)) return use_gpu_by_params
def _transform(self, dataset: DataFrame) -> DataFrame: def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals # pylint: disable=too-many-statements, too-many-locals
@ -1367,7 +1376,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
_, schema = self._out_schema() _, schema = self._out_schema()
is_local = _is_local(_get_spark_session().sparkContext) is_local = _is_local(_get_spark_session().sparkContext)
run_on_gpu = self._gpu_transform() run_on_gpu = self._run_on_gpu()
@pandas_udf(schema) # type: ignore @pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
@ -1381,9 +1390,10 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
dev_ordinal = -1 dev_ordinal = -1
if is_cudf_available(): msg = "Do the inference on the CPUs"
if run_on_gpu:
if is_cudf_available() and is_cupy_available():
if is_local: if is_local:
if run_on_gpu and is_cupy_available():
import cupy as cp # pylint: disable=import-error import cupy as cp # pylint: disable=import-error
total_gpus = cp.cuda.runtime.getDeviceCount() total_gpus = cp.cuda.runtime.getDeviceCount()
@ -1392,23 +1402,18 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
# For transform local mode, default the dev_ordinal to # For transform local mode, default the dev_ordinal to
# (partition id) % gpus. # (partition id) % gpus.
dev_ordinal = partition_id % total_gpus dev_ordinal = partition_id % total_gpus
elif run_on_gpu: else:
dev_ordinal = _get_gpu_id(context) dev_ordinal = _get_gpu_id(context)
if dev_ordinal >= 0: if dev_ordinal >= 0:
device = "cuda:" + str(dev_ordinal) device = "cuda:" + str(dev_ordinal)
get_logger("XGBoost-PySpark").info( msg = "Do the inference with device: " + device
"Do the inference with device: %s", device
)
model.set_params(device=device) model.set_params(device=device)
else: else:
get_logger("XGBoost-PySpark").info("Do the inference on the CPUs") msg = "Couldn't get the correct gpu id, fallback the inference on the CPUs"
else: else:
msg = ( msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs"
"CUDF is unavailable, fallback the inference on the CPUs"
if run_on_gpu
else "Do the inference on the CPUs"
)
get_logger("XGBoost-PySpark").info(msg) get_logger("XGBoost-PySpark").info(msg)
def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:

View File

@ -251,10 +251,10 @@ def test_gpu_transform(spark_diabetes_dataset) -> None:
model: SparkXGBRegressorModel = regressor.fit(train_df) model: SparkXGBRegressorModel = regressor.fit(train_df)
# The model trained with GPUs, and transform with GPU configurations. # The model trained with GPUs, and transform with GPU configurations.
assert model._gpu_transform() assert model._run_on_gpu()
model.set_device("cpu") model.set_device("cpu")
assert not model._gpu_transform() assert not model._run_on_gpu()
# without error # without error
cpu_rows = model.transform(test_df).select("prediction").collect() cpu_rows = model.transform(test_df).select("prediction").collect()
@ -263,11 +263,11 @@ def test_gpu_transform(spark_diabetes_dataset) -> None:
# The model trained with CPUs. Even with GPU configurations, # The model trained with CPUs. Even with GPU configurations,
# still prefer transforming with CPUs # still prefer transforming with CPUs
assert not model._gpu_transform() assert not model._run_on_gpu()
# Set gpu transform explicitly. # Set gpu transform explicitly.
model.set_device("cuda") model.set_device("cuda")
assert model._gpu_transform() assert model._run_on_gpu()
# without error # without error
gpu_rows = model.transform(test_df).select("prediction").collect() gpu_rows = model.transform(test_df).select("prediction").collect()

View File

@ -888,6 +888,22 @@ class TestPySparkLocal:
clf = SparkXGBClassifier(device="cuda") clf = SparkXGBClassifier(device="cuda")
clf._validate_params() clf._validate_params()
def test_gpu_params(self) -> None:
clf = SparkXGBClassifier()
assert not clf._run_on_gpu()
clf = SparkXGBClassifier(device="cuda", tree_method="hist")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(device="cuda")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(tree_method="gpu_hist")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(use_gpu=True)
assert clf._run_on_gpu()
def test_gpu_transform(self, clf_data: ClfData) -> None: def test_gpu_transform(self, clf_data: ClfData) -> None:
"""local mode""" """local mode"""
classifier = SparkXGBClassifier(device="cpu") classifier = SparkXGBClassifier(device="cpu")
@ -898,23 +914,23 @@ class TestPySparkLocal:
model.write().overwrite().save(path) model.write().overwrite().save(path)
# The model trained with CPU, transform defaults to cpu # The model trained with CPU, transform defaults to cpu
assert not model._gpu_transform() assert not model._run_on_gpu()
# without error # without error
model.transform(clf_data.cls_df_test).collect() model.transform(clf_data.cls_df_test).collect()
model.set_device("cuda") model.set_device("cuda")
assert model._gpu_transform() assert model._run_on_gpu()
model_loaded = SparkXGBClassifierModel.load(path) model_loaded = SparkXGBClassifierModel.load(path)
# The model trained with CPU, transform defaults to cpu # The model trained with CPU, transform defaults to cpu
assert not model_loaded._gpu_transform() assert not model_loaded._run_on_gpu()
# without error # without error
model_loaded.transform(clf_data.cls_df_test).collect() model_loaded.transform(clf_data.cls_df_test).collect()
model_loaded.set_device("cuda") model_loaded.set_device("cuda")
assert model_loaded._gpu_transform() assert model_loaded._run_on_gpu()
class XgboostLocalTest(SparkTestCase): class XgboostLocalTest(SparkTestCase):