Fix pylint. (#10296)
This commit is contained in:
parent
835e59e538
commit
ba9b4cb1ee
@ -527,7 +527,8 @@ def _validate_and_convert_feature_col_as_array_col(
|
||||
(DoubleType, FloatType, LongType, IntegerType, ShortType),
|
||||
):
|
||||
raise ValueError(
|
||||
"If feature column is array type, its elements must be number type."
|
||||
"If feature column is array type, its elements must be number type, "
|
||||
f"got {features_col_datatype.elementType}."
|
||||
)
|
||||
features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data)
|
||||
elif isinstance(features_col_datatype, VectorUDT):
|
||||
@ -1379,15 +1380,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
# to avoid the `self` object to be pickled to remote.
|
||||
xgb_sklearn_model = self._xgb_sklearn_model
|
||||
|
||||
has_base_margin = False
|
||||
base_margin_col = None
|
||||
if (
|
||||
self.isDefined(self.base_margin_col)
|
||||
and self.getOrDefault(self.base_margin_col) != ""
|
||||
):
|
||||
has_base_margin = True
|
||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||
alias.margin
|
||||
)
|
||||
has_base_margin = base_margin_col is not None
|
||||
|
||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||
@ -1472,6 +1473,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
||||
yield predict_func(model, X, base_margin)
|
||||
|
||||
if has_base_margin:
|
||||
assert base_margin_col is not None
|
||||
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
||||
else:
|
||||
pred_col = predict_udf(struct(*features_col))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user