[pyspark] Update eval_metric validation to support list of strings (#8826)
This commit is contained in:
parent
803d5e3c4c
commit
6cef9a08e9
@ -314,8 +314,19 @@ class _SparkXGBParams(
|
||||
raise ValueError("Only string type 'objective' param is allowed.")
|
||||
|
||||
if self.getOrDefault(self.eval_metric) is not None:
|
||||
if not isinstance(self.getOrDefault(self.eval_metric), str):
|
||||
raise ValueError("Only string type 'eval_metric' param is allowed.")
|
||||
if not (
|
||||
isinstance(self.getOrDefault(self.eval_metric), str)
|
||||
or (
|
||||
isinstance(self.getOrDefault(self.eval_metric), List)
|
||||
and all(
|
||||
isinstance(metric, str)
|
||||
for metric in self.getOrDefault(self.eval_metric)
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Only string type or list of string type 'eval_metric' param is allowed."
|
||||
)
|
||||
|
||||
if self.getOrDefault(self.early_stopping_rounds) is not None:
|
||||
if not (
|
||||
|
||||
@ -730,6 +730,16 @@ class TestPySparkLocal:
|
||||
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
|
||||
assert train_params["tree_method"] == "gpu_hist"
|
||||
|
||||
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
|
||||
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
|
||||
model = classifier.fit(clf_data.cls_df_train)
|
||||
model.transform(clf_data.cls_df_test).collect()
|
||||
|
||||
def test_classifier_with_string_eval_metric(self, clf_data: ClfData) -> None:
|
||||
classifier = SparkXGBClassifier(eval_metric="auc")
|
||||
model = classifier.fit(clf_data.cls_df_train)
|
||||
model.transform(clf_data.cls_df_test).collect()
|
||||
|
||||
|
||||
class XgboostLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user