[CI] Skip pyspark sparse tests. (#8675)
This commit is contained in:
parent
b2b6a8aa39
commit
e27cda7626
@ -39,7 +39,8 @@ dependencies:
|
||||
- cloudpickle
|
||||
- shap>=0.41
|
||||
- modin
|
||||
# TODO: Replace it with pyspark>=3.4 once 3.4 released.
|
||||
# - https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz
|
||||
- pyspark>=3.3.1
|
||||
- pip:
|
||||
- datatable
|
||||
# TODO: Replace it with pyspark>=3.4 once 3.4 released.
|
||||
- https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz
|
||||
|
||||
@ -39,6 +39,16 @@ from .utils import SparkTestCase
|
||||
logging.getLogger("py4j").setLevel(logging.INFO)
|
||||
|
||||
|
||||
def no_sparse_unwrap() -> tm.PytestSkip:
|
||||
try:
|
||||
from pyspark.sql.functions import unwrap_udt
|
||||
|
||||
except ImportError:
|
||||
return {"reason": "PySpark<3.4", "condition": True}
|
||||
|
||||
return {"reason": "PySpark<3.4", "condition": False}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spark() -> Generator[SparkSession, None, None]:
|
||||
config = {
|
||||
@ -1205,6 +1215,7 @@ class XgboostLocalTest(SparkTestCase):
|
||||
np.isclose(row.prediction, row.expected_prediction, atol=1e-3)
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(**no_sparse_unwrap())
|
||||
def test_regressor_with_sparse_optim(self):
|
||||
regressor = SparkXGBRegressor(missing=0.0)
|
||||
model = regressor.fit(self.reg_df_sparse_train)
|
||||
@ -1221,6 +1232,7 @@ class XgboostLocalTest(SparkTestCase):
|
||||
for row1, row2 in zip(pred_result, pred_result2):
|
||||
self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3))
|
||||
|
||||
@pytest.mark.skipif(**no_sparse_unwrap())
|
||||
def test_classifier_with_sparse_optim(self):
|
||||
cls = SparkXGBClassifier(missing=0.0)
|
||||
model = cls.fit(self.cls_df_sparse_train)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user