From 08a547f5c24e89a390f7cc07ebd64eec3a546800 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 15 Feb 2023 01:39:20 +0800 Subject: [PATCH] [backport] Fix feature types param (#8772) (#8801) Signed-off-by: Weichen Xu Co-authored-by: WeichenXu --- python-package/xgboost/spark/core.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index caa6e3cd0..4f770e139 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -140,6 +140,13 @@ _unsupported_predict_params = { } +# TODO: supply hint message for all other unsupported params. +_unsupported_params_hint_message = { + "enable_categorical": "`xgboost.spark` estimators do not have 'enable_categorical' param, " + "but you can set `feature_types` param and mark categorical features with 'c' string." +} + + class _SparkXGBParams( HasFeaturesCol, HasLabelCol, @@ -523,7 +530,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): or k in _unsupported_predict_params or k in _unsupported_train_params ): - raise ValueError(f"Unsupported param '{k}'.") + err_msg = _unsupported_params_hint_message.get( + k, f"Unsupported param '{k}'." + ) + raise ValueError(err_msg) _extra_params[k] = v _existing_extra_params = self.getOrDefault(self.arbitrary_params_dict) self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params}) @@ -749,6 +759,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): "feature_weights": self.getOrDefault(self.feature_weights), "missing": float(self.getOrDefault(self.missing)), } + if dmatrix_kwargs["feature_types"] is not None: + dmatrix_kwargs["enable_categorical"] = True booster_params["nthread"] = cpu_per_task use_gpu = self.getOrDefault(self.use_gpu)