Fix feature types param (#8772)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
parent
52d0230b58
commit
f27a7258c6
@ -143,6 +143,12 @@ _unsupported_predict_params = {
|
||||
"base_margin", # Use pyspark base_margin_col param instead.
|
||||
}
|
||||
|
||||
# TODO: supply hint message for all other unsupported params.
|
||||
_unsupported_params_hint_message = {
|
||||
"enable_categorical": "`xgboost.spark` estimators do not have 'enable_categorical' param, "
|
||||
"but you can set `feature_types` param and mark categorical features with 'c' string."
|
||||
}
|
||||
|
||||
# Global prediction names
|
||||
Pred = namedtuple(
|
||||
"Pred", ("prediction", "raw_prediction", "probability", "pred_contrib")
|
||||
@ -540,7 +546,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
or k in _unsupported_predict_params
|
||||
or k in _unsupported_train_params
|
||||
):
|
||||
raise ValueError(f"Unsupported param '{k}'.")
|
||||
err_msg = _unsupported_params_hint_message.get(
|
||||
k, f"Unsupported param '{k}'."
|
||||
)
|
||||
raise ValueError(err_msg)
|
||||
_extra_params[k] = v
|
||||
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
|
||||
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
||||
@ -780,6 +789,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
"feature_weights": self.getOrDefault(self.feature_weights),
|
||||
"missing": float(self.getOrDefault(self.missing)),
|
||||
}
|
||||
if dmatrix_kwargs["feature_types"] is not None:
|
||||
dmatrix_kwargs["enable_categorical"] = True
|
||||
booster_params["nthread"] = cpu_per_task
|
||||
|
||||
# Remove the parameters whose value is None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user