[pyspark] Avoid repartition. (#10408)
This commit is contained in:
parent
e0ebbc0746
commit
cf0c1d0888
@ -267,7 +267,7 @@ An example submit command is shown below with additional spark configurations an
|
||||
--conf spark.task.cpus=1 \
|
||||
--conf spark.executor.resource.gpu.amount=1 \
|
||||
--conf spark.task.resource.gpu.amount=0.08 \
|
||||
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
|
||||
--packages com.nvidia:rapids-4-spark_2.12:24.04.1 \
|
||||
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
|
||||
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
|
||||
--archives xgboost_env.tar.gz#environment \
|
||||
@ -276,3 +276,21 @@ An example submit command is shown below with additional spark configurations an
|
||||
When rapids plugin is enabled, both of the JVM rapids plugin and the cuDF Python package
|
||||
are required. More configuration options can be found in the RAPIDS link above along with
|
||||
details on the plugin.
|
||||
|
||||
Advanced Usage
|
||||
==============
|
||||
|
||||
XGBoost needs to repartition the input dataset to the num_workers to ensure there will be
|
||||
num_workers training tasks running at the same time. However, repartition is a costly operation.
|
||||
|
||||
If there is a scenario where reading the data from source and directly fitting it to XGBoost
|
||||
without introducing the shuffle stage, users can avoid the need for repartitioning by setting
|
||||
the Spark configuration parameters ``spark.sql.files.maxPartitionNum`` and
|
||||
``spark.sql.files.minPartitionNum`` to num_workers. This tells Spark to automatically partition
|
||||
the dataset into the desired number of partitions.
|
||||
|
||||
However, if the input dataset is skewed (i.e. the data is not evenly distributed), setting
|
||||
the partition number to num_workers may not be efficient. In this case, users can set
|
||||
the ``force_repartition=true`` option to explicitly force XGBoost to repartition the dataset,
|
||||
even if the partition number is already equal to num_workers. This ensures the data is evenly
|
||||
distributed across the workers.
|
||||
|
||||
@ -691,50 +691,15 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
sklearn_model._Booster.load_config(config)
|
||||
return sklearn_model
|
||||
|
||||
def _query_plan_contains_valid_repartition(self, dataset: DataFrame) -> bool:
|
||||
"""
|
||||
Returns true if the latest element in the logical plan is a valid repartition
|
||||
The logic plan string format is like:
|
||||
|
||||
== Optimized Logical Plan ==
|
||||
Repartition 4, true
|
||||
+- LogicalRDD [features#12, label#13L], false
|
||||
|
||||
i.e., the top line in the logical plan is the last operation to execute.
|
||||
so, in this method, we check the first line, if it is a "Repartition" operation,
|
||||
and the result dataframe has the same partition number with num_workers param,
|
||||
then it means the dataframe is well repartitioned and we don't need to
|
||||
repartition the dataframe again.
|
||||
"""
|
||||
num_partitions = dataset.rdd.getNumPartitions()
|
||||
assert dataset._sc._jvm is not None
|
||||
query_plan = dataset._sc._jvm.PythonSQLUtils.explainString(
|
||||
dataset._jdf.queryExecution(), "extended"
|
||||
)
|
||||
start = query_plan.index("== Optimized Logical Plan ==")
|
||||
start += len("== Optimized Logical Plan ==") + 1
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
if (
|
||||
query_plan[start : start + len("Repartition")] == "Repartition"
|
||||
and num_workers == num_partitions
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _repartition_needed(self, dataset: DataFrame) -> bool:
|
||||
"""
|
||||
We repartition the dataset if the number of workers is not equal to the number of
|
||||
partitions. There is also a check to make sure there was "active partitioning"
|
||||
where either Round Robin or Hash partitioning was actively used before this stage.
|
||||
"""
|
||||
partitions."""
|
||||
if self.getOrDefault(self.force_repartition):
|
||||
return True
|
||||
try:
|
||||
if self._query_plan_contains_valid_repartition(dataset):
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
return True
|
||||
num_workers = self.getOrDefault(self.num_workers)
|
||||
num_partitions = dataset.rdd.getNumPartitions()
|
||||
return not num_workers == num_partitions
|
||||
|
||||
def _get_distributed_train_params(self, dataset: DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -871,14 +836,10 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
num_workers,
|
||||
)
|
||||
|
||||
if self._repartition_needed(dataset) or (
|
||||
self.isDefined(self.validationIndicatorCol)
|
||||
and self.getOrDefault(self.validationIndicatorCol) != ""
|
||||
):
|
||||
# If validationIndicatorCol defined, we always repartition dataset
|
||||
# to balance data, because user might unionise train and validation dataset,
|
||||
# without shuffling data then some partitions might contain only train or validation
|
||||
# dataset.
|
||||
if self._repartition_needed(dataset):
|
||||
# If validationIndicatorCol defined, and if user unionise train and validation
|
||||
# dataset, users must set force_repartition to true to force repartition.
|
||||
# Or else some partitions might contain only train or validation dataset.
|
||||
if self.getOrDefault(self.repartition_random_shuffle):
|
||||
# In some cases, spark round-robin repartition might cause data skew
|
||||
# use random shuffle can address it.
|
||||
|
||||
@ -474,7 +474,7 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
|
||||
|
||||
classifier = SparkXGBClassifier(num_workers=self.n_workers)
|
||||
basic = self.cls_df_train_distributed
|
||||
self.assertTrue(classifier._repartition_needed(basic))
|
||||
self.assertTrue(not classifier._repartition_needed(basic))
|
||||
bad_repartitioned = basic.repartition(self.n_workers + 1)
|
||||
self.assertTrue(classifier._repartition_needed(bad_repartitioned))
|
||||
good_repartitioned = basic.repartition(self.n_workers)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user