[pyspark] support gpu transform (#9542)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2023-09-07 12:15:50 +08:00
committed by GitHub
parent 0f35493b65
commit 6c791b5b47
5 changed files with 166 additions and 7 deletions

View File

@@ -888,6 +888,34 @@ class TestPySparkLocal:
clf = SparkXGBClassifier(device="cuda")
clf._validate_params()
def test_gpu_transform(self, clf_data: ClfData) -> None:
"""local mode"""
classifier = SparkXGBClassifier(device="cpu")
model: SparkXGBClassifierModel = classifier.fit(clf_data.cls_df_train)
with tempfile.TemporaryDirectory() as tmpdir:
path = "file:" + tmpdir
model.write().overwrite().save(path)
# The model trained with CPU, transform defaults to cpu
assert not model._gpu_transform()
# without error
model.transform(clf_data.cls_df_test).collect()
model.set_device("cuda")
assert model._gpu_transform()
model_loaded = SparkXGBClassifierModel.load(path)
# The model trained with CPU, transform defaults to cpu
assert not model_loaded._gpu_transform()
# without error
model_loaded.transform(clf_data.cls_df_test).collect()
model_loaded.set_device("cuda")
assert model_loaded._gpu_transform()
class XgboostLocalTest(SparkTestCase):
def setUp(self):