Fix pylint. (#10296)
This commit is contained in:
parent
835e59e538
commit
ba9b4cb1ee
@ -527,7 +527,8 @@ def _validate_and_convert_feature_col_as_array_col(
|
|||||||
(DoubleType, FloatType, LongType, IntegerType, ShortType),
|
(DoubleType, FloatType, LongType, IntegerType, ShortType),
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If feature column is array type, its elements must be number type."
|
"If feature column is array type, its elements must be number type, "
|
||||||
|
f"got {features_col_datatype.elementType}."
|
||||||
)
|
)
|
||||||
features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data)
|
features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data)
|
||||||
elif isinstance(features_col_datatype, VectorUDT):
|
elif isinstance(features_col_datatype, VectorUDT):
|
||||||
@ -1379,15 +1380,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
# to avoid the `self` object to be pickled to remote.
|
# to avoid the `self` object to be pickled to remote.
|
||||||
xgb_sklearn_model = self._xgb_sklearn_model
|
xgb_sklearn_model = self._xgb_sklearn_model
|
||||||
|
|
||||||
has_base_margin = False
|
base_margin_col = None
|
||||||
if (
|
if (
|
||||||
self.isDefined(self.base_margin_col)
|
self.isDefined(self.base_margin_col)
|
||||||
and self.getOrDefault(self.base_margin_col) != ""
|
and self.getOrDefault(self.base_margin_col) != ""
|
||||||
):
|
):
|
||||||
has_base_margin = True
|
|
||||||
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
|
||||||
alias.margin
|
alias.margin
|
||||||
)
|
)
|
||||||
|
has_base_margin = base_margin_col is not None
|
||||||
|
|
||||||
features_col, feature_col_names = self._get_feature_col(dataset)
|
features_col, feature_col_names = self._get_feature_col(dataset)
|
||||||
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
|
||||||
@ -1472,6 +1473,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
yield predict_func(model, X, base_margin)
|
yield predict_func(model, X, base_margin)
|
||||||
|
|
||||||
if has_base_margin:
|
if has_base_margin:
|
||||||
|
assert base_margin_col is not None
|
||||||
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
pred_col = predict_udf(struct(*features_col, base_margin_col))
|
||||||
else:
|
else:
|
||||||
pred_col = predict_udf(struct(*features_col))
|
pred_col = predict_udf(struct(*features_col))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user