--------- Co-authored-by: Bobby Wang <wbo4958@gmail.com>
This commit is contained in:
parent
4d387cbfbf
commit
e75dd75bb2
@ -88,6 +88,18 @@ def is_cudf_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_cupy_available() -> bool:
|
||||
"""Check cupy package available or not"""
|
||||
if importlib.util.find_spec("cupy") is None:
|
||||
return False
|
||||
try:
|
||||
import cupy
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
import scipy.sparse as scipy_sparse
|
||||
from scipy.sparse import csr_matrix as scipy_csr
|
||||
|
||||
@ -59,7 +59,7 @@ from scipy.special import expit, softmax # pylint: disable=no-name-in-module
|
||||
|
||||
import xgboost
|
||||
from xgboost import XGBClassifier
|
||||
from xgboost.compat import is_cudf_available
|
||||
from xgboost.compat import is_cudf_available, is_cupy_available
|
||||
from xgboost.core import Booster, _check_distributed_params
|
||||
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
|
||||
from xgboost.training import train as worker_train
|
||||
@ -242,6 +242,13 @@ class _SparkXGBParams(
|
||||
TypeConverters.toList,
|
||||
)
|
||||
|
||||
def set_device(self, value: str) -> "_SparkXGBParams":
|
||||
"""Set device, optional value: cpu, cuda, gpu"""
|
||||
_check_distributed_params({"device": value})
|
||||
assert value in ("cpu", "cuda", "gpu")
|
||||
self.set(self.device, value)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _xgb_cls(cls) -> Type[XGBModel]:
|
||||
"""
|
||||
@ -1193,6 +1200,31 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
dataset = dataset.drop(pred_struct_col)
|
||||
return dataset
|
||||
|
||||
def _gpu_transform(self) -> bool:
|
||||
"""If gpu is used to do the prediction, true to gpu prediction"""
|
||||
|
||||
if _is_local(_get_spark_session().sparkContext):
|
||||
# if it's local model, we just use the internal "device"
|
||||
return use_cuda(self.getOrDefault(self.device))
|
||||
|
||||
gpu_per_task = (
|
||||
_get_spark_session()
|
||||
.sparkContext.getConf()
|
||||
.get("spark.task.resource.gpu.amount")
|
||||
)
|
||||
|
||||
# User don't set gpu configurations, just use cpu
|
||||
if gpu_per_task is None:
|
||||
if use_cuda(self.getOrDefault(self.device)):
|
||||
get_logger("XGBoost-PySpark").warning(
|
||||
"Do the prediction on the CPUs since "
|
||||
"no gpu configurations are set"
|
||||
)
|
||||
return False
|
||||
|
||||
# User already sets the gpu configurations, we just use the internal "device".
|
||||
return use_cuda(self.getOrDefault(self.device))
|
||||
|
||||
def _transform(self, dataset: DataFrame) -> DataFrame:
|
||||
# pylint: disable=too-many-statements, too-many-locals
|
||||
# Save xgb_sklearn_model and predict_params to be local variable
|
||||
@ -1216,21 +1248,77 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
_, schema = self._out_schema()
|
||||
|
||||
is_local = _is_local(_get_spark_session().sparkContext)
|
||||
run_on_gpu = self._gpu_transform()
|
||||
|
||||
@pandas_udf(schema) # type: ignore
|
||||
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
|
||||
assert xgb_sklearn_model is not None
|
||||
model = xgb_sklearn_model
|
||||
|
||||
from pyspark import TaskContext
|
||||
|
||||
context = TaskContext.get()
|
||||
assert context is not None
|
||||
|
||||
dev_ordinal = -1
|
||||
|
||||
if is_cudf_available():
|
||||
if is_local:
|
||||
if run_on_gpu and is_cupy_available():
|
||||
import cupy as cp # pylint: disable=import-error
|
||||
|
||||
total_gpus = cp.cuda.runtime.getDeviceCount()
|
||||
if total_gpus > 0:
|
||||
partition_id = context.partitionId()
|
||||
# For transform local mode, default the dev_ordinal to
|
||||
# (partition id) % gpus.
|
||||
dev_ordinal = partition_id % total_gpus
|
||||
elif run_on_gpu:
|
||||
dev_ordinal = _get_gpu_id(context)
|
||||
|
||||
if dev_ordinal >= 0:
|
||||
device = "cuda:" + str(dev_ordinal)
|
||||
get_logger("XGBoost-PySpark").info(
|
||||
"Do the inference with device: %s", device
|
||||
)
|
||||
model.set_params(device=device)
|
||||
else:
|
||||
get_logger("XGBoost-PySpark").info("Do the inference on the CPUs")
|
||||
else:
|
||||
msg = (
|
||||
"CUDF is unavailable, fallback the inference on the CPUs"
|
||||
if run_on_gpu
|
||||
else "Do the inference on the CPUs"
|
||||
)
|
||||
get_logger("XGBoost-PySpark").info(msg)
|
||||
|
||||
def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:
|
||||
"""Move the data to gpu if possible"""
|
||||
if dev_ordinal >= 0:
|
||||
import cudf # pylint: disable=import-error
|
||||
import cupy as cp # pylint: disable=import-error
|
||||
|
||||
# We must set the device after import cudf, which will change the device id to 0
|
||||
# See https://github.com/rapidsai/cudf/issues/11386
|
||||
cp.cuda.runtime.setDevice(dev_ordinal) # pylint: disable=I1101
|
||||
df = cudf.DataFrame(data)
|
||||
del data
|
||||
return df
|
||||
return data
|
||||
|
||||
for data in iterator:
|
||||
if enable_sparse_data_optim:
|
||||
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
|
||||
else:
|
||||
if feature_col_names is not None:
|
||||
X = data[feature_col_names]
|
||||
tmp = data[feature_col_names]
|
||||
else:
|
||||
X = stack_series(data[alias.data])
|
||||
tmp = stack_series(data[alias.data])
|
||||
X = to_gpu_if_possible(tmp)
|
||||
|
||||
if has_base_margin:
|
||||
base_margin = data[alias.margin].to_numpy()
|
||||
base_margin = to_gpu_if_possible(data[alias.margin])
|
||||
else:
|
||||
base_margin = None
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from threading import Thread
|
||||
from typing import Any, Callable, Dict, Optional, Set, Type
|
||||
|
||||
import pyspark
|
||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles
|
||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles, TaskContext
|
||||
from pyspark.sql.session import SparkSession
|
||||
|
||||
from xgboost import Booster, XGBModel, collective
|
||||
@ -129,7 +129,7 @@ def _is_local(spark_context: SparkContext) -> bool:
|
||||
return spark_context._jsc.sc().isLocal()
|
||||
|
||||
|
||||
def _get_gpu_id(task_context: BarrierTaskContext) -> int:
|
||||
def _get_gpu_id(task_context: TaskContext) -> int:
|
||||
"""Get the gpu id from the task resources"""
|
||||
if task_context is None:
|
||||
# This is a safety check.
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sklearn
|
||||
|
||||
@ -13,7 +14,7 @@ from pyspark.ml.linalg import Vectors
|
||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
|
||||
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor, SparkXGBRegressorModel
|
||||
|
||||
gpu_discovery_script_path = "tests/test_distributed/test_gpu_with_spark/discover_gpu.sh"
|
||||
|
||||
@ -242,3 +243,33 @@ def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature
|
||||
evaluator = RegressionEvaluator(metricName="rmse")
|
||||
rmse = evaluator.evaluate(pred_result_df)
|
||||
assert rmse <= 65.0
|
||||
|
||||
|
||||
def test_gpu_transform(spark_diabetes_dataset) -> None:
|
||||
regressor = SparkXGBRegressor(device="cuda", num_workers=num_workers)
|
||||
train_df, test_df = spark_diabetes_dataset
|
||||
model: SparkXGBRegressorModel = regressor.fit(train_df)
|
||||
|
||||
# The model trained with GPUs, and transform with GPU configurations.
|
||||
assert model._gpu_transform()
|
||||
|
||||
model.set_device("cpu")
|
||||
assert not model._gpu_transform()
|
||||
# without error
|
||||
cpu_rows = model.transform(test_df).select("prediction").collect()
|
||||
|
||||
regressor = SparkXGBRegressor(device="cpu", num_workers=num_workers)
|
||||
model = regressor.fit(train_df)
|
||||
|
||||
# The model trained with CPUs. Even with GPU configurations,
|
||||
# still prefer transforming with CPUs
|
||||
assert not model._gpu_transform()
|
||||
|
||||
# Set gpu transform explicitly.
|
||||
model.set_device("cuda")
|
||||
assert model._gpu_transform()
|
||||
# without error
|
||||
gpu_rows = model.transform(test_df).select("prediction").collect()
|
||||
|
||||
for cpu, gpu in zip(cpu_rows, gpu_rows):
|
||||
np.testing.assert_allclose(cpu.prediction, gpu.prediction, atol=1e-3)
|
||||
|
||||
@ -888,6 +888,34 @@ class TestPySparkLocal:
|
||||
clf = SparkXGBClassifier(device="cuda")
|
||||
clf._validate_params()
|
||||
|
||||
def test_gpu_transform(self, clf_data: ClfData) -> None:
|
||||
"""local mode"""
|
||||
classifier = SparkXGBClassifier(device="cpu")
|
||||
model: SparkXGBClassifierModel = classifier.fit(clf_data.cls_df_train)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = "file:" + tmpdir
|
||||
model.write().overwrite().save(path)
|
||||
|
||||
# The model trained with CPU, transform defaults to cpu
|
||||
assert not model._gpu_transform()
|
||||
|
||||
# without error
|
||||
model.transform(clf_data.cls_df_test).collect()
|
||||
|
||||
model.set_device("cuda")
|
||||
assert model._gpu_transform()
|
||||
|
||||
model_loaded = SparkXGBClassifierModel.load(path)
|
||||
|
||||
# The model trained with CPU, transform defaults to cpu
|
||||
assert not model_loaded._gpu_transform()
|
||||
# without error
|
||||
model_loaded.transform(clf_data.cls_df_test).collect()
|
||||
|
||||
model_loaded.set_device("cuda")
|
||||
assert model_loaded._gpu_transform()
|
||||
|
||||
|
||||
class XgboostLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user