[pyspark] fix empty data issue when constructing DMatrix (#8245)
Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
70df36c99c
commit
520586ffa7
@ -658,12 +658,17 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
col(self.getOrDefault(self.weightCol)).alias(alias.weight)
|
col(self.getOrDefault(self.weightCol)).alias(alias.weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
has_validation_col = False
|
||||||
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault(
|
if self.isDefined(self.validationIndicatorCol) and self.getOrDefault(
|
||||||
self.validationIndicatorCol
|
self.validationIndicatorCol
|
||||||
):
|
):
|
||||||
select_cols.append(
|
select_cols.append(
|
||||||
col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid)
|
col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid)
|
||||||
)
|
)
|
||||||
|
# In some cases, see https://issues.apache.org/jira/browse/SPARK-40407,
|
||||||
|
# the df.repartition can result in some reducer partitions without data,
|
||||||
|
# which will cause exception or hanging issue when creating DMatrix.
|
||||||
|
has_validation_col = True
|
||||||
|
|
||||||
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
if self.isDefined(self.base_margin_col) and self.getOrDefault(
|
||||||
self.base_margin_col
|
self.base_margin_col
|
||||||
@ -765,6 +770,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
gpu_id,
|
gpu_id,
|
||||||
dmatrix_kwargs,
|
dmatrix_kwargs,
|
||||||
enable_sparse_data_optim=enable_sparse_data_optim,
|
enable_sparse_data_optim=enable_sparse_data_optim,
|
||||||
|
has_validation_col=has_validation_col,
|
||||||
)
|
)
|
||||||
if dvalid is not None:
|
if dvalid is not None:
|
||||||
dval = [(dtrain, "training"), (dvalid, "validation")]
|
dval = [(dtrain, "training"), (dvalid, "validation")]
|
||||||
|
|||||||
@ -147,12 +147,13 @@ def _read_csr_matrix_from_unwrapped_spark_vec(part: pd.DataFrame) -> csr_matrix:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_dmatrix_from_partitions(
|
def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
|
||||||
iterator: Iterator[pd.DataFrame],
|
iterator: Iterator[pd.DataFrame],
|
||||||
feature_cols: Optional[Sequence[str]],
|
feature_cols: Optional[Sequence[str]],
|
||||||
gpu_id: Optional[int],
|
gpu_id: Optional[int],
|
||||||
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
|
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
|
||||||
enable_sparse_data_optim: bool,
|
enable_sparse_data_optim: bool,
|
||||||
|
has_validation_col: bool,
|
||||||
) -> Tuple[DMatrix, Optional[DMatrix]]:
|
) -> Tuple[DMatrix, Optional[DMatrix]]:
|
||||||
"""Create DMatrix from spark data partitions. This is not particularly efficient as
|
"""Create DMatrix from spark data partitions. This is not particularly efficient as
|
||||||
we need to convert the pandas series format to numpy then concatenate all the data.
|
we need to convert the pandas series format to numpy then concatenate all the data.
|
||||||
@ -173,7 +174,7 @@ def create_dmatrix_from_partitions(
|
|||||||
|
|
||||||
def append_m(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
def append_m(part: pd.DataFrame, name: str, is_valid: bool) -> None:
|
||||||
nonlocal n_features
|
nonlocal n_features
|
||||||
if name in part.columns:
|
if name in part.columns and part[name].shape[0] > 0:
|
||||||
array = part[name]
|
array = part[name]
|
||||||
if name == alias.data:
|
if name == alias.data:
|
||||||
array = stack_series(array)
|
array = stack_series(array)
|
||||||
@ -224,6 +225,10 @@ def create_dmatrix_from_partitions(
|
|||||||
train_data[name].append(array)
|
train_data[name].append(array)
|
||||||
|
|
||||||
def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix:
|
def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix:
|
||||||
|
if len(values) == 0:
|
||||||
|
# We must construct an empty DMatrix to bypass the AllReduce
|
||||||
|
return DMatrix(data=np.empty((0, 0)), **kwargs)
|
||||||
|
|
||||||
data = concat_or_none(values[alias.data])
|
data = concat_or_none(values[alias.data])
|
||||||
label = concat_or_none(values.get(alias.label, None))
|
label = concat_or_none(values.get(alias.label, None))
|
||||||
weight = concat_or_none(values.get(alias.weight, None))
|
weight = concat_or_none(values.get(alias.weight, None))
|
||||||
@ -247,9 +252,14 @@ def create_dmatrix_from_partitions(
|
|||||||
it = PartIter(train_data, gpu_id)
|
it = PartIter(train_data, gpu_id)
|
||||||
dtrain = DeviceQuantileDMatrix(it, **kwargs)
|
dtrain = DeviceQuantileDMatrix(it, **kwargs)
|
||||||
|
|
||||||
dvalid = make(valid_data, kwargs) if len(valid_data) != 0 else None
|
# Using has_validation_col here to indicate if there is validation col
|
||||||
|
# instead of getting it from iterator, since the iterator may be empty
|
||||||
|
# in some special case. That is to say, we must ensure every worker
|
||||||
|
# construct DMatrix even there is no any data since we need to ensure every
|
||||||
|
# worker do the AllReduce when constructing DMatrix, or else it may hang
|
||||||
|
# forever.
|
||||||
|
dvalid = make(valid_data, kwargs) if has_validation_col else None
|
||||||
|
|
||||||
assert dtrain.num_col() == n_features
|
|
||||||
if dvalid is not None:
|
if dvalid is not None:
|
||||||
assert dvalid.num_col() == dtrain.num_col()
|
assert dvalid.num_col() == dtrain.num_col()
|
||||||
|
|
||||||
|
|||||||
@ -68,11 +68,11 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
|
|||||||
if is_dqm:
|
if is_dqm:
|
||||||
cols = [f"feat-{i}" for i in range(n_features)]
|
cols = [f"feat-{i}" for i in range(n_features)]
|
||||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||||
iter(dfs), cols, 0, kwargs, False
|
iter(dfs), cols, 0, kwargs, False, True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
train_Xy, valid_Xy = create_dmatrix_from_partitions(
|
||||||
iter(dfs), None, None, kwargs, False
|
iter(dfs), None, None, kwargs, False, True
|
||||||
)
|
)
|
||||||
|
|
||||||
assert valid_Xy is not None
|
assert valid_Xy is not None
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from pyspark.ml.evaluation import (
|
|||||||
BinaryClassificationEvaluator,
|
BinaryClassificationEvaluator,
|
||||||
MulticlassClassificationEvaluator,
|
MulticlassClassificationEvaluator,
|
||||||
)
|
)
|
||||||
|
from pyspark.ml.feature import VectorAssembler
|
||||||
from pyspark.ml.functions import vector_to_array
|
from pyspark.ml.functions import vector_to_array
|
||||||
from pyspark.ml.linalg import Vectors
|
from pyspark.ml.linalg import Vectors
|
||||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||||
@ -1058,3 +1059,65 @@ class XgboostLocalTest(SparkTestCase):
|
|||||||
|
|
||||||
for row in pred_result:
|
for row in pred_result:
|
||||||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
||||||
|
|
||||||
|
def test_empty_validation_data(self):
|
||||||
|
df_train = self.session.createDataFrame(
|
||||||
|
[
|
||||||
|
(Vectors.dense(10.1, 11.2, 11.3), 0, False),
|
||||||
|
(Vectors.dense(1, 1.2, 1.3), 1, False),
|
||||||
|
(Vectors.dense(14.0, 15.0, 16.0), 0, False),
|
||||||
|
(Vectors.dense(1.1, 1.2, 1.3), 1, True),
|
||||||
|
],
|
||||||
|
["features", "label", "val_col"],
|
||||||
|
)
|
||||||
|
classifier = SparkXGBClassifier(
|
||||||
|
num_workers=2,
|
||||||
|
min_child_weight=0.0,
|
||||||
|
reg_alpha=0,
|
||||||
|
reg_lambda=0,
|
||||||
|
validation_indicator_col="val_col",
|
||||||
|
)
|
||||||
|
model = classifier.fit(df_train)
|
||||||
|
pred_result = model.transform(df_train).collect()
|
||||||
|
for row in pred_result:
|
||||||
|
self.assertEqual(row.prediction, row.label)
|
||||||
|
|
||||||
|
def test_empty_train_data(self):
|
||||||
|
df_train = self.session.createDataFrame(
|
||||||
|
[
|
||||||
|
(Vectors.dense(10.1, 11.2, 11.3), 0, True),
|
||||||
|
(Vectors.dense(1, 1.2, 1.3), 1, True),
|
||||||
|
(Vectors.dense(14.0, 15.0, 16.0), 0, True),
|
||||||
|
(Vectors.dense(1.1, 1.2, 1.3), 1, False),
|
||||||
|
],
|
||||||
|
["features", "label", "val_col"],
|
||||||
|
)
|
||||||
|
classifier = SparkXGBClassifier(
|
||||||
|
num_workers=2,
|
||||||
|
min_child_weight=0.0,
|
||||||
|
reg_alpha=0,
|
||||||
|
reg_lambda=0,
|
||||||
|
validation_indicator_col="val_col",
|
||||||
|
)
|
||||||
|
model = classifier.fit(df_train)
|
||||||
|
pred_result = model.transform(df_train).collect()
|
||||||
|
for row in pred_result:
|
||||||
|
self.assertEqual(row.prediction, 1.0)
|
||||||
|
|
||||||
|
def test_empty_partition(self):
|
||||||
|
# raw_df.repartition(4) will result int severe data skew, actually,
|
||||||
|
# there is no any data in reducer partition 1, reducer partition 2
|
||||||
|
# see https://github.com/dmlc/xgboost/issues/8221
|
||||||
|
raw_df = self.session.range(0, 100, 1, 50).withColumn(
|
||||||
|
"label", spark_sql_func.when(spark_sql_func.rand(1) > 0.5, 1).otherwise(0)
|
||||||
|
)
|
||||||
|
vector_assembler = (
|
||||||
|
VectorAssembler().setInputCols(["id"]).setOutputCol("features")
|
||||||
|
)
|
||||||
|
data_trans = vector_assembler.setHandleInvalid("keep").transform(raw_df)
|
||||||
|
data_trans.show(100)
|
||||||
|
|
||||||
|
classifier = SparkXGBClassifier(
|
||||||
|
num_workers=4,
|
||||||
|
)
|
||||||
|
classifier.fit(data_trans)
|
||||||
|
|||||||
@ -102,7 +102,7 @@ class SparkTestCase(TestSparkContext, TestTempDir, unittest.TestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.setup_env(
|
cls.setup_env(
|
||||||
{
|
{
|
||||||
"spark.master": "local[2]",
|
"spark.master": "local[4]",
|
||||||
"spark.python.worker.reuse": "false",
|
"spark.python.worker.reuse": "false",
|
||||||
"spark.driver.host": "127.0.0.1",
|
"spark.driver.host": "127.0.0.1",
|
||||||
"spark.task.maxFailures": "1",
|
"spark.task.maxFailures": "1",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user