[pyspark] Update eval_metric validation to support list of strings (#8826)
This commit is contained in:
parent
803d5e3c4c
commit
6cef9a08e9
@ -314,8 +314,19 @@ class _SparkXGBParams(
|
|||||||
raise ValueError("Only string type 'objective' param is allowed.")
|
raise ValueError("Only string type 'objective' param is allowed.")
|
||||||
|
|
||||||
if self.getOrDefault(self.eval_metric) is not None:
|
if self.getOrDefault(self.eval_metric) is not None:
|
||||||
if not isinstance(self.getOrDefault(self.eval_metric), str):
|
if not (
|
||||||
raise ValueError("Only string type 'eval_metric' param is allowed.")
|
isinstance(self.getOrDefault(self.eval_metric), str)
|
||||||
|
or (
|
||||||
|
isinstance(self.getOrDefault(self.eval_metric), List)
|
||||||
|
and all(
|
||||||
|
isinstance(metric, str)
|
||||||
|
for metric in self.getOrDefault(self.eval_metric)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Only string type or list of string type 'eval_metric' param is allowed."
|
||||||
|
)
|
||||||
|
|
||||||
if self.getOrDefault(self.early_stopping_rounds) is not None:
|
if self.getOrDefault(self.early_stopping_rounds) is not None:
|
||||||
if not (
|
if not (
|
||||||
|
|||||||
@ -730,6 +730,16 @@ class TestPySparkLocal:
|
|||||||
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
|
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
|
||||||
assert train_params["tree_method"] == "gpu_hist"
|
assert train_params["tree_method"] == "gpu_hist"
|
||||||
|
|
||||||
|
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
|
||||||
|
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
|
||||||
|
model = classifier.fit(clf_data.cls_df_train)
|
||||||
|
model.transform(clf_data.cls_df_test).collect()
|
||||||
|
|
||||||
|
def test_classifier_with_string_eval_metric(self, clf_data: ClfData) -> None:
|
||||||
|
classifier = SparkXGBClassifier(eval_metric="auc")
|
||||||
|
model = classifier.fit(clf_data.cls_df_train)
|
||||||
|
model.transform(clf_data.cls_df_test).collect()
|
||||||
|
|
||||||
|
|
||||||
class XgboostLocalTest(SparkTestCase):
|
class XgboostLocalTest(SparkTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user