[pyspark] Add validation for param 'early_stopping_rounds' and 'validation_indicator_col' (#8250)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
parent
0cd11b893a
commit
ff71c69adf
@ -292,6 +292,16 @@ class _SparkXGBParams(
|
|||||||
if not isinstance(self.getOrDefault(self.eval_metric), str):
|
if not isinstance(self.getOrDefault(self.eval_metric), str):
|
||||||
raise ValueError("Only string type 'eval_metric' param is allowed.")
|
raise ValueError("Only string type 'eval_metric' param is allowed.")
|
||||||
|
|
||||||
|
if self.getOrDefault(self.early_stopping_rounds) is not None:
|
||||||
|
if not (
|
||||||
|
self.isDefined(self.validationIndicatorCol)
|
||||||
|
and self.getOrDefault(self.validationIndicatorCol)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"If 'early_stopping_rounds' param is set, you need to set "
|
||||||
|
"'validation_indicator_col' param as well."
|
||||||
|
)
|
||||||
|
|
||||||
if self.getOrDefault(self.enable_sparse_data_optim):
|
if self.getOrDefault(self.enable_sparse_data_optim):
|
||||||
if self.getOrDefault(self.missing) != 0.0:
|
if self.getOrDefault(self.missing) != 0.0:
|
||||||
# If DMatrix is constructed from csr / csc matrix, then inactive elements
|
# If DMatrix is constructed from csr / csc matrix, then inactive elements
|
||||||
|
|||||||
@ -1145,3 +1145,8 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
)
|
)
|
||||||
classifier.fit(data_trans)
|
classifier.fit(data_trans)
|
||||||
|
|
||||||
|
def test_early_stop_param_validation(self):
|
||||||
|
classifier = SparkXGBClassifier(early_stopping_rounds=1)
|
||||||
|
with pytest.raises(ValueError, match="early_stopping_rounds"):
|
||||||
|
classifier.fit(self.cls_df_train)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user