[pyspark] Add validation for param 'early_stopping_rounds' and 'validation_indicator_col' (#8250)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
WeichenXu
2022-09-26 17:43:03 +08:00
committed by GitHub
parent 0cd11b893a
commit ff71c69adf
2 changed files with 15 additions and 0 deletions

View File

@@ -292,6 +292,16 @@ class _SparkXGBParams(
if not isinstance(self.getOrDefault(self.eval_metric), str):
raise ValueError("Only string type 'eval_metric' param is allowed.")
if self.getOrDefault(self.early_stopping_rounds) is not None:
if not (
self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol)
):
raise ValueError(
"If 'early_stopping_rounds' param is set, you need to set "
"'validation_indicator_col' param as well."
)
if self.getOrDefault(self.enable_sparse_data_optim):
if self.getOrDefault(self.missing) != 0.0:
# If DMatrix is constructed from csr / csc matrix, then inactive elements