[pyspark] Update eval_metric validation to support list of strings (#8826)

This commit is contained in:
mzzhang95 2023-03-01 19:24:12 -05:00 committed by GitHub
parent 803d5e3c4c
commit 6cef9a08e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 2 deletions

View File

@ -314,8 +314,19 @@ class _SparkXGBParams(
raise ValueError("Only string type 'objective' param is allowed.")
if self.getOrDefault(self.eval_metric) is not None:
if not isinstance(self.getOrDefault(self.eval_metric), str):
raise ValueError("Only string type 'eval_metric' param is allowed.")
if not (
isinstance(self.getOrDefault(self.eval_metric), str)
or (
isinstance(self.getOrDefault(self.eval_metric), List)
and all(
isinstance(metric, str)
for metric in self.getOrDefault(self.eval_metric)
)
)
):
raise ValueError(
"Only string type or list of string type 'eval_metric' param is allowed."
)
if self.getOrDefault(self.early_stopping_rounds) is not None:
if not (

View File

@ -730,6 +730,16 @@ class TestPySparkLocal:
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
assert train_params["tree_method"] == "gpu_hist"
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
model = classifier.fit(clf_data.cls_df_train)
model.transform(clf_data.cls_df_test).collect()
def test_classifier_with_string_eval_metric(self, clf_data: ClfData) -> None:
classifier = SparkXGBClassifier(eval_metric="auc")
model = classifier.fit(clf_data.cls_df_train)
model.transform(clf_data.cls_df_test).collect()
class XgboostLocalTest(SparkTestCase):
def setUp(self):