[pyspark] Avoid repartition. (#10408)

This commit is contained in:
Bobby Wang
2024-06-12 02:26:10 +08:00
committed by GitHub
parent e0ebbc0746
commit cf0c1d0888
3 changed files with 28 additions and 49 deletions

View File

@@ -691,50 +691,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
sklearn_model._Booster.load_config(config)
return sklearn_model
def _query_plan_contains_valid_repartition(self, dataset: DataFrame) -> bool:
"""
Returns true if the latest element in the logical plan is a valid repartition
The logic plan string format is like:
== Optimized Logical Plan ==
Repartition 4, true
+- LogicalRDD [features#12, label#13L], false
i.e., the top line in the logical plan is the last operation to execute.
so, in this method, we check the first line, if it is a "Repartition" operation,
and the result dataframe has the same partition number with num_workers param,
then it means the dataframe is well repartitioned and we don't need to
repartition the dataframe again.
"""
num_partitions = dataset.rdd.getNumPartitions()
assert dataset._sc._jvm is not None
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
dataset._jdf.queryExecution(), "extended"
)
start = query_plan.index("== Optimized Logical Plan ==")
start += len("== Optimized Logical Plan ==") + 1
num_workers = self.getOrDefault(self.num_workers)
if (
query_plan[start : start + len("Repartition")] == "Repartition"
and num_workers == num_partitions
):
return True
return False
def _repartition_needed(self, dataset: DataFrame) -> bool:
"""
We repartition the dataset if the number of workers is not equal to the number of
partitions. There is also a check to make sure there was "active partitioning"
where either Round Robin or Hash partitioning was actively used before this stage.
"""
partitions."""
if self.getOrDefault(self.force_repartition):
return True
try:
if self._query_plan_contains_valid_repartition(dataset):
return False
except Exception: # pylint: disable=broad-except
pass
return True
num_workers = self.getOrDefault(self.num_workers)
num_partitions = dataset.rdd.getNumPartitions()
return not num_workers == num_partitions
def _get_distributed_train_params(self, dataset: DataFrame) -> Dict[str, Any]:
"""
@@ -871,14 +836,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
num_workers,
)
if self._repartition_needed(dataset) or (
self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol) != ""
):
# If validationIndicatorCol defined, we always repartition dataset
# to balance data, because user might unionise train and validation dataset,
# without shuffling data then some partitions might contain only train or validation
# dataset.
if self._repartition_needed(dataset):
# If validationIndicatorCol defined, and if user unionise train and validation
# dataset, users must set force_repartition to true to force repartition.
# Or else some partitions might contain only train or validation dataset.
if self.getOrDefault(self.repartition_random_shuffle):
# In some cases, spark round-robin repartition might cause data skew
# use random shuffle can address it.