Fix pyspark parameter. (#9460)
- Don't pass the `use_gpu` parameter to the learner. - Fix GPU approx with PySpark.
This commit is contained in:
@@ -151,12 +151,18 @@ def spark_diabetes_dataset_feature_cols(spark_session_with_gpu):
|
||||
return train_df, test_df, data.feature_names
|
||||
|
||||
|
||||
def test_sparkxgb_classifier_with_gpu(spark_iris_dataset):
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_sparkxgb_classifier_with_gpu(tree_method: str, spark_iris_dataset) -> None:
|
||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||
|
||||
classifier = SparkXGBClassifier(device="cuda", num_workers=num_workers)
|
||||
classifier = SparkXGBClassifier(
|
||||
device="cuda", num_workers=num_workers, tree_method=tree_method
|
||||
)
|
||||
train_df, test_df = spark_iris_dataset
|
||||
model = classifier.fit(train_df)
|
||||
config = json.loads(model.get_booster().save_config())
|
||||
ctx = config["learner"]["generic_param"]
|
||||
assert ctx["device"] == "cuda:0"
|
||||
pred_result_df = model.transform(test_df)
|
||||
evaluator = MulticlassClassificationEvaluator(metricName="f1")
|
||||
f1 = evaluator.evaluate(pred_result_df)
|
||||
|
||||
Reference in New Issue
Block a user