import glob import logging import random import tempfile import uuid from collections import namedtuple from typing import Generator, Sequence, Type import numpy as np import pytest from xgboost.spark.data import pred_contribs import xgboost as xgb from xgboost import testing as tm pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_spark())] from pyspark.ml import Pipeline, PipelineModel from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.feature import VectorAssembler from pyspark.ml.functions import vector_to_array from pyspark.ml.linalg import Vectors from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.sql import SparkSession from pyspark.sql import functions as spark_sql_func from xgboost.spark import ( SparkXGBClassifier, SparkXGBClassifierModel, SparkXGBRanker, SparkXGBRegressor, SparkXGBRegressorModel, ) from xgboost.spark.core import _non_booster_params from xgboost import XGBClassifier, XGBModel, XGBRegressor from .utils import SparkTestCase logging.getLogger("py4j").setLevel(logging.INFO) def no_sparse_unwrap() -> tm.PytestSkip: try: from pyspark.sql.functions import unwrap_udt except ImportError: return {"reason": "PySpark<3.4", "condition": True} return {"reason": "PySpark<3.4", "condition": False} @pytest.fixture def spark() -> Generator[SparkSession, None, None]: config = { "spark.master": "local[4]", "spark.python.worker.reuse": "false", "spark.driver.host": "127.0.0.1", "spark.task.maxFailures": "1", "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", "spark.sql.pyspark.jvmStacktrace.enabled": "true", } builder = SparkSession.builder.appName("XGBoost PySpark Python API Tests") for k, v in config.items(): builder.config(k, v) logging.getLogger("pyspark").setLevel(logging.INFO) sess = builder.getOrCreate() yield sess sess.stop() sess.sparkContext.stop() RegWithWeight = namedtuple( "RegWithWeight", ( "reg_params_with_eval", "reg_df_train_with_eval_weight", "reg_df_test_with_eval_weight", "reg_with_eval_best_score", "reg_with_eval_and_weight_best_score", ), ) @pytest.fixture def reg_with_weight( spark: SparkSession, ) -> Generator[RegWithWeight, SparkSession, None]: reg_params_with_eval = { "validation_indicator_col": "isVal", "early_stopping_rounds": 1, "eval_metric": "rmse", } X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) w = np.array([1.0, 2.0, 1.0, 2.0]) y = np.array([0, 1, 2, 3]) reg1 = XGBRegressor() reg1.fit(X, y, sample_weight=w) predt1 = reg1.predict(X) X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) y_train = np.array([0, 1]) y_val = np.array([2, 3]) w_train = np.array([1.0, 2.0]) w_val = np.array([1.0, 2.0]) reg2 = XGBRegressor(early_stopping_rounds=1, eval_metric="rmse") reg2.fit( X_train, y_train, eval_set=[(X_val, y_val)], ) predt2 = reg2.predict(X) best_score2 = reg2.best_score reg3 = XGBRegressor(early_stopping_rounds=1, eval_metric="rmse") reg3.fit( X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)], sample_weight_eval_set=[w_val], ) predt3 = reg3.predict(X) best_score3 = reg3.best_score reg_df_train_with_eval_weight = spark.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 2, True, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 3, True, 2.0), ], ["features", "label", "isVal", "weight"], ) reg_df_test_with_eval_weight = spark.createDataFrame( [ ( Vectors.dense(1.0, 2.0, 3.0), float(predt1[0]), float(predt2[0]), float(predt3[0]), ), ( Vectors.sparse(3, {1: 1.0, 2: 5.5}), float(predt1[1]), float(predt2[1]), float(predt3[1]), ), ], [ "features", "expected_prediction_with_weight", "expected_prediction_with_eval", "expected_prediction_with_weight_and_eval", ], ) yield RegWithWeight( reg_params_with_eval, reg_df_train_with_eval_weight, reg_df_test_with_eval_weight, best_score2, best_score3, ) RegData = namedtuple("RegData", ("reg_df_train", "reg_df_test")) @pytest.fixture def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) y = np.array([0, 1]) reg1 = xgb.XGBRegressor() reg1.fit(X, y) predt0 = reg1.predict(X) pred_contrib0: np.ndarray = pred_contribs(reg1, X, None, False) # convert np array to pyspark dataframe reg_df_train_data = [ (Vectors.dense(X[0, :]), int(y[0])), (Vectors.sparse(3, {1: float(X[1, 1]), 2: float(X[1, 2])}), int(y[1])), ] reg_df_train = spark.createDataFrame(reg_df_train_data, ["features", "label"]) reg_df_test = spark.createDataFrame( [ ( Vectors.dense(X[0, :]), float(predt0[0]), pred_contrib0[0, :].tolist(), ), ( Vectors.sparse(3, {1: 1.0, 2: 5.5}), float(predt0[1]), pred_contrib0[1, :].tolist(), ), ], [ "features", "expected_prediction", "expected_pred_contribs", ], ) yield RegData(reg_df_train, reg_df_test) MultiClfData = namedtuple("MultiClfData", ("multi_clf_df_train", "multi_clf_df_test")) @pytest.fixture def multi_clf_data(spark: SparkSession) -> Generator[MultiClfData, None, None]: X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]]) y = np.array([0, 0, 1, 2]) cls1 = xgb.XGBClassifier() cls1.fit(X, y) predt0 = cls1.predict(X) proba0: np.ndarray = cls1.predict_proba(X) pred_contrib0: np.ndarray = pred_contribs(cls1, X, None, False) # convert np array to pyspark dataframe multi_cls_df_train_data = [ (Vectors.dense(X[0, :]), int(y[0])), (Vectors.dense(X[1, :]), int(y[1])), (Vectors.sparse(3, {1: float(X[2, 1]), 2: float(X[2, 2])}), int(y[2])), (Vectors.dense(X[3, :]), int(y[3])), ] multi_clf_df_train = spark.createDataFrame( multi_cls_df_train_data, ["features", "label"] ) multi_clf_df_test = spark.createDataFrame( [ ( Vectors.dense(X[0, :]), float(predt0[0]), proba0[0, :].tolist(), pred_contrib0[0, :].tolist(), ), ( Vectors.dense(X[1, :]), float(predt0[1]), proba0[1, :].tolist(), pred_contrib0[1, :].tolist(), ), ( Vectors.sparse(3, {1: 1.0, 2: 5.5}), float(predt0[2]), proba0[2, :].tolist(), pred_contrib0[2, :].tolist(), ), ], [ "features", "expected_prediction", "expected_probability", "expected_pred_contribs", ], ) yield MultiClfData(multi_clf_df_train, multi_clf_df_test) ClfWithWeight = namedtuple( "ClfWithWeight", ( "cls_params_with_eval", "cls_df_train_with_eval_weight", "cls_df_test_with_eval_weight", "cls_with_eval_best_score", "cls_with_eval_and_weight_best_score", ), ) @pytest.fixture def clf_with_weight( spark: SparkSession, ) -> Generator[ClfWithWeight, SparkSession, None]: """Test classifier with weight and eval set.""" X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) w = np.array([1.0, 2.0, 1.0, 2.0]) y = np.array([0, 1, 0, 1]) cls1 = XGBClassifier() cls1.fit(X, y, sample_weight=w) X_train = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) X_val = np.array([[4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) y_train = np.array([0, 1]) y_val = np.array([0, 1]) w_train = np.array([1.0, 2.0]) w_val = np.array([1.0, 2.0]) cls2 = XGBClassifier() cls2.fit( X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=1, eval_metric="logloss", ) cls3 = XGBClassifier() cls3.fit( X_train, y_train, sample_weight=w_train, eval_set=[(X_val, y_val)], sample_weight_eval_set=[w_val], early_stopping_rounds=1, eval_metric="logloss", ) cls_df_train_with_eval_weight = spark.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, False, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, False, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 0, True, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, True, 2.0), ], ["features", "label", "isVal", "weight"], ) cls_params_with_eval = { "validation_indicator_col": "isVal", "early_stopping_rounds": 1, "eval_metric": "logloss", } cls_df_test_with_eval_weight = spark.createDataFrame( [ ( Vectors.dense(1.0, 2.0, 3.0), [float(p) for p in cls1.predict_proba(X)[0, :]], [float(p) for p in cls2.predict_proba(X)[0, :]], [float(p) for p in cls3.predict_proba(X)[0, :]], ), ], [ "features", "expected_prob_with_weight", "expected_prob_with_eval", "expected_prob_with_weight_and_eval", ], ) cls_with_eval_best_score = cls2.best_score cls_with_eval_and_weight_best_score = cls3.best_score yield ClfWithWeight( cls_params_with_eval, cls_df_train_with_eval_weight, cls_df_test_with_eval_weight, cls_with_eval_best_score, cls_with_eval_and_weight_best_score, ) ClfData = namedtuple( "ClfData", ("cls_params", "cls_df_train", "cls_df_train_large", "cls_df_test") ) @pytest.fixture def clf_data(spark: SparkSession) -> Generator[ClfData, None, None]: cls_params = {"max_depth": 5, "n_estimators": 10, "scale_pos_weight": 4} X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) y = np.array([0, 1]) cl1 = xgb.XGBClassifier() cl1.fit(X, y) predt0 = cl1.predict(X) proba0: np.ndarray = cl1.predict_proba(X) pred_contrib0: np.ndarray = pred_contribs(cl1, X, None, True) cl2 = xgb.XGBClassifier(**cls_params) cl2.fit(X, y) predt1 = cl2.predict(X) proba1: np.ndarray = cl2.predict_proba(X) pred_contrib1: np.ndarray = pred_contribs(cl2, X, None, True) # convert np array to pyspark dataframe cls_df_train_data = [ (Vectors.dense(X[0, :]), int(y[0])), (Vectors.sparse(3, {1: float(X[1, 1]), 2: float(X[1, 2])}), int(y[1])), ] cls_df_train = spark.createDataFrame(cls_df_train_data, ["features", "label"]) cls_df_train_large = spark.createDataFrame( cls_df_train_data * 100, ["features", "label"] ) cls_df_test = spark.createDataFrame( [ ( Vectors.dense(X[0, :]), int(predt0[0]), proba0[0, :].tolist(), pred_contrib0[0, :].tolist(), int(predt1[0]), proba1[0, :].tolist(), pred_contrib1[0, :].tolist(), ), ( Vectors.sparse(3, {1: 1.0, 2: 5.5}), int(predt0[1]), proba0[1, :].tolist(), pred_contrib0[1, :].tolist(), int(predt1[1]), proba1[1, :].tolist(), pred_contrib1[1, :].tolist(), ), ], [ "features", "expected_prediction", "expected_probability", "expected_pred_contribs", "expected_prediction_with_params", "expected_probability_with_params", "expected_pred_contribs_with_params", ], ) yield ClfData(cls_params, cls_df_train, cls_df_train_large, cls_df_test) def assert_model_compatible(model: XGBModel, model_path: str) -> None: bst = xgb.Booster() path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0] bst.load_model(path) np.testing.assert_equal( np.array(model.get_booster().save_raw("json")), np.array(bst.save_raw("json")) ) def check_sub_dict_match( sub_dist: dict, whole_dict: dict, excluding_keys: Sequence[str] ) -> None: for k in sub_dist: if k not in excluding_keys: assert k in whole_dict, f"check on {k} failed" assert sub_dist[k] == whole_dict[k], f"check on {k} failed" def get_params_map(params_kv: dict, estimator: Type) -> dict: return {getattr(estimator, k): v for k, v in params_kv.items()} class TestPySparkLocal: def test_regressor_basic(self, reg_data: RegData) -> None: regressor = SparkXGBRegressor(pred_contrib_col="pred_contribs") model = regressor.fit(reg_data.reg_df_train) pred_result = model.transform(reg_data.reg_df_test).collect() for row in pred_result: np.testing.assert_equal(row.prediction, row.expected_prediction) np.testing.assert_allclose( row.pred_contribs, row.expected_pred_contribs, atol=1e-3 ) def test_regressor_with_weight_eval(self, reg_with_weight: RegWithWeight) -> None: # with weight regressor_with_weight = SparkXGBRegressor(weight_col="weight") model_with_weight = regressor_with_weight.fit( reg_with_weight.reg_df_train_with_eval_weight ) pred_result_with_weight = model_with_weight.transform( reg_with_weight.reg_df_test_with_eval_weight ).collect() for row in pred_result_with_weight: assert np.isclose( row.prediction, row.expected_prediction_with_weight, atol=1e-3 ) # with eval regressor_with_eval = SparkXGBRegressor(**reg_with_weight.reg_params_with_eval) model_with_eval = regressor_with_eval.fit( reg_with_weight.reg_df_train_with_eval_weight ) assert np.isclose( model_with_eval._xgb_sklearn_model.best_score, reg_with_weight.reg_with_eval_best_score, atol=1e-3, ) pred_result_with_eval = model_with_eval.transform( reg_with_weight.reg_df_test_with_eval_weight ).collect() for row in pred_result_with_eval: np.testing.assert_allclose( row.prediction, row.expected_prediction_with_eval, atol=1e-3 ) # with weight and eval regressor_with_weight_eval = SparkXGBRegressor( weight_col="weight", **reg_with_weight.reg_params_with_eval ) model_with_weight_eval = regressor_with_weight_eval.fit( reg_with_weight.reg_df_train_with_eval_weight ) pred_result_with_weight_eval = model_with_weight_eval.transform( reg_with_weight.reg_df_test_with_eval_weight ).collect() np.testing.assert_allclose( model_with_weight_eval._xgb_sklearn_model.best_score, reg_with_weight.reg_with_eval_and_weight_best_score, atol=1e-3, ) for row in pred_result_with_weight_eval: np.testing.assert_allclose( row.prediction, row.expected_prediction_with_weight_and_eval, atol=1e-3, ) def test_multi_classifier_basic(self, multi_clf_data: MultiClfData) -> None: cls = SparkXGBClassifier(pred_contrib_col="pred_contribs") model = cls.fit(multi_clf_data.multi_clf_df_train) pred_result = model.transform(multi_clf_data.multi_clf_df_test).collect() for row in pred_result: np.testing.assert_equal(row.prediction, row.expected_prediction) np.testing.assert_allclose( row.probability, row.expected_probability, rtol=1e-3 ) np.testing.assert_allclose( row.pred_contribs, row.expected_pred_contribs, atol=1e-3 ) def test_classifier_with_weight_eval(self, clf_with_weight: ClfWithWeight) -> None: # with weight classifier_with_weight = SparkXGBClassifier(weight_col="weight") model_with_weight = classifier_with_weight.fit( clf_with_weight.cls_df_train_with_eval_weight ) pred_result_with_weight = model_with_weight.transform( clf_with_weight.cls_df_test_with_eval_weight ).collect() for row in pred_result_with_weight: assert np.allclose( row.probability, row.expected_prob_with_weight, atol=1e-3 ) # with eval classifier_with_eval = SparkXGBClassifier( **clf_with_weight.cls_params_with_eval ) model_with_eval = classifier_with_eval.fit( clf_with_weight.cls_df_train_with_eval_weight ) assert np.isclose( model_with_eval._xgb_sklearn_model.best_score, clf_with_weight.cls_with_eval_best_score, atol=1e-3, ) pred_result_with_eval = model_with_eval.transform( clf_with_weight.cls_df_test_with_eval_weight ).collect() for row in pred_result_with_eval: assert np.allclose(row.probability, row.expected_prob_with_eval, atol=1e-3) # with weight and eval classifier_with_weight_eval = SparkXGBClassifier( weight_col="weight", **clf_with_weight.cls_params_with_eval ) model_with_weight_eval = classifier_with_weight_eval.fit( clf_with_weight.cls_df_train_with_eval_weight ) pred_result_with_weight_eval = model_with_weight_eval.transform( clf_with_weight.cls_df_test_with_eval_weight ).collect() np.testing.assert_allclose( model_with_weight_eval._xgb_sklearn_model.best_score, clf_with_weight.cls_with_eval_and_weight_best_score, atol=1e-3, ) for row in pred_result_with_weight_eval: np.testing.assert_allclose( row.probability, row.expected_prob_with_weight_and_eval, atol=1e-3 ) def test_classifier_model_save_load(self, clf_data: ClfData) -> None: with tempfile.TemporaryDirectory() as tmpdir: path = "file:" + tmpdir clf = SparkXGBClassifier(**clf_data.cls_params) model = clf.fit(clf_data.cls_df_train) model.save(path) loaded_model = SparkXGBClassifierModel.load(path) assert model.uid == loaded_model.uid for k, v in clf_data.cls_params.items(): assert loaded_model.getOrDefault(k) == v pred_result = loaded_model.transform(clf_data.cls_df_test).collect() for row in pred_result: np.testing.assert_allclose( row.probability, row.expected_probability_with_params, atol=1e-3 ) with pytest.raises(AssertionError, match="Expected class name"): SparkXGBRegressorModel.load(path) assert_model_compatible(model, tmpdir) def test_classifier_basic(self, clf_data: ClfData) -> None: classifier = SparkXGBClassifier( **clf_data.cls_params, pred_contrib_col="pred_contrib" ) model = classifier.fit(clf_data.cls_df_train) pred_result = model.transform(clf_data.cls_df_test).collect() for row in pred_result: np.testing.assert_equal(row.prediction, row.expected_prediction_with_params) np.testing.assert_allclose( row.probability, row.expected_probability_with_params, rtol=1e-3 ) np.testing.assert_equal( row.pred_contrib, row.expected_pred_contribs_with_params ) def test_classifier_with_params(self, clf_data: ClfData) -> None: classifier = SparkXGBClassifier(**clf_data.cls_params) all_params = dict( **(classifier._gen_xgb_params_dict()), **(classifier._gen_fit_params_dict()), **(classifier._gen_predict_params_dict()), ) check_sub_dict_match( clf_data.cls_params, all_params, excluding_keys=_non_booster_params ) model = classifier.fit(clf_data.cls_df_train) all_params = dict( **(model._gen_xgb_params_dict()), **(model._gen_fit_params_dict()), **(model._gen_predict_params_dict()), ) check_sub_dict_match( clf_data.cls_params, all_params, excluding_keys=_non_booster_params ) pred_result = model.transform(clf_data.cls_df_test).collect() for row in pred_result: np.testing.assert_equal(row.prediction, row.expected_prediction_with_params) np.testing.assert_allclose( row.probability, row.expected_probability_with_params, rtol=1e-3 ) def test_classifier_model_pipeline_save_load(self, clf_data: ClfData) -> None: with tempfile.TemporaryDirectory() as tmpdir: path = "file:" + tmpdir classifier = SparkXGBClassifier() pipeline = Pipeline(stages=[classifier]) pipeline = pipeline.copy( extra=get_params_map(clf_data.cls_params, classifier) ) model = pipeline.fit(clf_data.cls_df_train) model.save(path) loaded_model = PipelineModel.load(path) for k, v in clf_data.cls_params.items(): assert loaded_model.stages[0].getOrDefault(k) == v pred_result = loaded_model.transform(clf_data.cls_df_test).collect() for row in pred_result: np.testing.assert_allclose( row.probability, row.expected_probability_with_params, atol=1e-3 ) assert_model_compatible(model.stages[0], tmpdir) def test_classifier_with_cross_validator(self, clf_data: ClfData) -> None: xgb_classifer = SparkXGBClassifier(n_estimators=1) paramMaps = ParamGridBuilder().addGrid(xgb_classifer.max_depth, [1, 2]).build() cvBin = CrossValidator( estimator=xgb_classifer, estimatorParamMaps=paramMaps, evaluator=BinaryClassificationEvaluator(), seed=1, parallelism=4, numFolds=2, ) cvBinModel = cvBin.fit(clf_data.cls_df_train_large) cvBinModel.transform(clf_data.cls_df_test) def test_convert_to_sklearn_model_clf(self, clf_data: ClfData) -> None: classifier = SparkXGBClassifier( n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5 ) clf_model = classifier.fit(clf_data.cls_df_train) # Check that regardless of what booster, _convert_to_model converts to the # correct class type sklearn_classifier = classifier._convert_to_sklearn_model( clf_model.get_booster().save_raw("json"), clf_model.get_booster().save_config(), ) assert isinstance(sklearn_classifier, XGBClassifier) assert sklearn_classifier.n_estimators == 200 assert sklearn_classifier.missing == 2.0 assert sklearn_classifier.max_depth == 3 assert sklearn_classifier.get_params()["sketch_eps"] == 0.5 def test_classifier_array_col_as_feature(self, clf_data: ClfData) -> None: train_dataset = clf_data.cls_df_train.withColumn( "features", vector_to_array(spark_sql_func.col("features")) ) test_dataset = clf_data.cls_df_test.withColumn( "features", vector_to_array(spark_sql_func.col("features")) ) classifier = SparkXGBClassifier() model = classifier.fit(train_dataset) pred_result = model.transform(test_dataset).collect() for row in pred_result: np.testing.assert_equal(row.prediction, row.expected_prediction) np.testing.assert_allclose( row.probability, row.expected_probability, rtol=1e-3 ) def test_classifier_with_feature_names_types_weights( self, clf_data: ClfData ) -> None: classifier = SparkXGBClassifier( feature_names=["a1", "a2", "a3"], feature_types=["i", "int", "float"], feature_weights=[2.0, 5.0, 3.0], ) model = classifier.fit(clf_data.cls_df_train) model.transform(clf_data.cls_df_test).collect() def test_early_stop_param_validation(self, clf_data: ClfData) -> None: classifier = SparkXGBClassifier(early_stopping_rounds=1) with pytest.raises(ValueError, match="early_stopping_rounds"): classifier.fit(clf_data.cls_df_train) def test_gpu_param_setting(self, clf_data: ClfData) -> None: py_cls = SparkXGBClassifier(use_gpu=True) train_params = py_cls._get_distributed_train_params(clf_data.cls_df_train) assert train_params["tree_method"] == "gpu_hist" def test_classifier_with_list_eval_metric(self, clf_data: ClfData) -> None: classifier = SparkXGBClassifier(eval_metric=["auc", "rmse"]) model = classifier.fit(clf_data.cls_df_train) model.transform(clf_data.cls_df_test).collect() def test_classifier_with_string_eval_metric(self, clf_data: ClfData) -> None: classifier = SparkXGBClassifier(eval_metric="auc") model = classifier.fit(clf_data.cls_df_train) model.transform(clf_data.cls_df_test).collect() class XgboostLocalTest(SparkTestCase): def setUp(self): logging.getLogger().setLevel("INFO") random.seed(2020) # The following code use xgboost python library to train xgb model and predict. # # >>> import numpy as np # >>> import xgboost # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) # >>> y = np.array([0, 1]) # >>> reg1 = xgboost.XGBRegressor() # >>> reg1.fit(X, y) # >>> reg1.predict(X) # array([8.8375784e-04, 9.9911624e-01], dtype=float32) # >>> def custom_lr(boosting_round): # ... return 1.0 / (boosting_round + 1) # ... # >>> reg1.fit(X, y, callbacks=[xgboost.callback.LearningRateScheduler(custom_lr)]) # >>> reg1.predict(X) # array([0.02406844, 0.9759315 ], dtype=float32) # >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10) # >>> reg2.fit(X, y) # >>> reg2.predict(X, ntree_limit=5) # array([0.22185266, 0.77814734], dtype=float32) self.reg_params = { "max_depth": 5, "n_estimators": 10, "ntree_limit": 5, "max_bin": 9, } self.reg_df_train = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), ], ["features", "label"], ) self.reg_df_test = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759), ], [ "features", "expected_prediction", "expected_prediction_with_params", "expected_prediction_with_callbacks", ], ) # kwargs test (using the above data, train, we get the same results) self.cls_params_kwargs = {"tree_method": "approx", "sketch_eps": 0.03} # >>> X = np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0], [0.0, 1.0, 5.5], [-1.0, -2.0, 1.0]]) # >>> y = np.array([0, 0, 1, 2]) # >>> cl = xgboost.XGBClassifier() # >>> cl.fit(X, y) # >>> cl.predict_proba(np.array([[1.0, 2.0, 3.0]])) # array([[0.5374299 , 0.23128504, 0.23128504]], dtype=float32) # Test classifier with both base margin and without # >>> import numpy as np # >>> import xgboost # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5], [4.0, 5.0, 6.0], [0.0, 6.0, 7.5]]) # >>> w = np.array([1.0, 2.0, 1.0, 2.0]) # >>> y = np.array([0, 1, 0, 1]) # >>> base_margin = np.array([1,0,0,1]) # # This is without the base margin # >>> cls1 = xgboost.XGBClassifier() # >>> cls1.fit(X, y, sample_weight=w) # >>> cls1.predict_proba(np.array([[1.0, 2.0, 3.0]])) # array([[0.3333333, 0.6666667]], dtype=float32) # >>> cls1.predict(np.array([[1.0, 2.0, 3.0]])) # array([1]) # # This is with the same base margin for predict # >>> cls2 = xgboost.XGBClassifier() # >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin) # >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[0]) # array([[0.44142532, 0.5585747 ]], dtype=float32) # >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0]) # array([1]) # # This is with a different base margin for predict # # >>> cls2 = xgboost.XGBClassifier() # >>> cls2.fit(X, y, sample_weight=w, base_margin=base_margin) # >>> cls2.predict_proba(np.array([[1.0, 2.0, 3.0]]), base_margin=[1]) # array([[0.2252, 0.7747 ]], dtype=float32) # >>> cls2.predict(np.array([[1.0, 2.0, 3.0]]), base_margin=[0]) # array([1]) self.cls_df_train_without_base_margin = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0), (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0), ], ["features", "label", "weight"], ) self.cls_df_test_without_base_margin = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), [0.3333, 0.6666], 1), ], [ "features", "expected_prob_without_base_margin", "expected_prediction_without_base_margin", ], ) self.cls_df_train_with_same_base_margin = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0), (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), ], ["features", "label", "weight", "base_margin"], ) self.cls_df_test_with_same_base_margin = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, [0.4415, 0.5585], 1), ], [ "features", "base_margin", "expected_prob_with_base_margin", "expected_prediction_with_base_margin", ], ) self.cls_df_train_with_different_base_margin = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, 1.0, 1), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1, 2.0, 0), (Vectors.dense(4.0, 5.0, 6.0), 0, 1.0, 0), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 2.0, 1), ], ["features", "label", "weight", "base_margin"], ) self.cls_df_test_with_different_base_margin = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 1, [0.2252, 0.7747], 1), ], [ "features", "base_margin", "expected_prob_with_base_margin", "expected_prediction_with_base_margin", ], ) self.reg_df_sparse_train = self.session.createDataFrame( [ (Vectors.dense(1.0, 0.0, 3.0, 0.0, 0.0), 0), (Vectors.sparse(5, {1: 1.0, 3: 5.5}), 1), (Vectors.sparse(5, {4: -3.0}), 2), ] * 10, ["features", "label"], ) self.cls_df_sparse_train = self.session.createDataFrame( [ (Vectors.dense(1.0, 0.0, 3.0, 0.0, 0.0), 0), (Vectors.sparse(5, {1: 1.0, 3: 5.5}), 1), (Vectors.sparse(5, {4: -3.0}), 0), ] * 10, ["features", "label"], ) def get_local_tmp_dir(self): return self.tempdir + str(uuid.uuid4()) def assert_model_compatible(self, model: XGBModel, model_path: str): bst = xgb.Booster() path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0] bst.load_model(path) self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) def test_convert_to_sklearn_model_reg(self) -> None: regressor = SparkXGBRegressor( n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5 ) reg_model = regressor.fit(self.reg_df_train) sklearn_regressor = regressor._convert_to_sklearn_model( reg_model.get_booster().save_raw("json"), reg_model.get_booster().save_config(), ) assert isinstance(sklearn_regressor, XGBRegressor) assert sklearn_regressor.n_estimators == 200 assert sklearn_regressor.missing == 2.0 assert sklearn_regressor.max_depth == 3 assert sklearn_regressor.get_params()["sketch_eps"] == 0.5 def test_regressor_params_basic(self): py_reg = SparkXGBRegressor() self.assertTrue(hasattr(py_reg, "n_estimators")) self.assertEqual(py_reg.n_estimators.parent, py_reg.uid) self.assertFalse(hasattr(py_reg, "gpu_id")) self.assertEqual(py_reg.getOrDefault(py_reg.n_estimators), 100) self.assertEqual(py_reg.getOrDefault(py_reg.objective), "reg:squarederror") py_reg2 = SparkXGBRegressor(n_estimators=200) self.assertEqual(py_reg2.getOrDefault(py_reg2.n_estimators), 200) py_reg3 = py_reg2.copy({py_reg2.max_depth: 10}) self.assertEqual(py_reg3.getOrDefault(py_reg3.n_estimators), 200) self.assertEqual(py_reg3.getOrDefault(py_reg3.max_depth), 10) def test_classifier_params_basic(self): py_cls = SparkXGBClassifier() self.assertTrue(hasattr(py_cls, "n_estimators")) self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) self.assertFalse(hasattr(py_cls, "gpu_id")) self.assertEqual(py_cls.getOrDefault(py_cls.n_estimators), 100) self.assertEqual(py_cls.getOrDefault(py_cls.objective), None) py_cls2 = SparkXGBClassifier(n_estimators=200) self.assertEqual(py_cls2.getOrDefault(py_cls2.n_estimators), 200) py_cls3 = py_cls2.copy({py_cls2.max_depth: 10}) self.assertEqual(py_cls3.getOrDefault(py_cls3.n_estimators), 200) self.assertEqual(py_cls3.getOrDefault(py_cls3.max_depth), 10) def test_classifier_kwargs_basic(self): py_cls = SparkXGBClassifier(**self.cls_params_kwargs) self.assertTrue(hasattr(py_cls, "n_estimators")) self.assertEqual(py_cls.n_estimators.parent, py_cls.uid) self.assertFalse(hasattr(py_cls, "gpu_id")) self.assertTrue(hasattr(py_cls, "arbitrary_params_dict")) expected_kwargs = {"sketch_eps": 0.03} self.assertEqual( py_cls.getOrDefault(py_cls.arbitrary_params_dict), expected_kwargs ) # Testing overwritten params py_cls = SparkXGBClassifier() py_cls.setParams(x=1, y=2) py_cls.setParams(y=3, z=4) xgb_params = py_cls._gen_xgb_params_dict() assert xgb_params["x"] == 1 assert xgb_params["y"] == 3 assert xgb_params["z"] == 4 def test_param_alias(self): py_cls = SparkXGBClassifier(features_col="f1", label_col="l1") self.assertEqual(py_cls.getOrDefault(py_cls.featuresCol), "f1") self.assertEqual(py_cls.getOrDefault(py_cls.labelCol), "l1") with pytest.raises( ValueError, match="Please use param name features_col instead" ): SparkXGBClassifier(featuresCol="f1") @staticmethod def test_param_value_converter(): py_cls = SparkXGBClassifier(missing=np.float64(1.0), sketch_eps=np.float64(0.3)) # don't check by isintance(v, float) because for numpy scalar it will also return True assert py_cls.getOrDefault(py_cls.missing).__class__.__name__ == "float" assert ( py_cls.getOrDefault(py_cls.arbitrary_params_dict)[ "sketch_eps" ].__class__.__name__ == "float64" ) def test_regressor_with_params(self): regressor = SparkXGBRegressor(**self.reg_params) all_params = dict( **(regressor._gen_xgb_params_dict()), **(regressor._gen_fit_params_dict()), **(regressor._gen_predict_params_dict()), ) check_sub_dict_match( self.reg_params, all_params, excluding_keys=_non_booster_params ) model = regressor.fit(self.reg_df_train) all_params = dict( **(model._gen_xgb_params_dict()), **(model._gen_fit_params_dict()), **(model._gen_predict_params_dict()), ) check_sub_dict_match( self.reg_params, all_params, excluding_keys=_non_booster_params ) pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( np.isclose( row.prediction, row.expected_prediction_with_params, atol=1e-3 ) ) def test_regressor_model_save_load(self): tmp_dir = self.get_local_tmp_dir() path = "file:" + tmp_dir regressor = SparkXGBRegressor(**self.reg_params) model = regressor.fit(self.reg_df_train) model.save(path) loaded_model = SparkXGBRegressorModel.load(path) self.assertEqual(model.uid, loaded_model.uid) for k, v in self.reg_params.items(): self.assertEqual(loaded_model.getOrDefault(k), v) pred_result = loaded_model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( np.isclose( row.prediction, row.expected_prediction_with_params, atol=1e-3 ) ) with self.assertRaisesRegex(AssertionError, "Expected class name"): SparkXGBClassifierModel.load(path) self.assert_model_compatible(model, tmp_dir) def test_regressor_model_pipeline_save_load(self): tmp_dir = self.get_local_tmp_dir() path = "file:" + tmp_dir regressor = SparkXGBRegressor() pipeline = Pipeline(stages=[regressor]) pipeline = pipeline.copy(extra=get_params_map(self.reg_params, regressor)) model = pipeline.fit(self.reg_df_train) model.save(path) loaded_model = PipelineModel.load(path) for k, v in self.reg_params.items(): self.assertEqual(loaded_model.stages[0].getOrDefault(k), v) pred_result = loaded_model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( np.isclose( row.prediction, row.expected_prediction_with_params, atol=1e-3 ) ) self.assert_model_compatible(model.stages[0], tmp_dir) def test_callbacks(self): from xgboost.callback import LearningRateScheduler path = self.get_local_tmp_dir() def custom_learning_rate(boosting_round): return 1.0 / (boosting_round + 1) cb = [LearningRateScheduler(custom_learning_rate)] regressor = SparkXGBRegressor(callbacks=cb) # Test the save/load of the estimator instead of the model, since # the callbacks param only exists in the estimator but not in the model regressor.save(path) regressor = SparkXGBRegressor.load(path) model = regressor.fit(self.reg_df_train) pred_result = model.transform(self.reg_df_test).collect() for row in pred_result: self.assertTrue( np.isclose( row.prediction, row.expected_prediction_with_callbacks, atol=1e-3 ) ) def test_train_with_initial_model(self): path = self.get_local_tmp_dir() reg1 = SparkXGBRegressor(**self.reg_params) model = reg1.fit(self.reg_df_train) init_booster = model.get_booster() reg2 = SparkXGBRegressor(max_depth=2, n_estimators=2, xgb_model=init_booster) model21 = reg2.fit(self.reg_df_train) pred_res21 = model21.transform(self.reg_df_test).collect() reg2.save(path) reg2 = SparkXGBRegressor.load(path) self.assertTrue(reg2.getOrDefault(reg2.xgb_model) is not None) model22 = reg2.fit(self.reg_df_train) pred_res22 = model22.transform(self.reg_df_test).collect() # Test the transform result is the same for original and loaded model for row1, row2 in zip(pred_res21, pred_res22): self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) def test_classifier_with_base_margin(self): cls_without_base_margin = SparkXGBClassifier(weight_col="weight") model_without_base_margin = cls_without_base_margin.fit( self.cls_df_train_without_base_margin ) pred_result_without_base_margin = model_without_base_margin.transform( self.cls_df_test_without_base_margin ).collect() for row in pred_result_without_base_margin: self.assertTrue( np.isclose( row.prediction, row.expected_prediction_without_base_margin, atol=1e-3, ) ) np.testing.assert_allclose( row.probability, row.expected_prob_without_base_margin, atol=1e-3 ) cls_with_same_base_margin = SparkXGBClassifier( weight_col="weight", base_margin_col="base_margin" ) model_with_same_base_margin = cls_with_same_base_margin.fit( self.cls_df_train_with_same_base_margin ) pred_result_with_same_base_margin = model_with_same_base_margin.transform( self.cls_df_test_with_same_base_margin ).collect() for row in pred_result_with_same_base_margin: self.assertTrue( np.isclose( row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 ) ) np.testing.assert_allclose( row.probability, row.expected_prob_with_base_margin, atol=1e-3 ) cls_with_different_base_margin = SparkXGBClassifier( weight_col="weight", base_margin_col="base_margin" ) model_with_different_base_margin = cls_with_different_base_margin.fit( self.cls_df_train_with_different_base_margin ) pred_result_with_different_base_margin = ( model_with_different_base_margin.transform( self.cls_df_test_with_different_base_margin ).collect() ) for row in pred_result_with_different_base_margin: self.assertTrue( np.isclose( row.prediction, row.expected_prediction_with_base_margin, atol=1e-3 ) ) np.testing.assert_allclose( row.probability, row.expected_prob_with_base_margin, atol=1e-3 ) def test_num_workers_param(self): regressor = SparkXGBRegressor(num_workers=-1) self.assertRaises(ValueError, regressor._validate_params) classifier = SparkXGBClassifier(num_workers=0) self.assertRaises(ValueError, classifier._validate_params) def test_use_gpu_param(self): classifier = SparkXGBClassifier(use_gpu=True, tree_method="exact") self.assertRaises(ValueError, classifier._validate_params) regressor = SparkXGBRegressor(use_gpu=True, tree_method="exact") self.assertRaises(ValueError, regressor._validate_params) regressor = SparkXGBRegressor(use_gpu=True, tree_method="gpu_hist") regressor = SparkXGBRegressor(use_gpu=True) classifier = SparkXGBClassifier(use_gpu=True, tree_method="gpu_hist") classifier = SparkXGBClassifier(use_gpu=True) def test_feature_importances(self): reg1 = SparkXGBRegressor(**self.reg_params) model = reg1.fit(self.reg_df_train) booster = model.get_booster() self.assertEqual(model.get_feature_importances(), booster.get_score()) self.assertEqual( model.get_feature_importances(importance_type="gain"), booster.get_score(importance_type="gain"), ) def test_regressor_array_col_as_feature(self): train_dataset = self.reg_df_train.withColumn( "features", vector_to_array(spark_sql_func.col("features")) ) test_dataset = self.reg_df_test.withColumn( "features", vector_to_array(spark_sql_func.col("features")) ) regressor = SparkXGBRegressor() model = regressor.fit(train_dataset) pred_result = model.transform(test_dataset).collect() for row in pred_result: self.assertTrue( np.isclose(row.prediction, row.expected_prediction, atol=1e-3) ) @pytest.mark.skipif(**no_sparse_unwrap()) def test_regressor_with_sparse_optim(self): regressor = SparkXGBRegressor(missing=0.0) model = regressor.fit(self.reg_df_sparse_train) assert model._xgb_sklearn_model.missing == 0.0 pred_result = model.transform(self.reg_df_sparse_train).collect() # enable sparse optimiaztion regressor2 = SparkXGBRegressor(missing=0.0, enable_sparse_data_optim=True) model2 = regressor2.fit(self.reg_df_sparse_train) assert model2.getOrDefault(model2.enable_sparse_data_optim) assert model2._xgb_sklearn_model.missing == 0.0 pred_result2 = model2.transform(self.reg_df_sparse_train).collect() for row1, row2 in zip(pred_result, pred_result2): self.assertTrue(np.isclose(row1.prediction, row2.prediction, atol=1e-3)) @pytest.mark.skipif(**no_sparse_unwrap()) def test_classifier_with_sparse_optim(self): cls = SparkXGBClassifier(missing=0.0) model = cls.fit(self.cls_df_sparse_train) assert model._xgb_sklearn_model.missing == 0.0 pred_result = model.transform(self.cls_df_sparse_train).collect() # enable sparse optimiaztion cls2 = SparkXGBClassifier(missing=0.0, enable_sparse_data_optim=True) model2 = cls2.fit(self.cls_df_sparse_train) assert model2.getOrDefault(model2.enable_sparse_data_optim) assert model2._xgb_sklearn_model.missing == 0.0 pred_result2 = model2.transform(self.cls_df_sparse_train).collect() for row1, row2 in zip(pred_result, pred_result2): self.assertTrue(np.allclose(row1.probability, row2.probability, rtol=1e-3)) def test_empty_validation_data(self) -> None: for tree_method in [ "hist", "approx", ]: # pytest.mark conflict with python unittest df_train = self.session.createDataFrame( [ (Vectors.dense(10.1, 11.2, 11.3), 0, False), (Vectors.dense(1, 1.2, 1.3), 1, False), (Vectors.dense(14.0, 15.0, 16.0), 0, False), (Vectors.dense(1.1, 1.2, 1.3), 1, True), ], ["features", "label", "val_col"], ) classifier = SparkXGBClassifier( num_workers=2, tree_method=tree_method, min_child_weight=0.0, reg_alpha=0, reg_lambda=0, validation_indicator_col="val_col", ) model = classifier.fit(df_train) pred_result = model.transform(df_train).collect() for row in pred_result: self.assertEqual(row.prediction, row.label) def test_empty_train_data(self) -> None: for tree_method in [ "hist", "approx", ]: # pytest.mark conflict with python unittest df_train = self.session.createDataFrame( [ (Vectors.dense(10.1, 11.2, 11.3), 0, True), (Vectors.dense(1, 1.2, 1.3), 1, True), (Vectors.dense(14.0, 15.0, 16.0), 0, True), (Vectors.dense(1.1, 1.2, 1.3), 1, False), ], ["features", "label", "val_col"], ) classifier = SparkXGBClassifier( num_workers=2, min_child_weight=0.0, reg_alpha=0, reg_lambda=0, tree_method=tree_method, validation_indicator_col="val_col", ) model = classifier.fit(df_train) pred_result = model.transform(df_train).collect() for row in pred_result: assert row.prediction == 1.0 def test_empty_partition(self): # raw_df.repartition(4) will result int severe data skew, actually, # there is no any data in reducer partition 1, reducer partition 2 # see https://github.com/dmlc/xgboost/issues/8221 for tree_method in [ "hist", "approx", ]: # pytest.mark conflict with python unittest raw_df = self.session.range(0, 100, 1, 50).withColumn( "label", spark_sql_func.when(spark_sql_func.rand(1) > 0.5, 1).otherwise(0), ) vector_assembler = ( VectorAssembler().setInputCols(["id"]).setOutputCol("features") ) data_trans = vector_assembler.setHandleInvalid("keep").transform(raw_df) classifier = SparkXGBClassifier(num_workers=4, tree_method=tree_method) classifier.fit(data_trans) def test_unsupported_params(self): with pytest.raises(ValueError, match="evals_result"): SparkXGBClassifier(evals_result={}) class XgboostRankerLocalTest(SparkTestCase): def setUp(self): self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") self.ranker_df_train = self.session.createDataFrame( [ (Vectors.dense(1.0, 2.0, 3.0), 0, 0), (Vectors.dense(4.0, 5.0, 6.0), 1, 0), (Vectors.dense(9.0, 4.0, 8.0), 2, 0), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), ], ["features", "label", "qid"], ) self.ranker_df_test = self.session.createDataFrame( [ (Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988), (Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556), (Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570), (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988), (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612), (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826), ], ["features", "qid", "expected_prediction"], ) self.ranker_df_train_1 = self.session.createDataFrame( [ (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9), (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9), (Vectors.dense(1.0, 2.0, 3.0), 0, 8), (Vectors.dense(4.0, 5.0, 6.0), 1, 8), (Vectors.dense(9.0, 4.0, 8.0), 2, 8), (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7), (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7), (Vectors.dense(1.0, 2.0, 3.0), 0, 6), (Vectors.dense(4.0, 5.0, 6.0), 1, 6), (Vectors.dense(9.0, 4.0, 8.0), 2, 6), ] * 4, ["features", "label", "qid"], ) def test_ranker(self): ranker = SparkXGBRanker(qid_col="qid") assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" model = ranker.fit(self.ranker_df_train) pred_result = model.transform(self.ranker_df_test).collect() for row in pred_result: assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3) def test_ranker_qid_sorted(self): ranker = SparkXGBRanker(qid_col="qid", num_workers=4) assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" model = ranker.fit(self.ranker_df_train_1) model.transform(self.ranker_df_test).collect()