From ff71c69adf801f0dd9b8121f0f0a8d7cb18dfdb7 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 26 Sep 2022 17:43:03 +0800 Subject: [PATCH] [pyspark] Add validation for param 'early_stopping_rounds' and 'validation_indicator_col' (#8250) Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 10 ++++++++++ tests/python/test_spark/test_spark_local.py | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index d1bc0e04b..ffeeae8a7 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -292,6 +292,16 @@ class _SparkXGBParams( if not isinstance(self.getOrDefault(self.eval_metric), str): raise ValueError("Only string type 'eval_metric' param is allowed.") + if self.getOrDefault(self.early_stopping_rounds) is not None: + if not ( + self.isDefined(self.validationIndicatorCol) + and self.getOrDefault(self.validationIndicatorCol) + ): + raise ValueError( + "If 'early_stopping_rounds' param is set, you need to set " + "'validation_indicator_col' param as well." + ) + if self.getOrDefault(self.enable_sparse_data_optim): if self.getOrDefault(self.missing) != 0.0: # If DMatrix is constructed from csr / csc matrix, then inactive elements diff --git a/tests/python/test_spark/test_spark_local.py b/tests/python/test_spark/test_spark_local.py index 0ad487098..3894bed4b 100644 --- a/tests/python/test_spark/test_spark_local.py +++ b/tests/python/test_spark/test_spark_local.py @@ -1145,3 +1145,8 @@ class XgboostLocalTest(SparkTestCase): num_workers=4, ) classifier.fit(data_trans) + + def test_early_stop_param_validation(self): + classifier = SparkXGBClassifier(early_stopping_rounds=1) + with pytest.raises(ValueError, match="early_stopping_rounds"): + classifier.fit(self.cls_df_train)