[pyspark] Use quantile dmatrix. (#8284)
This commit is contained in:
@@ -1047,67 +1047,79 @@ class XgboostLocalTest(SparkTestCase):
|
||||
for row in pred_result:
|
||||
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)
|
||||
|
||||
def test_empty_validation_data(self):
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, False),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, False),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, False),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, True),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, row.label)
|
||||
def test_empty_validation_data(self) -> None:
|
||||
for tree_method in [
|
||||
"hist",
|
||||
"approx",
|
||||
]: # pytest.mark conflict with python unittest
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, False),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, False),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, False),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, True),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
tree_method=tree_method,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, row.label)
|
||||
|
||||
def test_empty_train_data(self):
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, True),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, True),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, True),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, False),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
self.assertEqual(row.prediction, 1.0)
|
||||
def test_empty_train_data(self) -> None:
|
||||
for tree_method in [
|
||||
"hist",
|
||||
"approx",
|
||||
]: # pytest.mark conflict with python unittest
|
||||
df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(10.1, 11.2, 11.3), 0, True),
|
||||
(Vectors.dense(1, 1.2, 1.3), 1, True),
|
||||
(Vectors.dense(14.0, 15.0, 16.0), 0, True),
|
||||
(Vectors.dense(1.1, 1.2, 1.3), 1, False),
|
||||
],
|
||||
["features", "label", "val_col"],
|
||||
)
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=2,
|
||||
min_child_weight=0.0,
|
||||
reg_alpha=0,
|
||||
reg_lambda=0,
|
||||
tree_method=tree_method,
|
||||
validation_indicator_col="val_col",
|
||||
)
|
||||
model = classifier.fit(df_train)
|
||||
pred_result = model.transform(df_train).collect()
|
||||
for row in pred_result:
|
||||
assert row.prediction == 1.0
|
||||
|
||||
def test_empty_partition(self):
|
||||
# raw_df.repartition(4) will result int severe data skew, actually,
|
||||
# there is no any data in reducer partition 1, reducer partition 2
|
||||
# see https://github.com/dmlc/xgboost/issues/8221
|
||||
raw_df = self.session.range(0, 100, 1, 50).withColumn(
|
||||
"label", spark_sql_func.when(spark_sql_func.rand(1) > 0.5, 1).otherwise(0)
|
||||
)
|
||||
vector_assembler = (
|
||||
VectorAssembler().setInputCols(["id"]).setOutputCol("features")
|
||||
)
|
||||
data_trans = vector_assembler.setHandleInvalid("keep").transform(raw_df)
|
||||
data_trans.show(100)
|
||||
for tree_method in [
|
||||
"hist",
|
||||
"approx",
|
||||
]: # pytest.mark conflict with python unittest
|
||||
raw_df = self.session.range(0, 100, 1, 50).withColumn(
|
||||
"label",
|
||||
spark_sql_func.when(spark_sql_func.rand(1) > 0.5, 1).otherwise(0),
|
||||
)
|
||||
vector_assembler = (
|
||||
VectorAssembler().setInputCols(["id"]).setOutputCol("features")
|
||||
)
|
||||
data_trans = vector_assembler.setHandleInvalid("keep").transform(raw_df)
|
||||
|
||||
classifier = SparkXGBClassifier(
|
||||
num_workers=4,
|
||||
)
|
||||
classifier.fit(data_trans)
|
||||
classifier = SparkXGBClassifier(num_workers=4, tree_method=tree_method)
|
||||
classifier.fit(data_trans)
|
||||
|
||||
def test_early_stop_param_validation(self):
|
||||
classifier = SparkXGBClassifier(early_stopping_rounds=1)
|
||||
|
||||
Reference in New Issue
Block a user