[pyspark] Refactor local tests. (#8525)
- Use pytest fixture for spark session. - Replace hardcoded results.
This commit is contained in:
parent
42c5ee5588
commit
e143a4dd7e
@ -2,6 +2,8 @@ import glob
|
||||
import logging
|
||||
import random
|
||||
import uuid
|
||||
from collections import namedtuple
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -17,6 +19,7 @@ from pyspark.ml.feature import VectorAssembler
|
||||
from pyspark.ml.functions import vector_to_array
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.sql import functions as spark_sql_func
|
||||
from xgboost.spark import (
|
||||
SparkXGBClassifier,
|
||||
@ -34,6 +37,324 @@ from .utils import SparkTestCase
|
||||
logging.getLogger("py4j").setLevel(logging.INFO)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark() -> Generator[SparkSession, None, None]:
|
||||
config = {
|
||||
"spark.master": "local[4]",
|
||||
"spark.python.worker.reuse": "false",
|
||||
"spark.driver.host": "127.0.0.1",
|
||||
"spark.task.maxFailures": "1",
|
||||
"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
|
||||
"spark.sql.pyspark.jvmStacktrace.enabled": "true",
|
||||
}
|
||||
|
||||
builder = SparkSession.builder.appName("XGBoost PySpark Python API Tests")
|
||||
for k, v in config.items():
|
||||
builder.config(k, v)
|
||||
logging.getLogger("pyspark").setLevel(logging.INFO)
|
||||
sess = builder.getOrCreate()
|
||||
yield sess
|
||||
|
||||
sess.stop()
|
||||
sess.sparkContext.stop()
|
||||
|
||||
|
||||
RegWithWeight = namedtuple(
|
||||
"RegWithWeight",
|
||||
(
|
||||
"reg_params_with_eval",
|
||||
"reg_df_train_with_eval_weight",
|
||||
"reg_df_test_with_eval_weight",
|
||||
"reg_with_eval_best_score",
|
||||
"reg_with_eval_and_weight_best_score",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reg_with_weight(
|
||||
spark: SparkSession,
|
||||
) -> Generator[RegWithWeight, SparkSession, None]:
|
||||
reg_params_with_eval = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "rmse",
|
||||
}
|
||||
|
||||
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
w = np.array([1.0, 2.0, 1.0, 2.0])
|
||||
y = np.array([0, 1, 2, 3])
|
||||
|
||||
reg1 = XGBRegressor()
|
||||
reg1.fit(X, y, sample_weight=w)
|
||||
predt1 = reg1.predict(X)
|
||||
|
||||
X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
y_train = np.array([0, 1])
|
||||
y_val = np.array([2, 3])
|
||||
w_train = np.array([1.0, 2.0])
|
||||
w_val = np.array([1.0, 2.0])
|
||||
|
||||
reg2 = XGBRegressor(early_stopping_rounds=1, eval_metric="rmse")
|
||||
reg2.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
)
|
||||
predt2 = reg2.predict(X)
|
||||
best_score2 = reg2.best_score
|
||||
|
||||
reg3 = XGBRegressor(early_stopping_rounds=1, eval_metric="rmse")
|
||||
reg3.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
sample_weight=w_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
sample_weight_eval_set=[w_val],
|
||||
)
|
||||
predt3 = reg3.predict(X)
|
||||
best_score3 = reg3.best_score
|
||||
|
||||
reg_df_train_with_eval_weight = spark.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
|
||||
],
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
|
||||
reg_df_test_with_eval_weight = spark.createDataFrame(
|
||||
[
|
||||
(
|
||||
Vectors.dense(1.0, 2.0, 3.0),
|
||||
float(predt1[0]),
|
||||
float(predt2[0]),
|
||||
float(predt3[0]),
|
||||
),
|
||||
(
|
||||
Vectors.sparse(3, {1: 1.0, 2: 5.5}),
|
||||
float(predt1[1]),
|
||||
float(predt2[1]),
|
||||
float(predt3[1]),
|
||||
),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prediction_with_weight",
|
||||
"expected_prediction_with_eval",
|
||||
"expected_prediction_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
yield RegWithWeight(
|
||||
reg_params_with_eval,
|
||||
reg_df_train_with_eval_weight,
|
||||
reg_df_test_with_eval_weight,
|
||||
best_score2,
|
||||
best_score3,
|
||||
)
|
||||
|
||||
|
||||
ClfWithWeight = namedtuple(
|
||||
"ClfWithWeight",
|
||||
(
|
||||
"cls_params_with_eval",
|
||||
"cls_df_train_with_eval_weight",
|
||||
"cls_df_test_with_eval_weight",
|
||||
"cls_with_eval_best_score",
|
||||
"cls_with_eval_and_weight_best_score",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clf_with_weight(
|
||||
spark: SparkSession,
|
||||
) -> Generator[ClfWithWeight, SparkSession, None]:
|
||||
"""Test classifier with weight and eval set."""
|
||||
|
||||
X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
w = np.array([1.0, 2.0, 1.0, 2.0])
|
||||
y = np.array([0, 1, 0, 1])
|
||||
cls1 = XGBClassifier()
|
||||
cls1.fit(X, y, sample_weight=w)
|
||||
|
||||
X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
y_train = np.array([0, 1])
|
||||
y_val = np.array([0, 1])
|
||||
w_train = np.array([1.0, 2.0])
|
||||
w_val = np.array([1.0, 2.0])
|
||||
cls2 = XGBClassifier()
|
||||
cls2.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
early_stopping_rounds=1,
|
||||
eval_metric="logloss",
|
||||
)
|
||||
|
||||
cls3 = XGBClassifier()
|
||||
cls3.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
sample_weight=w_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
sample_weight_eval_set=[w_val],
|
||||
early_stopping_rounds=1,
|
||||
eval_metric="logloss",
|
||||
)
|
||||
|
||||
cls_df_train_with_eval_weight = spark.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
|
||||
],
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
cls_params_with_eval = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "logloss",
|
||||
}
|
||||
cls_df_test_with_eval_weight = spark.createDataFrame(
|
||||
[
|
||||
(
|
||||
Vectors.dense(1.0, 2.0, 3.0),
|
||||
[float(p) for p in cls1.predict_proba(X)[0, :]],
|
||||
[float(p) for p in cls2.predict_proba(X)[0, :]],
|
||||
[float(p) for p in cls3.predict_proba(X)[0, :]],
|
||||
),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prob_with_weight",
|
||||
"expected_prob_with_eval",
|
||||
"expected_prob_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
cls_with_eval_best_score = cls2.best_score
|
||||
cls_with_eval_and_weight_best_score = cls3.best_score
|
||||
yield ClfWithWeight(
|
||||
cls_params_with_eval,
|
||||
cls_df_train_with_eval_weight,
|
||||
cls_df_test_with_eval_weight,
|
||||
cls_with_eval_best_score,
|
||||
cls_with_eval_and_weight_best_score,
|
||||
)
|
||||
|
||||
|
||||
class TestPySparkLocal:
|
||||
def test_regressor_with_weight_eval(self, reg_with_weight: RegWithWeight) -> None:
|
||||
# with weight
|
||||
regressor_with_weight = SparkXGBRegressor(weight_col="weight")
|
||||
model_with_weight = regressor_with_weight.fit(
|
||||
reg_with_weight.reg_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight = model_with_weight.transform(
|
||||
reg_with_weight.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_weight:
|
||||
assert np.isclose(
|
||||
row.prediction, row.expected_prediction_with_weight, atol=1e-3
|
||||
)
|
||||
|
||||
# with eval
|
||||
regressor_with_eval = SparkXGBRegressor(**reg_with_weight.reg_params_with_eval)
|
||||
model_with_eval = regressor_with_eval.fit(
|
||||
reg_with_weight.reg_df_train_with_eval_weight
|
||||
)
|
||||
assert np.isclose(
|
||||
model_with_eval._xgb_sklearn_model.best_score,
|
||||
reg_with_weight.reg_with_eval_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
pred_result_with_eval = model_with_eval.transform(
|
||||
reg_with_weight.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_eval:
|
||||
np.testing.assert_allclose(
|
||||
row.prediction, row.expected_prediction_with_eval, atol=1e-3
|
||||
)
|
||||
# with weight and eval
|
||||
regressor_with_weight_eval = SparkXGBRegressor(
|
||||
weight_col="weight", **reg_with_weight.reg_params_with_eval
|
||||
)
|
||||
model_with_weight_eval = regressor_with_weight_eval.fit(
|
||||
reg_with_weight.reg_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight_eval = model_with_weight_eval.transform(
|
||||
reg_with_weight.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
np.testing.assert_allclose(
|
||||
model_with_weight_eval._xgb_sklearn_model.best_score,
|
||||
reg_with_weight.reg_with_eval_and_weight_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
for row in pred_result_with_weight_eval:
|
||||
np.testing.assert_allclose(
|
||||
row.prediction,
|
||||
row.expected_prediction_with_weight_and_eval,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
def test_classifier_with_weight_eval(self, clf_with_weight: ClfWithWeight) -> None:
|
||||
# with weight
|
||||
classifier_with_weight = SparkXGBClassifier(weight_col="weight")
|
||||
model_with_weight = classifier_with_weight.fit(
|
||||
clf_with_weight.cls_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight = model_with_weight.transform(
|
||||
clf_with_weight.cls_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_weight:
|
||||
assert np.allclose(
|
||||
row.probability, row.expected_prob_with_weight, atol=1e-3
|
||||
)
|
||||
# with eval
|
||||
classifier_with_eval = SparkXGBClassifier(
|
||||
**clf_with_weight.cls_params_with_eval
|
||||
)
|
||||
model_with_eval = classifier_with_eval.fit(
|
||||
clf_with_weight.cls_df_train_with_eval_weight
|
||||
)
|
||||
assert np.isclose(
|
||||
model_with_eval._xgb_sklearn_model.best_score,
|
||||
clf_with_weight.cls_with_eval_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
pred_result_with_eval = model_with_eval.transform(
|
||||
clf_with_weight.cls_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_eval:
|
||||
assert np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3)
|
||||
# with weight and eval
|
||||
classifier_with_weight_eval = SparkXGBClassifier(
|
||||
weight_col="weight", **clf_with_weight.cls_params_with_eval
|
||||
)
|
||||
model_with_weight_eval = classifier_with_weight_eval.fit(
|
||||
clf_with_weight.cls_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight_eval = model_with_weight_eval.transform(
|
||||
clf_with_weight.cls_df_test_with_eval_weight
|
||||
).collect()
|
||||
np.testing.assert_allclose(
|
||||
model_with_weight_eval._xgb_sklearn_model.best_score,
|
||||
clf_with_weight.cls_with_eval_and_weight_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
for row in pred_result_with_weight_eval:
|
||||
np.testing.assert_allclose( # failed
|
||||
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
|
||||
)
|
||||
|
||||
|
||||
class XgboostLocalTest(SparkTestCase):
|
||||
def setUp(self):
|
||||
logging.getLogger().setLevel("INFO")
|
||||
@ -167,130 +488,6 @@ class XgboostLocalTest(SparkTestCase):
|
||||
["features", "expected_probability"],
|
||||
)
|
||||
|
||||
# Test regressor with weight and eval set
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
|
||||
# >>> y = np.array([0, 1, 2, 3])
|
||||
# >>> reg1 = xgboost.XGBRegressor()
|
||||
# >>> reg1.fit(X, y, sample_weight=w)
|
||||
# >>> reg1.predict(X)
|
||||
# >>> array([1.0679445e-03, 1.0000550e+00, ...
|
||||
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> y_train = np.array([0, 1])
|
||||
# >>> y_val = np.array([2, 3])
|
||||
# >>> w_train = np.array([1.0, 2.0])
|
||||
# >>> w_val = np.array([1.0, 2.0])
|
||||
# >>> reg2 = xgboost.XGBRegressor()
|
||||
# >>> reg2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
|
||||
# >>> early_stopping_rounds=1, eval_metric='rmse')
|
||||
# >>> reg2.predict(X)
|
||||
# >>> array([8.8370638e-04, 9.9911624e-01, ...
|
||||
# >>> reg2.best_score
|
||||
# 2.0000002682208837
|
||||
# >>> reg3 = xgboost.XGBRegressor()
|
||||
# >>> reg3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
|
||||
# >>> sample_weight_eval_set=[w_val],
|
||||
# >>> early_stopping_rounds=1, eval_metric='rmse')
|
||||
# >>> reg3.predict(X)
|
||||
# >>> array([0.03155671, 0.98874104,...
|
||||
# >>> reg3.best_score
|
||||
# 1.9970891552124017
|
||||
self.reg_df_train_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0),
|
||||
],
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
self.reg_params_with_eval = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "rmse",
|
||||
}
|
||||
self.reg_df_test_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0.001068, 0.00088, 0.03155),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.000055, 0.9991, 0.9887),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prediction_with_weight",
|
||||
"expected_prediction_with_eval",
|
||||
"expected_prediction_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
self.reg_with_eval_best_score = 2.0
|
||||
self.reg_with_eval_and_weight_best_score = 1.997
|
||||
|
||||
# Test classifier with weight and eval set
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
# >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> w = np.array([1.0, 2.0, 1.0, 2.0])
|
||||
# >>> y = np.array([0, 1, 0, 1])
|
||||
# >>> cls1 = xgboost.XGBClassifier()
|
||||
# >>> cls1.fit(X, y, sample_weight=w)
|
||||
# >>> cls1.predict_proba(X)
|
||||
# array([[0.3333333, 0.6666667],...
|
||||
# >>> X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]])
|
||||
# >>> X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]])
|
||||
# >>> y_train = np.array([0, 1])
|
||||
# >>> y_val = np.array([0, 1])
|
||||
# >>> w_train = np.array([1.0, 2.0])
|
||||
# >>> w_val = np.array([1.0, 2.0])
|
||||
# >>> cls2 = xgboost.XGBClassifier()
|
||||
# >>> cls2.fit(X_train, y_train, eval_set=[(X_val, y_val)],
|
||||
# >>> early_stopping_rounds=1, eval_metric='logloss')
|
||||
# >>> cls2.predict_proba(X)
|
||||
# array([[0.5, 0.5],...
|
||||
# >>> cls2.best_score
|
||||
# 0.6931
|
||||
# >>> cls3 = xgboost.XGBClassifier()
|
||||
# >>> cls3.fit(X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)],
|
||||
# >>> sample_weight_eval_set=[w_val],
|
||||
# >>> early_stopping_rounds=1, eval_metric='logloss')
|
||||
# >>> cls3.predict_proba(X)
|
||||
# array([[0.3344962, 0.6655038],...
|
||||
# >>> cls3.best_score
|
||||
# 0.6365
|
||||
self.cls_df_train_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0),
|
||||
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0),
|
||||
(Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0),
|
||||
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0),
|
||||
],
|
||||
["features", "label", "isVal", "weight"],
|
||||
)
|
||||
self.cls_params_with_eval = {
|
||||
"validation_indicator_col": "isVal",
|
||||
"early_stopping_rounds": 1,
|
||||
"eval_metric": "logloss",
|
||||
}
|
||||
self.cls_df_test_with_eval_weight = self.session.createDataFrame(
|
||||
[
|
||||
(
|
||||
Vectors.dense(1.0, 2.0, 3.0),
|
||||
[0.3333, 0.6666],
|
||||
[0.5, 0.5],
|
||||
[0.3097, 0.6903],
|
||||
),
|
||||
],
|
||||
[
|
||||
"features",
|
||||
"expected_prob_with_weight",
|
||||
"expected_prob_with_eval",
|
||||
"expected_prob_with_weight_and_eval",
|
||||
],
|
||||
)
|
||||
self.cls_with_eval_best_score = 0.6931
|
||||
self.cls_with_eval_and_weight_best_score = 0.6378
|
||||
|
||||
# Test classifier with both base margin and without
|
||||
# >>> import numpy as np
|
||||
# >>> import xgboost
|
||||
@ -790,96 +987,6 @@ class XgboostLocalTest(SparkTestCase):
|
||||
row.probability, row.expected_prob_with_base_margin, atol=1e-3
|
||||
)
|
||||
|
||||
def test_regressor_with_weight_eval(self):
|
||||
# with weight
|
||||
regressor_with_weight = SparkXGBRegressor(weight_col="weight")
|
||||
model_with_weight = regressor_with_weight.fit(
|
||||
self.reg_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight = model_with_weight.transform(
|
||||
self.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_weight:
|
||||
assert np.isclose(
|
||||
row.prediction, row.expected_prediction_with_weight, atol=1e-3
|
||||
)
|
||||
|
||||
# with eval
|
||||
regressor_with_eval = SparkXGBRegressor(**self.reg_params_with_eval)
|
||||
model_with_eval = regressor_with_eval.fit(self.reg_df_train_with_eval_weight)
|
||||
assert np.isclose(
|
||||
model_with_eval._xgb_sklearn_model.best_score,
|
||||
self.reg_with_eval_best_score,
|
||||
atol=1e-3,
|
||||
), (
|
||||
f"Expected best score: {self.reg_with_eval_best_score}, but ",
|
||||
f"get {model_with_eval._xgb_sklearn_model.best_score}",
|
||||
)
|
||||
|
||||
pred_result_with_eval = model_with_eval.transform(
|
||||
self.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
for row in pred_result_with_eval:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction, row.expected_prediction_with_eval, atol=1e-3
|
||||
),
|
||||
f"Expect prediction is {row.expected_prediction_with_eval},"
|
||||
f"but get {row.prediction}",
|
||||
)
|
||||
# with weight and eval
|
||||
regressor_with_weight_eval = SparkXGBRegressor(
|
||||
weight_col="weight", **self.reg_params_with_eval
|
||||
)
|
||||
model_with_weight_eval = regressor_with_weight_eval.fit(
|
||||
self.reg_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight_eval = model_with_weight_eval.transform(
|
||||
self.reg_df_test_with_eval_weight
|
||||
).collect()
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
model_with_weight_eval._xgb_sklearn_model.best_score,
|
||||
self.reg_with_eval_and_weight_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
for row in pred_result_with_weight_eval:
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
row.prediction,
|
||||
row.expected_prediction_with_weight_and_eval,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
def test_classifier_with_weight_eval(self):
|
||||
# with weight and eval
|
||||
# Added scale_pos_weight because in 1.4.2, the original answer returns 0.5 which
|
||||
# doesn't really indicate this working correctly.
|
||||
classifier_with_weight_eval = SparkXGBClassifier(
|
||||
weight_col="weight", scale_pos_weight=4, **self.cls_params_with_eval
|
||||
)
|
||||
model_with_weight_eval = classifier_with_weight_eval.fit(
|
||||
self.cls_df_train_with_eval_weight
|
||||
)
|
||||
pred_result_with_weight_eval = model_with_weight_eval.transform(
|
||||
self.cls_df_test_with_eval_weight
|
||||
).collect()
|
||||
self.assertTrue(
|
||||
np.isclose(
|
||||
model_with_weight_eval._xgb_sklearn_model.best_score,
|
||||
self.cls_with_eval_and_weight_best_score,
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
for row in pred_result_with_weight_eval:
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_num_workers_param(self):
|
||||
regressor = SparkXGBRegressor(num_workers=-1)
|
||||
self.assertRaises(ValueError, regressor._validate_params)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user