[pyspark] Filter out the unsupported train parameters (#8355)
This commit is contained in:
@@ -126,6 +126,11 @@ _unsupported_fit_params = {
|
||||
"eval_qid", # Use spark param `qid_col` instead
|
||||
}
|
||||
|
||||
_unsupported_train_params = {
|
||||
"evals", # Supported by spark param validation_indicator_col
|
||||
"evals_result", # Won't support yet+
|
||||
}
|
||||
|
||||
_unsupported_predict_params = {
|
||||
# for classification, we can use rawPrediction as margin
|
||||
"output_margin",
|
||||
@@ -515,6 +520,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
k in _unsupported_xgb_params
|
||||
or k in _unsupported_fit_params
|
||||
or k in _unsupported_predict_params
|
||||
or k in _unsupported_train_params
|
||||
):
|
||||
raise ValueError(f"Unsupported param '{k}'.")
|
||||
_extra_params[k] = v
|
||||
@@ -620,7 +626,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
@classmethod
|
||||
def _get_xgb_train_call_args(cls, train_params):
|
||||
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
|
||||
xgb_train_default_args = _get_default_params_from_func(
|
||||
xgboost.train, _unsupported_train_params
|
||||
)
|
||||
booster_params, kwargs_params = {}, {}
|
||||
for key, value in train_params.items():
|
||||
if key in xgb_train_default_args:
|
||||
|
||||
Reference in New Issue
Block a user