[pyspark] Avoid repartition. (#10408)
This commit is contained in:
@@ -691,50 +691,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
sklearn_model._Booster.load_config(config)
|
||||
return sklearn_model
|
||||
|
||||
def _query_plan_contains_valid_repartition(self, dataset: DataFrame) -> bool:
|
||||
"""
|
||||
Returns true if the latest element in the logical plan is a valid repartition
|
||||
The logic plan string format is like:
|
||||
|
||||
== Optimized Logical Plan ==
|
||||
Repartition 4, true
|
||||
+- LogicalRDD [features#12, label#13L], false
|
||||
|
||||
i.e., the top line in the logical plan is the last operation to execute.
|
||||
so, in this method, we check the first line, if it is a "Repartition" operation,
|
||||
and the result dataframe has the same partition number with num_workers param,
|
||||
then it means the dataframe is well repartitioned and we don't need to
|
||||
repartition the dataframe again.
|
||||
"""
|
||||
num_partitions = dataset.rdd.getNumPartitions()
|
||||
assert dataset._sc._jvm is not None
|
||||
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
|
||||
dataset._jdf.queryExecution(), "extended"
|
||||
)
|
||||
start = query_plan.index("== Optimized Logical Plan ==")
|
||||
start += len("== Optimized Logical Plan ==") + 1
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
if (
|
||||
query_plan[start : start + len("Repartition")] == "Repartition"
|
||||
and num_workers == num_partitions
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _repartition_needed(self, dataset: DataFrame) -> bool:
|
||||
"""
|
||||
We repartition the dataset if the number of workers is not equal to the number of
|
||||
partitions. There is also a check to make sure there was "active partitioning"
|
||||
where either Round Robin or Hash partitioning was actively used before this stage.
|
||||
"""
|
||||
partitions."""
|
||||
if self.getOrDefault(self.force_repartition):
|
||||
return True
|
||||
try:
|
||||
if self._query_plan_contains_valid_repartition(dataset):
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
return True
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
num_partitions = dataset.rdd.getNumPartitions()
|
||||
return not num_workers == num_partitions
|
||||
|
||||
def _get_distributed_train_params(self, dataset: DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -871,14 +836,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
num_workers,
|
||||
)
|
||||
|
||||
if self._repartition_needed(dataset) or (
|
||||
self.isDefined(self.validationIndicatorCol)
|
||||
and self.getOrDefault(self.validationIndicatorCol) != ""
|
||||
):
|
||||
# If validationIndicatorCol defined, we always repartition dataset
|
||||
# to balance data, because user might unionise train and validation dataset,
|
||||
# without shuffling data then some partitions might contain only train or validation
|
||||
# dataset.
|
||||
if self._repartition_needed(dataset):
|
||||
# If validationIndicatorCol defined, and if user unionise train and validation
|
||||
# dataset, users must set force_repartition to true to force repartition.
|
||||
# Or else some partitions might contain only train or validation dataset.
|
||||
if self.getOrDefault(self.repartition_random_shuffle):
|
||||
# In some cases, spark round-robin repartition might cause data skew
|
||||
# use random shuffle can address it.
|
||||
|
||||
Reference in New Issue
Block a user