[pyspark] Filter out the unsupported train parameters (#8355)

This commit is contained in:
Bobby Wang
2022-10-18 23:26:02 +08:00
committed by GitHub
parent 3901f5d9db
commit 76f95a6667
2 changed files with 13 additions and 1 deletions

View File

@@ -126,6 +126,11 @@ _unsupported_fit_params = {
"eval_qid", # Use spark param `qid_col` instead
}
_unsupported_train_params = {
"evals", # Supported by spark param validation_indicator_col
"evals_result", # Won't support yet+
}
_unsupported_predict_params = {
# for classification, we can use rawPrediction as margin
"output_margin",
@@ -515,6 +520,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
k in _unsupported_xgb_params
or k in _unsupported_fit_params
or k in _unsupported_predict_params
or k in _unsupported_train_params
):
raise ValueError(f"Unsupported param '{k}'.")
_extra_params[k] = v
@@ -620,7 +626,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
@classmethod
def _get_xgb_train_call_args(cls, train_params):
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
xgb_train_default_args = _get_default_params_from_func(
xgboost.train, _unsupported_train_params
)
booster_params, kwargs_params = {}, {}
for key, value in train_params.items():
if key in xgb_train_default_args: