From ba9b4cb1eecb90f7e3129cc6b8ece5e09dfd5472 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 17 May 2024 13:28:39 +0800 Subject: [PATCH] Fix pylint. (#10296) --- python-package/xgboost/spark/core.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 2f24effe5..8134ec7e7 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -527,7 +527,8 @@ def _validate_and_convert_feature_col_as_array_col( (DoubleType, FloatType, LongType, IntegerType, ShortType), ): raise ValueError( - "If feature column is array type, its elements must be number type." + "If feature column is array type, its elements must be number type, " + f"got {features_col_datatype.elementType}." ) features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data) elif isinstance(features_col_datatype, VectorUDT): @@ -1379,15 +1380,15 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): # to avoid the `self` object to be pickled to remote. xgb_sklearn_model = self._xgb_sklearn_model - has_base_margin = False + base_margin_col = None if ( self.isDefined(self.base_margin_col) and self.getOrDefault(self.base_margin_col) != "" ): - has_base_margin = True base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias( alias.margin ) + has_base_margin = base_margin_col is not None features_col, feature_col_names = self._get_feature_col(dataset) enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim) @@ -1472,6 +1473,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): yield predict_func(model, X, base_margin) if has_base_margin: + assert base_margin_col is not None pred_col = predict_udf(struct(*features_col, base_margin_col)) else: pred_col = predict_udf(struct(*features_col))