[pyspark] Filter out the unsupported train parameters (#8355)
This commit is contained in:
parent
3901f5d9db
commit
76f95a6667
@ -126,6 +126,11 @@ _unsupported_fit_params = {
|
||||
"eval_qid", # Use spark param `qid_col` instead
|
||||
}
|
||||
|
||||
_unsupported_train_params = {
|
||||
"evals", # Supported by spark param validation_indicator_col
|
||||
"evals_result", # Won't support yet+
|
||||
}
|
||||
|
||||
_unsupported_predict_params = {
|
||||
# for classification, we can use rawPrediction as margin
|
||||
"output_margin",
|
||||
@ -515,6 +520,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
k in _unsupported_xgb_params
|
||||
or k in _unsupported_fit_params
|
||||
or k in _unsupported_predict_params
|
||||
or k in _unsupported_train_params
|
||||
):
|
||||
raise ValueError(f"Unsupported param '{k}'.")
|
||||
_extra_params[k] = v
|
||||
@ -620,7 +626,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
@classmethod
|
||||
def _get_xgb_train_call_args(cls, train_params):
|
||||
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
|
||||
xgb_train_default_args = _get_default_params_from_func(
|
||||
xgboost.train, _unsupported_train_params
|
||||
)
|
||||
booster_params, kwargs_params = {}, {}
|
||||
for key, value in train_params.items():
|
||||
if key in xgb_train_default_args:
|
||||
|
||||
@ -1126,3 +1126,7 @@ class XgboostLocalTest(SparkTestCase):
|
||||
classifier = SparkXGBClassifier(early_stopping_rounds=1)
|
||||
with pytest.raises(ValueError, match="early_stopping_rounds"):
|
||||
classifier.fit(self.cls_df_train)
|
||||
|
||||
def test_unsupported_params(self):
|
||||
with pytest.raises(ValueError, match="evals_result"):
|
||||
SparkXGBClassifier(evals_result={})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user