diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 4f770e139..de3ae02cc 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -866,7 +866,11 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): result_xgb_model = self._convert_to_sklearn_model( bytearray(booster, "utf-8"), config ) - return self._copyValues(self._create_pyspark_model(result_xgb_model)) + spark_model = self._create_pyspark_model(result_xgb_model) + # According to pyspark ML convention, the model uid should be the same + # with estimator uid. + spark_model._resetUid(self.uid) + return self._copyValues(spark_model) def write(self): """