Signed-off-by: Weichen Xu <weichen.xu@databricks.com> Co-authored-by: WeichenXu <weichen.xu@databricks.com>
This commit is contained in:
parent
60303db2ee
commit
08a547f5c2
@ -140,6 +140,13 @@ _unsupported_predict_params = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: supply hint message for all other unsupported params.
|
||||||
|
_unsupported_params_hint_message = {
|
||||||
|
"enable_categorical": "`xgboost.spark` estimators do not have 'enable_categorical' param, "
|
||||||
|
"but you can set `feature_types` param and mark categorical features with 'c' string."
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class _SparkXGBParams(
|
class _SparkXGBParams(
|
||||||
HasFeaturesCol,
|
HasFeaturesCol,
|
||||||
HasLabelCol,
|
HasLabelCol,
|
||||||
@ -523,7 +530,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
or k in _unsupported_predict_params
|
or k in _unsupported_predict_params
|
||||||
or k in _unsupported_train_params
|
or k in _unsupported_train_params
|
||||||
):
|
):
|
||||||
raise ValueError(f"Unsupported param '{k}'.")
|
err_msg = _unsupported_params_hint_message.get(
|
||||||
|
k, f"Unsupported param '{k}'."
|
||||||
|
)
|
||||||
|
raise ValueError(err_msg)
|
||||||
_extra_params[k] = v
|
_extra_params[k] = v
|
||||||
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
|
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
|
||||||
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
|
||||||
@ -749,6 +759,8 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
"feature_weights": self.getOrDefault(self.feature_weights),
|
"feature_weights": self.getOrDefault(self.feature_weights),
|
||||||
"missing": float(self.getOrDefault(self.missing)),
|
"missing": float(self.getOrDefault(self.missing)),
|
||||||
}
|
}
|
||||||
|
if dmatrix_kwargs["feature_types"] is not None:
|
||||||
|
dmatrix_kwargs["enable_categorical"] = True
|
||||||
booster_params["nthread"] = cpu_per_task
|
booster_params["nthread"] = cpu_per_task
|
||||||
use_gpu = self.getOrDefault(self.use_gpu)
|
use_gpu = self.getOrDefault(self.use_gpu)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user