[pyspark] Add param validation for "objective" and "eval_metric" param, and remove invalid booster params (#8173)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
parent
9b32e6e2dc
commit
d03794ce7a
@ -114,10 +114,10 @@ _unsupported_xgb_params = [
|
|||||||
|
|
||||||
_unsupported_fit_params = {
|
_unsupported_fit_params = {
|
||||||
"sample_weight", # Supported by spark param weightCol
|
"sample_weight", # Supported by spark param weightCol
|
||||||
# Supported by spark param weightCol # and validationIndicatorCol
|
"eval_set", # Supported by spark param validation_indicator_col
|
||||||
"eval_set",
|
"sample_weight_eval_set", # Supported by spark param weight_col + validation_indicator_col
|
||||||
"sample_weight_eval_set",
|
|
||||||
"base_margin", # Supported by spark param base_margin_col
|
"base_margin", # Supported by spark param base_margin_col
|
||||||
|
"base_margin_eval_set", # Supported by spark param base_margin_col + validation_indicator_col
|
||||||
"group", # Use spark param `qid_col` instead
|
"group", # Use spark param `qid_col` instead
|
||||||
"qid", # Use spark param `qid_col` instead
|
"qid", # Use spark param `qid_col` instead
|
||||||
"eval_group", # Use spark param `qid_col` instead
|
"eval_group", # Use spark param `qid_col` instead
|
||||||
@ -287,6 +287,14 @@ class _SparkXGBParams(
|
|||||||
"If features_cols param set, then features_col param is ignored."
|
"If features_cols param set, then features_col param is ignored."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.getOrDefault(self.objective) is not None:
|
||||||
|
if not isinstance(self.getOrDefault(self.objective), str):
|
||||||
|
raise ValueError("Only string type 'objective' param is allowed.")
|
||||||
|
|
||||||
|
if self.getOrDefault(self.eval_metric) is not None:
|
||||||
|
if not isinstance(self.getOrDefault(self.eval_metric), str):
|
||||||
|
raise ValueError("Only string type 'eval_metric' param is allowed.")
|
||||||
|
|
||||||
if self.getOrDefault(self.enable_sparse_data_optim):
|
if self.getOrDefault(self.enable_sparse_data_optim):
|
||||||
if self.getOrDefault(self.missing) != 0.0:
|
if self.getOrDefault(self.missing) != 0.0:
|
||||||
# If DMatrix is constructed from csr / csc matrix, then inactive elements
|
# If DMatrix is constructed from csr / csc matrix, then inactive elements
|
||||||
@ -578,7 +586,6 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
params.update(fit_params)
|
params.update(fit_params)
|
||||||
params["verbose_eval"] = verbose_eval
|
params["verbose_eval"] = verbose_eval
|
||||||
classification = self._xgb_cls() == XGBClassifier
|
classification = self._xgb_cls() == XGBClassifier
|
||||||
num_classes = int(dataset.select(countDistinct(alias.label)).collect()[0][0])
|
|
||||||
if classification:
|
if classification:
|
||||||
num_classes = int(
|
num_classes = int(
|
||||||
dataset.select(countDistinct(alias.label)).collect()[0][0]
|
dataset.select(countDistinct(alias.label)).collect()[0][0]
|
||||||
@ -610,6 +617,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
kwargs_params[key] = value
|
kwargs_params[key] = value
|
||||||
else:
|
else:
|
||||||
booster_params[key] = value
|
booster_params[key] = value
|
||||||
|
|
||||||
|
booster_params = {
|
||||||
|
k: v for k, v in booster_params.items() if k not in _non_booster_params
|
||||||
|
}
|
||||||
return booster_params, kwargs_params
|
return booster_params, kwargs_params
|
||||||
|
|
||||||
def _fit(self, dataset):
|
def _fit(self, dataset):
|
||||||
|
|||||||
@ -211,6 +211,11 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# The default 'objective' param value comes from sklearn `XGBClassifier` ctor,
|
||||||
|
# but in pyspark we will automatically set objective param depending on
|
||||||
|
# binary or multinomial input dataset, and we need to remove the fixed default
|
||||||
|
# param value as well to avoid causing ambiguity.
|
||||||
|
self._setDefault(objective=None)
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -227,6 +232,10 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Spark Xgboost classifier estimator does not support `qid_col` param."
|
"Spark Xgboost classifier estimator does not support `qid_col` param."
|
||||||
)
|
)
|
||||||
|
if self.getOrDefault(self.objective): # pylint: disable=no-member
|
||||||
|
raise ValueError(
|
||||||
|
"Setting custom 'objective' param is not allowed in 'SparkXGBClassifier'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
|
_set_pyspark_xgb_cls_param_attrs(SparkXGBClassifier, SparkXGBClassifierModel)
|
||||||
|
|||||||
@ -433,6 +433,7 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
|
self.assertEqual(py_reg.n_estimators.parent, py_reg.uid)
|
||||||
self.assertFalse(hasattr(py_reg, "gpu_id"))
|
self.assertFalse(hasattr(py_reg, "gpu_id"))
|
||||||
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
|
self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100)
|
||||||
|
self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror")
|
||||||
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
py_reg2 = SparkXGBRegressor(n_estimators=200)
|
||||||
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
|
self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200)
|
||||||
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
|
py_reg3 = py_reg2.copy({py_reg2.max_depth: 10})
|
||||||
@ -445,6 +446,7 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
self.assertEqual(py_cls.n_estimators.parent, py_cls.uid)
|
||||||
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
self.assertFalse(hasattr(py_cls, "gpu_id"))
|
||||||
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
|
self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100)
|
||||||
|
self.assertEqual(py_cls.getOrDefault(py_cls.objective), None)
|
||||||
py_cls2 = SparkXGBClassifier(n_estimators=200)
|
py_cls2 = SparkXGBClassifier(n_estimators=200)
|
||||||
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
|
self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200)
|
||||||
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
|
py_cls3 = py_cls2.copy({py_cls2.max_depth: 10})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user