[pyspark] Filter out the unsupported train parameters (#8355)
This commit is contained in:
parent
3901f5d9db
commit
76f95a6667
@ -126,6 +126,11 @@ _unsupported_fit_params = {
|
|||||||
"eval_qid", # Use spark param `qid_col` instead
|
"eval_qid", # Use spark param `qid_col` instead
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_unsupported_train_params = {
|
||||||
|
"evals", # Supported by spark param validation_indicator_col
|
||||||
|
"evals_result", # Won't support yet+
|
||||||
|
}
|
||||||
|
|
||||||
_unsupported_predict_params = {
|
_unsupported_predict_params = {
|
||||||
# for classification, we can use rawPrediction as margin
|
# for classification, we can use rawPrediction as margin
|
||||||
"output_margin",
|
"output_margin",
|
||||||
@ -515,6 +520,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
k in _unsupported_xgb_params
|
k in _unsupported_xgb_params
|
||||||
or k in _unsupported_fit_params
|
or k in _unsupported_fit_params
|
||||||
or k in _unsupported_predict_params
|
or k in _unsupported_predict_params
|
||||||
|
or k in _unsupported_train_params
|
||||||
):
|
):
|
||||||
raise ValueError(f"Unsupported param '{k}'.")
|
raise ValueError(f"Unsupported param '{k}'.")
|
||||||
_extra_params[k] = v
|
_extra_params[k] = v
|
||||||
@ -620,7 +626,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_xgb_train_call_args(cls, train_params):
|
def _get_xgb_train_call_args(cls, train_params):
|
||||||
xgb_train_default_args = _get_default_params_from_func(xgboost.train, {})
|
xgb_train_default_args = _get_default_params_from_func(
|
||||||
|
xgboost.train, _unsupported_train_params
|
||||||
|
)
|
||||||
booster_params, kwargs_params = {}, {}
|
booster_params, kwargs_params = {}, {}
|
||||||
for key, value in train_params.items():
|
for key, value in train_params.items():
|
||||||
if key in xgb_train_default_args:
|
if key in xgb_train_default_args:
|
||||||
|
|||||||
@ -1126,3 +1126,7 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
classifier = SparkXGBClassifier(early_stopping_rounds=1)
|
classifier = SparkXGBClassifier(early_stopping_rounds=1)
|
||||||
with pytest.raises(ValueError, match="early_stopping_rounds"):
|
with pytest.raises(ValueError, match="early_stopping_rounds"):
|
||||||
classifier.fit(self.cls_df_train)
|
classifier.fit(self.cls_df_train)
|
||||||
|
|
||||||
|
def test_unsupported_params(self):
|
||||||
|
with pytest.raises(ValueError, match="evals_result"):
|
||||||
|
SparkXGBClassifier(evals_result={})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user