From f27a7258c6ed81e45020360ba614c050b50ef709 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 14 Feb 2023 02:16:42 +0800 Subject: [PATCH] Fix feature types param (#8772) Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index b9be4b39b..6d9733817 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -143,6 +143,12 @@ _unsupported_predict_params = { "base_margin", # Use pyspark base_margin_col param instead. } +# TODO: supply hint message for all other unsupported params. +_unsupported_params_hint_message = { + "enable_categorical": "`xgboost.spark` estimators do not have 'enable_categorical' param, " + "but you can set `feature_types` param and mark categorical features with 'c' string." +} + # Global prediction names Pred = namedtuple( "Pred", ("prediction", "raw_prediction", "probability", "pred_contrib") @@ -540,7 +546,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): or k in _unsupported_predict_params or k in _unsupported_train_params ): - raise ValueError(f"Unsupported param '{k}'.") + err_msg = _unsupported_params_hint_message.get( + k, f"Unsupported param '{k}'." + ) + raise ValueError(err_msg) _extra_params[k] = v _existing_extra_params = self.getOrDefault(self.arbitrary_params_dict) self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params}) @@ -780,6 +789,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): "feature_weights": self.getOrDefault(self.feature_weights), "missing": float(self.getOrDefault(self.missing)), } + if dmatrix_kwargs["feature_types"] is not None: + dmatrix_kwargs["enable_categorical"] = True booster_params["nthread"] = cpu_per_task # Remove the parameters whose value is None