Signed-off-by: Weichen Xu <weichen.xu@databricks.com> Co-authored-by: WeichenXu <weichen.xu@databricks.com>
This commit is contained in:
parent
3218f6cd3c
commit
e882fb3262
@ -866,7 +866,11 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
result_xgb_model = self._convert_to_sklearn_model(
|
result_xgb_model = self._convert_to_sklearn_model(
|
||||||
bytearray(booster, "utf-8"), config
|
bytearray(booster, "utf-8"), config
|
||||||
)
|
)
|
||||||
return self._copyValues(self._create_pyspark_model(result_xgb_model))
|
spark_model = self._create_pyspark_model(result_xgb_model)
|
||||||
|
# According to pyspark ML convention, the model uid should be the same
|
||||||
|
# with estimator uid.
|
||||||
|
spark_model._resetUid(self.uid)
|
||||||
|
return self._copyValues(spark_model)
|
||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user