[pyspark] make the model saved by pyspark compatible (#8219)
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import glob
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
@@ -7,6 +8,8 @@ import numpy as np
|
||||
import pytest
|
||||
import testing as tm
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
if tm.no_spark()["condition"]:
|
||||
pytest.skip(msg=tm.no_spark()["reason"], allow_module_level=True)
|
||||
if sys.platform.startswith("win") or sys.platform.startswith("darwin"):
|
||||
@@ -31,7 +34,7 @@ from xgboost.spark import (
|
||||
)
|
||||
from xgboost.spark.core import _non_booster_params
|
||||
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
from xgboost import XGBClassifier, XGBModel, XGBRegressor
|
||||
|
||||
from .utils import SparkTestCase
|
||||
|
||||
@@ -63,7 +66,12 @@ class XgboostLocalTest(SparkTestCase):
|
||||
# >>> reg2.fit(X, y)
|
||||
# >>> reg2.predict(X, ntree_limit=5)
|
||||
# array([0.22185266, 0.77814734], dtype=float32)
|
||||
self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5}
|
||||
self.reg_params = {
|
||||
"max_depth": 5,
|
||||
"n_estimators": 10,
|
||||
"ntree_limit": 5,
|
||||
"max_bin": 9,
|
||||
}
|
||||
self.reg_df_train = self.session.createDataFrame(
|
||||
[
|
||||
(Vectors.dense(1.0, 2.0, 3.0), 0),
|
||||
@@ -428,6 +436,12 @@ class XgboostLocalTest(SparkTestCase):
|
||||
def get_local_tmp_dir(self):
|
||||
return self.tempdir + str(uuid.uuid4())
|
||||
|
||||
def assert_model_compatible(self, model: XGBModel, model_path: str):
|
||||
bst = xgb.Booster()
|
||||
path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0]
|
||||
bst.load_model(path)
|
||||
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json"))
|
||||
|
||||
def test_regressor_params_basic(self):
|
||||
py_reg = SparkXGBRegressor()
|
||||
self.assertTrue(hasattr(py_reg, "n_estimators"))
|
||||
@@ -592,7 +606,8 @@ class XgboostLocalTest(SparkTestCase):
|
||||
)
|
||||
|
||||
def test_regressor_model_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
tmp_dir = self.get_local_tmp_dir()
|
||||
path = "file:" + tmp_dir
|
||||
regressor = SparkXGBRegressor(**self.reg_params)
|
||||
model = regressor.fit(self.reg_df_train)
|
||||
model.save(path)
|
||||
@@ -612,8 +627,11 @@ class XgboostLocalTest(SparkTestCase):
|
||||
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
||||
SparkXGBClassifierModel.load(path)
|
||||
|
||||
self.assert_model_compatible(model, tmp_dir)
|
||||
|
||||
def test_classifier_model_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
tmp_dir = self.get_local_tmp_dir()
|
||||
path = "file:" + tmp_dir
|
||||
regressor = SparkXGBClassifier(**self.cls_params)
|
||||
model = regressor.fit(self.cls_df_train)
|
||||
model.save(path)
|
||||
@@ -633,12 +651,15 @@ class XgboostLocalTest(SparkTestCase):
|
||||
with self.assertRaisesRegex(AssertionError, "Expected class name"):
|
||||
SparkXGBRegressorModel.load(path)
|
||||
|
||||
self.assert_model_compatible(model, tmp_dir)
|
||||
|
||||
@staticmethod
|
||||
def _get_params_map(params_kv, estimator):
|
||||
return {getattr(estimator, k): v for k, v in params_kv.items()}
|
||||
|
||||
def test_regressor_model_pipeline_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
tmp_dir = self.get_local_tmp_dir()
|
||||
path = "file:" + tmp_dir
|
||||
regressor = SparkXGBRegressor()
|
||||
pipeline = Pipeline(stages=[regressor])
|
||||
pipeline = pipeline.copy(extra=self._get_params_map(self.reg_params, regressor))
|
||||
@@ -656,9 +677,11 @@ class XgboostLocalTest(SparkTestCase):
|
||||
row.prediction, row.expected_prediction_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
self.assert_model_compatible(model.stages[0], tmp_dir)
|
||||
|
||||
def test_classifier_model_pipeline_save_load(self):
|
||||
path = "file:" + self.get_local_tmp_dir()
|
||||
tmp_dir = self.get_local_tmp_dir()
|
||||
path = "file:" + tmp_dir
|
||||
classifier = SparkXGBClassifier()
|
||||
pipeline = Pipeline(stages=[classifier])
|
||||
pipeline = pipeline.copy(
|
||||
@@ -678,6 +701,7 @@ class XgboostLocalTest(SparkTestCase):
|
||||
row.probability, row.expected_probability_with_params, atol=1e-3
|
||||
)
|
||||
)
|
||||
self.assert_model_compatible(model.stages[0], tmp_dir)
|
||||
|
||||
def test_classifier_with_cross_validator(self):
|
||||
xgb_classifer = SparkXGBClassifier()
|
||||
|
||||
Reference in New Issue
Block a user