[pyspark] Refactor local tests. (#8525)

- Use pytest fixture for spark session.
- Replace hardcoded results.
This commit is contained in:
Jiaming Yuan 2022-12-05 23:49:54 +08:00 committed by GitHub
parent 42c5ee5588
commit e143a4dd7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,8 @@ import glob
import logging
import random
import uuid
from collections import namedtuple
from typing import Generator
import numpy as np
import pytest
@ -17,6 +19,7 @@ from pyspark.ml.feature import VectorAssembler
from pyspark.ml.functions import vector_to_array
from pyspark.ml.linalg import Vectors
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql import functions as spark_sql_func
from xgboost.spark import (
SparkXGBClassifier,
@ -34,6 +37,324 @@ from .utils import SparkTestCase
logging.getLogger("py4j").setLevel(logging.INFO)
@pytest.fixture
def spark() -> Generator[SparkSession, None, None]:
config = {
"spark.master": "local[4]",
"spark.python.worker.reuse": "false",
"spark.driver.host": "127.0.0.1",
"spark.task.maxFailures": "1",
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
}
builder = SparkSession.builder.appName("XGBoost PySpark Python API Tests")
for k, v in config.items():
builder.config(k, v)
logging.getLogger("pyspark").setLevel(logging.INFO)
sess = builder.getOrCreate()
yield sess
sess.stop()
sess.sparkContext.stop()
RegWithWeight = namedtuple(
"RegWithWeight",
(
"reg_params_with_eval",
"reg_df_train_with_eval_weight",
"reg_df_test_with_eval_weight",
"reg_with_eval_best_score",
"reg_with_eval_and_weight_best_score",
),
)
@pytest.fixture
def reg_with_weight(
spark: SparkSession,
) -> Generator[RegWithWeight, SparkSession, None]:
reg_params_with_eval = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "rmse",
}
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
w = np.array([1.0, 2.0, 1.0, 2.0])
y = np.array([0, 1, 2, 3])
reg1 = XGBRegressor()
reg1.fit(X, y, sample_weight=w)
predt1 = reg1.predict(X)
X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
y_train = np.array([0, 1])
y_val = np.array([2, 3])
w_train = np.array([1.0, 2.0])
w_val = np.array([1.0, 2.0])
reg2 = XGBRegressor(early_stopping_rounds=1, eval_metric="rmse")
reg2.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
)
predt2 = reg2.predict(X)
best_score2 = reg2.best_score
reg3 = XGBRegressor(early_stopping_rounds=1, eval_metric="rmse")
reg3.fit(
X_train,
y_train,
sample_weight=w_train,
eval_set=[(X_val, y_val)],
sample_weight_eval_set=[w_val],
)
predt3 = reg3.predict(X)
best_score3 = reg3.best_score
reg_df_train_with_eval_weight = spark.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
],
["features", "label", "isVal", "weight"],
)
reg_df_test_with_eval_weight = spark.createDataFrame(
[
(
Vectors.dense(1.0, 2.0, 3.0),
float(predt1[0]),
float(predt2[0]),
float(predt3[0]),
),
(
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
float(predt1[1]),
float(predt2[1]),
float(predt3[1]),
),
],
[
"features",
"expected_prediction_with_weight",
"expected_prediction_with_eval",
"expected_prediction_with_weight_and_eval",
],
)
yield RegWithWeight(
reg_params_with_eval,
reg_df_train_with_eval_weight,
reg_df_test_with_eval_weight,
best_score2,
best_score3,
)
ClfWithWeight = namedtuple(
"ClfWithWeight",
(
"cls_params_with_eval",
"cls_df_train_with_eval_weight",
"cls_df_test_with_eval_weight",
"cls_with_eval_best_score",
"cls_with_eval_and_weight_best_score",
),
)
@pytest.fixture
def clf_with_weight(
spark: SparkSession,
) -> Generator[ClfWithWeight, SparkSession, None]:
"""Test classifier with weight and eval set."""
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
w = np.array([1.0, 2.0, 1.0, 2.0])
y = np.array([0, 1, 0, 1])
cls1 = XGBClassifier()
cls1.fit(X, y, sample_weight=w)
X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
y_train = np.array([0, 1])
y_val = np.array([0, 1])
w_train = np.array([1.0, 2.0])
w_val = np.array([1.0, 2.0])
cls2 = XGBClassifier()
cls2.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
early_stopping_rounds=1,
eval_metric="logloss",
)
cls3 = XGBClassifier()
cls3.fit(
X_train,
y_train,
sample_weight=w_train,
eval_set=[(X_val, y_val)],
sample_weight_eval_set=[w_val],
early_stopping_rounds=1,
eval_metric="logloss",
)
cls_df_train_with_eval_weight = spark.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
],
["features", "label", "isVal", "weight"],
)
cls_params_with_eval = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "logloss",
}
cls_df_test_with_eval_weight = spark.createDataFrame(
[
(
Vectors.dense(1.0, 2.0, 3.0),
[float(p) for p in cls1.predict_proba(X)[0, :]],
[float(p) for p in cls2.predict_proba(X)[0, :]],
[float(p) for p in cls3.predict_proba(X)[0, :]],
),
],
[
"features",
"expected_prob_with_weight",
"expected_prob_with_eval",
"expected_prob_with_weight_and_eval",
],
)
cls_with_eval_best_score = cls2.best_score
cls_with_eval_and_weight_best_score = cls3.best_score
yield ClfWithWeight(
cls_params_with_eval,
cls_df_train_with_eval_weight,
cls_df_test_with_eval_weight,
cls_with_eval_best_score,
cls_with_eval_and_weight_best_score,
)
class TestPySparkLocal:
def test_regressor_with_weight_eval(self, reg_with_weight: RegWithWeight) -> None:
# with weight
regressor_with_weight = SparkXGBRegressor(weight_col="weight")
model_with_weight = regressor_with_weight.fit(
reg_with_weight.reg_df_train_with_eval_weight
)
pred_result_with_weight = model_with_weight.transform(
reg_with_weight.reg_df_test_with_eval_weight
).collect()
for row in pred_result_with_weight:
assert np.isclose(
row.prediction, row.expected_prediction_with_weight, atol=1e-3
)
# with eval
regressor_with_eval = SparkXGBRegressor(**reg_with_weight.reg_params_with_eval)
model_with_eval = regressor_with_eval.fit(
reg_with_weight.reg_df_train_with_eval_weight
)
assert np.isclose(
model_with_eval._xgb_sklearn_model.best_score,
reg_with_weight.reg_with_eval_best_score,
atol=1e-3,
)
pred_result_with_eval = model_with_eval.transform(
reg_with_weight.reg_df_test_with_eval_weight
).collect()
for row in pred_result_with_eval:
np.testing.assert_allclose(
row.prediction, row.expected_prediction_with_eval, atol=1e-3
)
# with weight and eval
regressor_with_weight_eval = SparkXGBRegressor(
weight_col="weight", **reg_with_weight.reg_params_with_eval
)
model_with_weight_eval = regressor_with_weight_eval.fit(
reg_with_weight.reg_df_train_with_eval_weight
)
pred_result_with_weight_eval = model_with_weight_eval.transform(
reg_with_weight.reg_df_test_with_eval_weight
).collect()
np.testing.assert_allclose(
model_with_weight_eval._xgb_sklearn_model.best_score,
reg_with_weight.reg_with_eval_and_weight_best_score,
atol=1e-3,
)
for row in pred_result_with_weight_eval:
np.testing.assert_allclose(
row.prediction,
row.expected_prediction_with_weight_and_eval,
atol=1e-3,
)
def test_classifier_with_weight_eval(self, clf_with_weight: ClfWithWeight) -> None:
# with weight
classifier_with_weight = SparkXGBClassifier(weight_col="weight")
model_with_weight = classifier_with_weight.fit(
clf_with_weight.cls_df_train_with_eval_weight
)
pred_result_with_weight = model_with_weight.transform(
clf_with_weight.cls_df_test_with_eval_weight
).collect()
for row in pred_result_with_weight:
assert np.allclose(
row.probability, row.expected_prob_with_weight, atol=1e-3
)
# with eval
classifier_with_eval = SparkXGBClassifier(
**clf_with_weight.cls_params_with_eval
)
model_with_eval = classifier_with_eval.fit(
clf_with_weight.cls_df_train_with_eval_weight
)
assert np.isclose(
model_with_eval._xgb_sklearn_model.best_score,
clf_with_weight.cls_with_eval_best_score,
atol=1e-3,
)
pred_result_with_eval = model_with_eval.transform(
clf_with_weight.cls_df_test_with_eval_weight
).collect()
for row in pred_result_with_eval:
assert np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3)
# with weight and eval
classifier_with_weight_eval = SparkXGBClassifier(
weight_col="weight", **clf_with_weight.cls_params_with_eval
)
model_with_weight_eval = classifier_with_weight_eval.fit(
clf_with_weight.cls_df_train_with_eval_weight
)
pred_result_with_weight_eval = model_with_weight_eval.transform(
clf_with_weight.cls_df_test_with_eval_weight
).collect()
np.testing.assert_allclose(
model_with_weight_eval._xgb_sklearn_model.best_score,
clf_with_weight.cls_with_eval_and_weight_best_score,
atol=1e-3,
)
for row in pred_result_with_weight_eval:
np.testing.assert_allclose( # failed
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
)
class XgboostLocalTest(SparkTestCase):
def setUp(self):
logging.getLogger().setLevel("INFO")
@ -167,130 +488,6 @@ class XgboostLocalTest(SparkTestCase):
["features", "expected_probability"],
)
# Test regressor with weight and eval set
# >>> import numpy as np
# >>> import xgboost
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
# >>> y = np.array([0, 1, 2, 3])
# >>> reg1 = xgboost.XGBRegressor()
# >>> reg1.fit(X, y, sample_weight=w)
# >>> reg1.predict(X)
# >>> array([1.0679445e-03, 1.0000550e+00, ...
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> y_train = np.array([0, 1])
# >>> y_val = np.array([2, 3])
# >>> w_train = np.array([1.0, 2.0])
# >>> w_val = np.array([1.0, 2.0])
# >>> reg2 = xgboost.XGBRegressor()
# >>> reg2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
# >>> early_stopping_rounds=1, eval_metric='rmse')
# >>> reg2.predict(X)
# >>> array([8.8370638e-04, 9.9911624e-01, ...
# >>> reg2.best_score
# 2.0000002682208837
# >>> reg3 = xgboost.XGBRegressor()
# >>> reg3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
# >>> sample_weight_eval_set=[w_val],
# >>> early_stopping_rounds=1, eval_metric='rmse')
# >>> reg3.predict(X)
# >>> array([0.03155671, 0.98874104,...
# >>> reg3.best_score
# 1.9970891552124017
self.reg_df_train_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
],
["features", "label", "isVal", "weight"],
)
self.reg_params_with_eval = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "rmse",
}
self.reg_df_test_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0.001068, 0.00088, 0.03155),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.000055, 0.9991, 0.9887),
],
[
"features",
"expected_prediction_with_weight",
"expected_prediction_with_eval",
"expected_prediction_with_weight_and_eval",
],
)
self.reg_with_eval_best_score = 2.0
self.reg_with_eval_and_weight_best_score = 1.997
# Test classifier with weight and eval set
# >>> import numpy as np
# >>> import xgboost
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
# >>> y = np.array([0, 1, 0, 1])
# >>> cls1 = xgboost.XGBClassifier()
# >>> cls1.fit(X, y, sample_weight=w)
# >>> cls1.predict_proba(X)
# array([[0.3333333, 0.6666667],...
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
# >>> y_train = np.array([0, 1])
# >>> y_val = np.array([0, 1])
# >>> w_train = np.array([1.0, 2.0])
# >>> w_val = np.array([1.0, 2.0])
# >>> cls2 = xgboost.XGBClassifier()
# >>> cls2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
# >>> early_stopping_rounds=1, eval_metric='logloss')
# >>> cls2.predict_proba(X)
# array([[0.5, 0.5],...
# >>> cls2.best_score
# 0.6931
# >>> cls3 = xgboost.XGBClassifier()
# >>> cls3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
# >>> sample_weight_eval_set=[w_val],
# >>> early_stopping_rounds=1, eval_metric='logloss')
# >>> cls3.predict_proba(X)
# array([[0.3344962, 0.6655038],...
# >>> cls3.best_score
# 0.6365
self.cls_df_train_with_eval_weight = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
],
["features", "label", "isVal", "weight"],
)
self.cls_params_with_eval = {
"validation_indicator_col": "isVal",
"early_stopping_rounds": 1,
"eval_metric": "logloss",
}
self.cls_df_test_with_eval_weight = self.session.createDataFrame(
[
(
Vectors.dense(1.0, 2.0, 3.0),
[0.3333, 0.6666],
[0.5, 0.5],
[0.3097, 0.6903],
),
],
[
"features",
"expected_prob_with_weight",
"expected_prob_with_eval",
"expected_prob_with_weight_and_eval",
],
)
self.cls_with_eval_best_score = 0.6931
self.cls_with_eval_and_weight_best_score = 0.6378
# Test classifier with both base margin and without
# >>> import numpy as np
# >>> import xgboost
@ -790,96 +987,6 @@ class XgboostLocalTest(SparkTestCase):
row.probability, row.expected_prob_with_base_margin, atol=1e-3
)
def test_regressor_with_weight_eval(self):
# with weight
regressor_with_weight = SparkXGBRegressor(weight_col="weight")
model_with_weight = regressor_with_weight.fit(
self.reg_df_train_with_eval_weight
)
pred_result_with_weight = model_with_weight.transform(
self.reg_df_test_with_eval_weight
).collect()
for row in pred_result_with_weight:
assert np.isclose(
row.prediction, row.expected_prediction_with_weight, atol=1e-3
)
# with eval
regressor_with_eval = SparkXGBRegressor(**self.reg_params_with_eval)
model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight)
assert np.isclose(
model_with_eval._xgb_sklearn_model.best_score,
self.reg_with_eval_best_score,
atol=1e-3,
), (
f"Expected best score: {self.reg_with_eval_best_score}, but ",
f"get {model_with_eval._xgb_sklearn_model.best_score}",
)
pred_result_with_eval = model_with_eval.transform(
self.reg_df_test_with_eval_weight
).collect()
for row in pred_result_with_eval:
self.assertTrue(
np.isclose(
row.prediction, row.expected_prediction_with_eval, atol=1e-3
),
f"Expect prediction is {row.expected_prediction_with_eval},"
f"but get {row.prediction}",
)
# with weight and eval
regressor_with_weight_eval = SparkXGBRegressor(
weight_col="weight", **self.reg_params_with_eval
)
model_with_weight_eval = regressor_with_weight_eval.fit(
self.reg_df_train_with_eval_weight
)
pred_result_with_weight_eval = model_with_weight_eval.transform(
self.reg_df_test_with_eval_weight
).collect()
self.assertTrue(
np.isclose(
model_with_weight_eval._xgb_sklearn_model.best_score,
self.reg_with_eval_and_weight_best_score,
atol=1e-3,
)
)
for row in pred_result_with_weight_eval:
self.assertTrue(
np.isclose(
row.prediction,
row.expected_prediction_with_weight_and_eval,
atol=1e-3,
)
)
def test_classifier_with_weight_eval(self):
# with weight and eval
# Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which
# doesn't really indicate this working correctly.
classifier_with_weight_eval = SparkXGBClassifier(
weight_col="weight", scale_pos_weight=4, **self.cls_params_with_eval
)
model_with_weight_eval = classifier_with_weight_eval.fit(
self.cls_df_train_with_eval_weight
)
pred_result_with_weight_eval = model_with_weight_eval.transform(
self.cls_df_test_with_eval_weight
).collect()
self.assertTrue(
np.isclose(
model_with_weight_eval._xgb_sklearn_model.best_score,
self.cls_with_eval_and_weight_best_score,
atol=1e-3,
)
)
for row in pred_result_with_weight_eval:
self.assertTrue(
np.allclose(
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
)
)
def test_num_workers_param(self):
regressor = SparkXGBRegressor(num_workers=-1)
self.assertRaises(ValueError, regressor._validate_params)