[pyspark] Fix xgboost spark estimator dataset repartition issues (#8231)
This commit is contained in:
parent
3fd331f8f2
commit
ab342af242
@ -20,7 +20,7 @@ from pyspark.ml.param.shared import (
|
|||||||
HasWeightCol,
|
HasWeightCol,
|
||||||
)
|
)
|
||||||
from pyspark.ml.util import MLReadable, MLWritable
|
from pyspark.ml.util import MLReadable, MLWritable
|
||||||
from pyspark.sql.functions import col, countDistinct, pandas_udf, struct
|
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
|
||||||
from pyspark.sql.types import (
|
from pyspark.sql.types import (
|
||||||
ArrayType,
|
ArrayType,
|
||||||
DoubleType,
|
DoubleType,
|
||||||
@ -164,6 +164,12 @@ class _SparkXGBParams(
|
|||||||
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
|
+ "Note: The auto repartitioning judgement is not fully accurate, so it is recommended"
|
||||||
+ "to have force_repartition be True.",
|
+ "to have force_repartition be True.",
|
||||||
)
|
)
|
||||||
|
repartition_random_shuffle = Param(
|
||||||
|
Params._dummy(),
|
||||||
|
"repartition_random_shuffle",
|
||||||
|
"A boolean variable. Set repartition_random_shuffle=true if you want to random shuffle "
|
||||||
|
"dataset when repartitioning is required. By default is True.",
|
||||||
|
)
|
||||||
feature_names = Param(
|
feature_names = Param(
|
||||||
Params._dummy(), "feature_names", "A list of str to specify feature names."
|
Params._dummy(), "feature_names", "A list of str to specify feature names."
|
||||||
)
|
)
|
||||||
@ -270,15 +276,6 @@ class _SparkXGBParams(
|
|||||||
f"It cannot be less than 1 [Default is 1]"
|
f"It cannot be less than 1 [Default is 1]"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
|
||||||
self.getOrDefault(self.force_repartition)
|
|
||||||
and self.getOrDefault(self.num_workers) == 1
|
|
||||||
):
|
|
||||||
get_logger(self.__class__.__name__).warning(
|
|
||||||
"You set force_repartition to true when there is no need for a repartition."
|
|
||||||
"Therefore, that parameter will be ignored."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.getOrDefault(self.features_cols):
|
if self.getOrDefault(self.features_cols):
|
||||||
if not self.getOrDefault(self.use_gpu):
|
if not self.getOrDefault(self.use_gpu):
|
||||||
raise ValueError("features_cols param requires enabling use_gpu.")
|
raise ValueError("features_cols param requires enabling use_gpu.")
|
||||||
@ -470,6 +467,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
num_workers=1,
|
num_workers=1,
|
||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
force_repartition=False,
|
force_repartition=False,
|
||||||
|
repartition_random_shuffle=True,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None,
|
feature_types=None,
|
||||||
arbitrary_params_dict={},
|
arbitrary_params_dict={},
|
||||||
@ -695,8 +693,21 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
num_workers,
|
num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._repartition_needed(dataset):
|
if self._repartition_needed(dataset) or (
|
||||||
dataset = dataset.repartition(num_workers)
|
self.isDefined(self.validationIndicatorCol)
|
||||||
|
and self.getOrDefault(self.validationIndicatorCol)
|
||||||
|
):
|
||||||
|
# If validationIndicatorCol defined, we always repartition dataset
|
||||||
|
# to balance data, because user might unionise train and validation dataset,
|
||||||
|
# without shuffling data then some partitions might contain only train or validation
|
||||||
|
# dataset.
|
||||||
|
if self.getOrDefault(self.repartition_random_shuffle):
|
||||||
|
# In some cases, spark round-robin repartition might cause data skew
|
||||||
|
# use random shuffle can address it.
|
||||||
|
dataset = dataset.repartition(num_workers, rand(1))
|
||||||
|
else:
|
||||||
|
dataset = dataset.repartition(num_workers)
|
||||||
|
|
||||||
train_params = self._get_distributed_train_params(dataset)
|
train_params = self._get_distributed_train_params(dataset)
|
||||||
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
|
||||||
train_params
|
train_params
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user