[pyspark] unify the way for determining whether runs on the GPU. (#9724)

This commit is contained in:
Bobby Wang
2023-10-27 11:21:30 +08:00
committed by GitHub
parent f41a08fda8
commit 1323531323
3 changed files with 64 additions and 43 deletions

View File

@@ -251,10 +251,10 @@ def test_gpu_transform(spark_diabetes_dataset) -> None:
model: SparkXGBRegressorModel = regressor.fit(train_df)
# The model trained with GPUs, and transform with GPU configurations.
assert model._gpu_transform()
assert model._run_on_gpu()
model.set_device("cpu")
assert not model._gpu_transform()
assert not model._run_on_gpu()
# without error
cpu_rows = model.transform(test_df).select("prediction").collect()
@@ -263,11 +263,11 @@ def test_gpu_transform(spark_diabetes_dataset) -> None:
# The model trained with CPUs. Even with GPU configurations,
# still prefer transforming with CPUs
assert not model._gpu_transform()
assert not model._run_on_gpu()
# Set gpu transform explicitly.
model.set_device("cuda")
assert model._gpu_transform()
assert model._run_on_gpu()
# without error
gpu_rows = model.transform(test_df).select("prediction").collect()