[backport] [pyspark] support gpu transform (#9542) (#9559)

---------

Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
Jiaming Yuan 2023-09-07 17:21:09 +08:00 committed by GitHub
parent 4d387cbfbf
commit e75dd75bb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 166 additions and 7 deletions

View File

@ -88,6 +88,18 @@ def is_cudf_available() -> bool:
return False
def is_cupy_available() -> bool:
"""Check cupy package available or not"""
if importlib.util.find_spec("cupy") is None:
return False
try:
import cupy
return True
except ImportError:
return False
try:
import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr

View File

@ -59,7 +59,7 @@ from scipy.special import expit, softmax # pylint: disable=no-name-in-module
import xgboost
from xgboost import XGBClassifier
from xgboost.compat import is_cudf_available
from xgboost.compat import is_cudf_available, is_cupy_available
from xgboost.core import Booster, _check_distributed_params
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
from xgboost.training import train as worker_train
@ -242,6 +242,13 @@ class _SparkXGBParams(
TypeConverters.toList,
)
def set_device(self, value: str) -> "_SparkXGBParams":
"""Set device, optional value: cpu, cuda, gpu"""
_check_distributed_params({"device": value})
assert value in ("cpu", "cuda", "gpu")
self.set(self.device, value)
return self
@classmethod
def _xgb_cls(cls) -> Type[XGBModel]:
"""
@ -1193,6 +1200,31 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
dataset = dataset.drop(pred_struct_col)
return dataset
def _gpu_transform(self) -> bool:
"""If gpu is used to do the prediction, true to gpu prediction"""
if _is_local(_get_spark_session().sparkContext):
# if it's local model, we just use the internal "device"
return use_cuda(self.getOrDefault(self.device))
gpu_per_task = (
_get_spark_session()
.sparkContext.getConf()
.get("spark.task.resource.gpu.amount")
)
# User don't set gpu configurations, just use cpu
if gpu_per_task is None:
if use_cuda(self.getOrDefault(self.device)):
get_logger("XGBoost-PySpark").warning(
"Do the prediction on the CPUs since "
"no gpu configurations are set"
)
return False
# User already sets the gpu configurations, we just use the internal "device".
return use_cuda(self.getOrDefault(self.device))
def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
@ -1216,21 +1248,77 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
_, schema = self._out_schema()
is_local = _is_local(_get_spark_session().sparkContext)
run_on_gpu = self._gpu_transform()
@pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
assert xgb_sklearn_model is not None
model = xgb_sklearn_model
from pyspark import TaskContext
context = TaskContext.get()
assert context is not None
dev_ordinal = -1
if is_cudf_available():
if is_local:
if run_on_gpu and is_cupy_available():
import cupy as cp # pylint: disable=import-error
total_gpus = cp.cuda.runtime.getDeviceCount()
if total_gpus > 0:
partition_id = context.partitionId()
# For transform local mode, default the dev_ordinal to
# (partition id) % gpus.
dev_ordinal = partition_id % total_gpus
elif run_on_gpu:
dev_ordinal = _get_gpu_id(context)
if dev_ordinal >= 0:
device = "cuda:" + str(dev_ordinal)
get_logger("XGBoost-PySpark").info(
"Do the inference with device: %s", device
)
model.set_params(device=device)
else:
get_logger("XGBoost-PySpark").info("Do the inference on the CPUs")
else:
msg = (
"CUDF is unavailable, fallback the inference on the CPUs"
if run_on_gpu
else "Do the inference on the CPUs"
)
get_logger("XGBoost-PySpark").info(msg)
def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:
"""Move the data to gpu if possible"""
if dev_ordinal >= 0:
import cudf # pylint: disable=import-error
import cupy as cp # pylint: disable=import-error
# We must set the device after import cudf, which will change the device id to 0
# See https://github.com/rapidsai/cudf/issues/11386
cp.cuda.runtime.setDevice(dev_ordinal) # pylint: disable=I1101
df = cudf.DataFrame(data)
del data
return df
return data
for data in iterator:
if enable_sparse_data_optim:
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
else:
if feature_col_names is not None:
X = data[feature_col_names]
tmp = data[feature_col_names]
else:
X = stack_series(data[alias.data])
tmp = stack_series(data[alias.data])
X = to_gpu_if_possible(tmp)
if has_base_margin:
base_margin = data[alias.margin].to_numpy()
base_margin = to_gpu_if_possible(data[alias.margin])
else:
base_margin = None

View File

@ -10,7 +10,7 @@ from threading import Thread
from typing import Any, Callable, Dict, Optional, Set, Type
import pyspark
from pyspark import BarrierTaskContext, SparkContext, SparkFiles
from pyspark import BarrierTaskContext, SparkContext, SparkFiles, TaskContext
from pyspark.sql.session import SparkSession
from xgboost import Booster, XGBModel, collective
@ -129,7 +129,7 @@ def _is_local(spark_context: SparkContext) -> bool:
return spark_context._jsc.sc().isLocal()
def _get_gpu_id(task_context: BarrierTaskContext) -> int:
def _get_gpu_id(task_context: TaskContext) -> int:
"""Get the gpu id from the task resources"""
if task_context is None:
# This is a safety check.

View File

@ -2,6 +2,7 @@ import json
import logging
import subprocess
import numpy as np
import pytest
import sklearn
@ -13,7 +14,7 @@ from pyspark.ml.linalg import Vectors
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import SparkSession
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor, SparkXGBRegressorModel
gpu_discovery_script_path = "tests/test_distributed/test_gpu_with_spark/discover_gpu.sh"
@ -242,3 +243,33 @@ def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature
evaluator = RegressionEvaluator(metricName="rmse")
rmse = evaluator.evaluate(pred_result_df)
assert rmse <= 65.0
def test_gpu_transform(spark_diabetes_dataset) -> None:
regressor = SparkXGBRegressor(device="cuda", num_workers=num_workers)
train_df, test_df = spark_diabetes_dataset
model: SparkXGBRegressorModel = regressor.fit(train_df)
# The model trained with GPUs, and transform with GPU configurations.
assert model._gpu_transform()
model.set_device("cpu")
assert not model._gpu_transform()
# without error
cpu_rows = model.transform(test_df).select("prediction").collect()
regressor = SparkXGBRegressor(device="cpu", num_workers=num_workers)
model = regressor.fit(train_df)
# The model trained with CPUs. Even with GPU configurations,
# still prefer transforming with CPUs
assert not model._gpu_transform()
# Set gpu transform explicitly.
model.set_device("cuda")
assert model._gpu_transform()
# without error
gpu_rows = model.transform(test_df).select("prediction").collect()
for cpu, gpu in zip(cpu_rows, gpu_rows):
np.testing.assert_allclose(cpu.prediction, gpu.prediction, atol=1e-3)

View File

@ -888,6 +888,34 @@ class TestPySparkLocal:
clf = SparkXGBClassifier(device="cuda")
clf._validate_params()
def test_gpu_transform(self, clf_data: ClfData) -> None:
"""local mode"""
classifier = SparkXGBClassifier(device="cpu")
model: SparkXGBClassifierModel = classifier.fit(clf_data.cls_df_train)
with tempfile.TemporaryDirectory() as tmpdir:
path = "file:" + tmpdir
model.write().overwrite().save(path)
# The model trained with CPU, transform defaults to cpu
assert not model._gpu_transform()
# without error
model.transform(clf_data.cls_df_test).collect()
model.set_device("cuda")
assert model._gpu_transform()
model_loaded = SparkXGBClassifierModel.load(path)
# The model trained with CPU, transform defaults to cpu
assert not model_loaded._gpu_transform()
# without error
model_loaded.transform(clf_data.cls_df_test).collect()
model_loaded.set_device("cuda")
assert model_loaded._gpu_transform()
class XgboostLocalTest(SparkTestCase):
def setUp(self):