[pyspark] Update eval_metric validation to support list of strings (#8826)

This commit is contained in:
mzzhang95
2023-03-01 19:24:12 -05:00
committed by GitHub
parent 803d5e3c4c
commit 6cef9a08e9
2 changed files with 23 additions and 2 deletions

View File

@@ -730,6 +730,16 @@ class TestPySparkLocal:
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
assert train_params["tree_method"] == "gpu_hist"
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
model = classifier.fit(clf_data.cls_df_train)
model.transform(clf_data.cls_df_test).collect()
def test_classifier_with_string_eval_metric(self, clf_data: ClfData) -> None:
classifier = SparkXGBClassifier(eval_metric="auc")
model = classifier.fit(clf_data.cls_df_train)
model.transform(clf_data.cls_df_test).collect()
class XgboostLocalTest(SparkTestCase):
def setUp(self):