[spark] Make spark model have the same UID with its estimator (#9022)

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
WeichenXu 2023-04-14 02:53:30 +08:00 committed by GitHub
parent 8e0f320db3
commit 191d0aa5cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 1 deletions

View File

@ -931,7 +931,11 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
result_xgb_model = self._convert_to_sklearn_model( result_xgb_model = self._convert_to_sklearn_model(
bytearray(booster, "utf-8"), config bytearray(booster, "utf-8"), config
) )
return self._copyValues(self._create_pyspark_model(result_xgb_model)) spark_model = self._create_pyspark_model(result_xgb_model)
# According to pyspark ML convention, the model uid should be the same
# with estimator uid.
spark_model._resetUid(self.uid)
return self._copyValues(spark_model)
def write(self): def write(self):
""" """

View File

@ -464,6 +464,7 @@ class TestPySparkLocal:
def test_regressor_basic(self, reg_data: RegData) -> None: def test_regressor_basic(self, reg_data: RegData) -> None:
regressor = SparkXGBRegressor(pred_contrib_col="pred_contribs") regressor = SparkXGBRegressor(pred_contrib_col="pred_contribs")
model = regressor.fit(reg_data.reg_df_train) model = regressor.fit(reg_data.reg_df_train)
assert regressor.uid == model.uid
pred_result = model.transform(reg_data.reg_df_test).collect() pred_result = model.transform(reg_data.reg_df_test).collect()
for row in pred_result: for row in pred_result:
np.testing.assert_equal(row.prediction, row.expected_prediction) np.testing.assert_equal(row.prediction, row.expected_prediction)