[pyspark] Filter out the unsupported train parameters (#8355)

This commit is contained in:
Bobby Wang 2022-10-18 23:26:02 +08:00 committed by GitHub
parent 3901f5d9db
commit 76f95a6667
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 1 deletions

View File

@ -126,6 +126,11 @@ _unsupported_fit_params = {
"eval_qid", # Use spark param `qid_col` instead "eval_qid", # Use spark param `qid_col` instead
} }
_unsupported_train_params = {
"evals", # Supported by spark param validation_indicator_col
"evals_result", # Won't support yet+
}
_unsupported_predict_params = { _unsupported_predict_params = {
# for classification, we can use rawPrediction as margin # for classification, we can use rawPrediction as margin
"output_margin", "output_margin",
@ -515,6 +520,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
k in _unsupported_xgb_params k in _unsupported_xgb_params
or k in _unsupported_fit_params or k in _unsupported_fit_params
or k in _unsupported_predict_params or k in _unsupported_predict_params
or k in _unsupported_train_params
): ):
raise ValueError(f"Unsupported param '{k}'.") raise ValueError(f"Unsupported param '{k}'.")
_extra_params[k] = v _extra_params[k] = v
@ -620,7 +626,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
@classmethod @classmethod
def _get_xgb_train_call_args(cls, train_params): def _get_xgb_train_call_args(cls, train_params):
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {}) xgb_train_default_args = _get_default_params_from_func(
xgboost.train, _unsupported_train_params
)
booster_params, kwargs_params = {}, {} booster_params, kwargs_params = {}, {}
for key, value in train_params.items(): for key, value in train_params.items():
if key in xgb_train_default_args: if key in xgb_train_default_args:

View File

@ -1126,3 +1126,7 @@ class XgboostLocalTest(SparkTestCase):
classifier = SparkXGBClassifier(early_stopping_rounds=1) classifier = SparkXGBClassifier(early_stopping_rounds=1)
with pytest.raises(ValueError, match="early_stopping_rounds"): with pytest.raises(ValueError, match="early_stopping_rounds"):
classifier.fit(self.cls_df_train) classifier.fit(self.cls_df_train)
def test_unsupported_params(self):
with pytest.raises(ValueError, match="evals_result"):
SparkXGBClassifier(evals_result={})