[backport] Fix feature types param (#8772) (#8801)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Co-authored-by: WeichenXu <weichen.xu@databricks.com>
This commit is contained in:
Jiaming Yuan 2023-02-15 01:39:20 +08:00 committed by GitHub
parent 60303db2ee
commit 08a547f5c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -140,6 +140,13 @@ _unsupported_predict_params = {
} }
# TODO: supply hint message for all other unsupported params.
_unsupported_params_hint_message = {
"enable_categorical": "`xgboost.spark` estimators do not have 'enable_categorical' param, "
"but you can set `feature_types` param and mark categorical features with 'c' string."
}
class _SparkXGBParams( class _SparkXGBParams(
HasFeaturesCol, HasFeaturesCol,
HasLabelCol, HasLabelCol,
@ -523,7 +530,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
or k in _unsupported_predict_params or k in _unsupported_predict_params
or k in _unsupported_train_params or k in _unsupported_train_params
): ):
raise ValueError(f"Unsupported param '{k}'.") err_msg = _unsupported_params_hint_message.get(
k, f"Unsupported param '{k}'."
)
raise ValueError(err_msg)
_extra_params[k] = v _extra_params[k] = v
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict) _existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params}) self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
@ -749,6 +759,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"feature_weights": self.getOrDefault(self.feature_weights), "feature_weights": self.getOrDefault(self.feature_weights),
"missing": float(self.getOrDefault(self.missing)), "missing": float(self.getOrDefault(self.missing)),
} }
if dmatrix_kwargs["feature_types"] is not None:
dmatrix_kwargs["enable_categorical"] = True
booster_params["nthread"] = cpu_per_task booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu) use_gpu = self.getOrDefault(self.use_gpu)