[pyspark] Update eval_metric validation to support list of strings (#8826)

This commit is contained in:
mzzhang95
2023-03-01 19:24:12 -05:00
committed by GitHub
parent 803d5e3c4c
commit 6cef9a08e9
2 changed files with 23 additions and 2 deletions

View File

@@ -314,8 +314,19 @@ class _SparkXGBParams(
raise ValueError("Only string type 'objective' param is allowed.")
if self.getOrDefault(self.eval_metric) is not None:
if not isinstance(self.getOrDefault(self.eval_metric), str):
raise ValueError("Only string type 'eval_metric' param is allowed.")
if not (
isinstance(self.getOrDefault(self.eval_metric), str)
or (
isinstance(self.getOrDefault(self.eval_metric), List)
and all(
isinstance(metric, str)
for metric in self.getOrDefault(self.eval_metric)
)
)
):
raise ValueError(
"Only string type or list of string type 'eval_metric' param is allowed."
)
if self.getOrDefault(self.early_stopping_rounds) is not None:
if not (