[pyspark] Update eval_metric validation to support list of strings (#8826)
This commit is contained in:
@@ -730,6 +730,16 @@ class TestPySparkLocal:
|
||||
train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train)
|
||||
assert train_params["tree_method"] == "gpu_hist"
|
||||
|
||||
def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None:
|
||||
classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"])
|
||||
model = classifier.fit(clf_data.cls_df_train)
|
||||
model.transform(clf_data.cls_df_test).collect()
|
||||
|
||||
def test_classifier_with_string_eval_metric(self, clf_data: ClfData) -> None:
|
||||
classifier = SparkXGBClassifier(eval_metric="auc")
|
||||
model = classifier.fit(clf_data.cls_df_train)
|
||||
model.transform(clf_data.cls_df_test).collect()
|
||||
|
||||
|
||||
class XgboostLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user