Fix pylint. (#10296)

This commit is contained in:
Jiaming Yuan 2024-05-17 13:28:39 +08:00 committed by GitHub
parent 835e59e538
commit ba9b4cb1ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -527,7 +527,8 @@ def _validate_and_convert_feature_col_as_array_col(
(DoubleType, FloatType, LongType, IntegerType, ShortType),
):
raise ValueError(
"If feature column is array type, its elements must be number type."
"If feature column is array type, its elements must be number type, "
f"got {features_col_datatype.elementType}."
)
features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data)
elif isinstance(features_col_datatype, VectorUDT):
@ -1379,15 +1380,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
has_base_margin = False
base_margin_col = None
if (
self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
):
has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
alias.margin
)
has_base_margin = base_margin_col is not None
features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
@ -1472,6 +1473,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
yield predict_func(model, X, base_margin)
if has_base_margin:
assert base_margin_col is not None
pred_col = predict_udf(struct(*features_col, base_margin_col))
else:
pred_col = predict_udf(struct(*features_col))