From 6cef9a08e9e6bfc54b32da2b1108949d2f1a5de2 Mon Sep 17 00:00:00 2001 From: mzzhang95 <48953222+mzzhang95@users.noreply.github.com> Date: Wed, 1 Mar 2023 19:24:12 -0500 Subject: [PATCH] [pyspark] Update eval_metric validation to support list of strings (#8826) --- python-package/xgboost/spark/core.py | 15 +++++++++++++-- .../test_with_spark/test_spark_local.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 6d9733817..8a13e88cc 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -314,8 +314,19 @@ class _SparkXGBParams( raise ValueError("Only string type 'objective' param is allowed.") if self.getOrDefault(self.eval_metric) is not None: - if not isinstance(self.getOrDefault(self.eval_metric), str): - raise ValueError("Only string type 'eval_metric' param is allowed.") + if not ( + isinstance(self.getOrDefault(self.eval_metric), str) + or ( + isinstance(self.getOrDefault(self.eval_metric), List) + and all( + isinstance(metric, str) + for metric in self.getOrDefault(self.eval_metric) + ) + ) + ): + raise ValueError( + "Only string type or list of string type 'eval_metric' param is allowed." + ) if self.getOrDefault(self.early_stopping_rounds) is not None: if not ( diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 27f1ef06f..b86a1930f 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -730,6 +730,16 @@ class TestPySparkLocal: train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train) assert train_params["tree_method"] == "gpu_hist" + def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None: + classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"]) + model = classifier.fit(clf_data.cls_df_train) + model.transform(clf_data.cls_df_test).collect() + + def test_classifier_with_string_eval_metric(self, clf_data: ClfData) -> None: + classifier = SparkXGBClassifier(eval_metric="auc") + model = classifier.fit(clf_data.cls_df_train) + model.transform(clf_data.cls_df_test).collect() + class XgboostLocalTest(SparkTestCase): def setUp(self):