[pyspark] Avoid repartition. (#10408)

This commit is contained in:
Bobby Wang
2024-06-12 02:26:10 +08:00
committed by GitHub
parent e0ebbc0746
commit cf0c1d0888
3 changed files with 28 additions and 49 deletions

View File

@@ -474,7 +474,7 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
classifier = SparkXGBClassifier(num_workers=self.n_workers)
basic = self.cls_df_train_distributed
self.assertTrue(classifier._repartition_needed(basic))
self.assertTrue(not classifier._repartition_needed(basic))
bad_repartitioned = basic.repartition(self.n_workers + 1)
self.assertTrue(classifier._repartition_needed(bad_repartitioned))
good_repartitioned = basic.repartition(self.n_workers)