[pyspark] unify the way for determining whether runs on the GPU. (#9724)

This commit is contained in:
Bobby Wang
2023-10-27 11:21:30 +08:00
committed by GitHub
parent f41a08fda8
commit 1323531323
3 changed files with 64 additions and 43 deletions

View File

@@ -888,6 +888,22 @@ class TestPySparkLocal:
clf = SparkXGBClassifier(device="cuda")
clf._validate_params()
def test_gpu_params(self) -> None:
clf = SparkXGBClassifier()
assert not clf._run_on_gpu()
clf = SparkXGBClassifier(device="cuda", tree_method="hist")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(device="cuda")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(tree_method="gpu_hist")
assert clf._run_on_gpu()
clf = SparkXGBClassifier(use_gpu=True)
assert clf._run_on_gpu()
def test_gpu_transform(self, clf_data: ClfData) -> None:
"""local mode"""
classifier = SparkXGBClassifier(device="cpu")
@@ -898,23 +914,23 @@ class TestPySparkLocal:
model.write().overwrite().save(path)
# The model trained with CPU, transform defaults to cpu
assert not model._gpu_transform()
assert not model._run_on_gpu()
# without error
model.transform(clf_data.cls_df_test).collect()
model.set_device("cuda")
assert model._gpu_transform()
assert model._run_on_gpu()
model_loaded = SparkXGBClassifierModel.load(path)
# The model trained with CPU, transform defaults to cpu
assert not model_loaded._gpu_transform()
assert not model_loaded._run_on_gpu()
# without error
model_loaded.transform(clf_data.cls_df_test).collect()
model_loaded.set_device("cuda")
assert model_loaded._gpu_transform()
assert model_loaded._run_on_gpu()
class XgboostLocalTest(SparkTestCase):