diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 6569f7e3d..cc62b354d 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -23,13 +23,7 @@ from typing import ( import numpy from . import collective -from .core import ( - Booster, - DMatrix, - XGBoostError, - _get_booster_layer_trees, - _parse_eval_str, -) +from .core import Booster, DMatrix, XGBoostError, _parse_eval_str __all__ = [ "TrainingCallback", @@ -177,22 +171,14 @@ class CallbackContainer: assert isinstance(model, Booster), msg if not self.is_cv: - num_parallel_tree, _ = _get_booster_layer_trees(model) if model.attr("best_score") is not None: model.best_score = float(cast(str, model.attr("best_score"))) model.best_iteration = int(cast(str, model.attr("best_iteration"))) - # num_class is handled internally - model.set_attr( - best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree) - ) - model.best_ntree_limit = int(cast(str, model.attr("best_ntree_limit"))) else: # Due to compatibility with version older than 1.4, these attributes are # added to Python object even if early stopping is not used. model.best_iteration = model.num_boosted_rounds() - 1 model.set_attr(best_iteration=str(model.best_iteration)) - model.best_ntree_limit = (model.best_iteration + 1) * num_parallel_tree - model.set_attr(best_ntree_limit=str(model.best_ntree_limit)) return model diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 30aa771e3..68346d900 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -126,25 +126,6 @@ def _parse_eval_str(result: str) -> List[Tuple[str, float]]: IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int]) -def _convert_ntree_limit( - booster: "Booster", ntree_limit: Optional[int], iteration_range: IterRange -) -> IterRange: - if ntree_limit is not None and ntree_limit != 0: - warnings.warn( - "ntree_limit is deprecated, use `iteration_range` or model " - "slicing instead.", - UserWarning, - ) - if iteration_range is not None and iteration_range[1] != 0: - raise ValueError( - "Only one of `iteration_range` and `ntree_limit` can be non zero." - ) - num_parallel_tree, _ = _get_booster_layer_trees(booster) - num_parallel_tree = max([num_parallel_tree, 1]) - iteration_range = (0, ntree_limit // num_parallel_tree) - return iteration_range - - def _expect(expectations: Sequence[Type], got: Type) -> str: """Translate input error into string. @@ -1508,41 +1489,6 @@ Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] -def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]: - """Get number of trees added to booster per-iteration. This function will be removed - once `best_ntree_limit` is dropped in favor of `best_iteration`. Returns - `num_parallel_tree` and `num_groups`. - - """ - config = json.loads(model.save_config()) - booster = config["learner"]["gradient_booster"]["name"] - if booster == "gblinear": - num_parallel_tree = 0 - elif booster == "dart": - num_parallel_tree = int( - config["learner"]["gradient_booster"]["gbtree"]["gbtree_model_param"][ - "num_parallel_tree" - ] - ) - elif booster == "gbtree": - try: - num_parallel_tree = int( - config["learner"]["gradient_booster"]["gbtree_model_param"][ - "num_parallel_tree" - ] - ) - except KeyError: - num_parallel_tree = int( - config["learner"]["gradient_booster"]["gbtree_train_param"][ - "num_parallel_tree" - ] - ) - else: - raise ValueError(f"Unknown booster: {booster}") - num_groups = int(config["learner"]["learner_model_param"]["num_class"]) - return num_parallel_tree, num_groups - - def _configure_metrics(params: BoosterParam) -> BoosterParam: if ( isinstance(params, dict) @@ -1576,11 +1522,11 @@ class Booster: """ Parameters ---------- - params : dict + params : Parameters for boosters. - cache : list + cache : List of cache items. - model_file : string/os.PathLike/Booster/bytearray + model_file : Path to the model file if it's string or PathLike. """ cache = cache if cache is not None else [] @@ -2100,7 +2046,6 @@ class Booster: self, data: DMatrix, output_margin: bool = False, - ntree_limit: int = 0, pred_leaf: bool = False, pred_contribs: bool = False, approx_contribs: bool = False, @@ -2127,9 +2072,6 @@ class Booster: output_margin : Whether to output the raw untransformed margin value. - ntree_limit : - Deprecated, use `iteration_range` instead. - pred_leaf : When this option is on, the output will be a matrix of (nsample, ntrees) with each record indicating the predicted leaf index of @@ -2196,7 +2138,6 @@ class Booster: raise TypeError("Expecting data to be a DMatrix object, got: ", type(data)) if validate_features: self._validate_dmatrix_features(data) - iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range) args = { "type": 0, "training": training, @@ -2522,8 +2463,6 @@ class Booster: self.best_iteration = int(self.attr("best_iteration")) # type: ignore if self.attr("best_score") is not None: self.best_score = float(self.attr("best_score")) # type: ignore - if self.attr("best_ntree_limit") is not None: - self.best_ntree_limit = int(self.attr("best_ntree_limit")) # type: ignore def num_boosted_rounds(self) -> int: """Get number of boosted rounds. For gblinear this is reset to 0 after diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 8c679b75b..88bd1c819 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1653,14 +1653,11 @@ class DaskScikitLearnBase(XGBModel): self, X: _DataT, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[_DaskCollection] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." - assert ntree_limit is None, msg return self.client.sync( self._predict_async, X, @@ -1694,12 +1691,9 @@ class DaskScikitLearnBase(XGBModel): def apply( self, X: _DataT, - ntree_limit: Optional[int] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." - assert ntree_limit is None, msg return self.client.sync(self._apply_async, X, iteration_range=iteration_range) def __await__(self) -> Awaitable[Any]: @@ -1993,14 +1987,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa def predict_proba( self, X: _DaskCollection, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[_DaskCollection] = None, iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." - assert ntree_limit is None, msg return self._client_sync( self._predict_proba_async, X=X, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 563ff8659..fffc0eb9b 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -36,7 +36,6 @@ from .core import ( Objective, QuantileDMatrix, XGBoostError, - _convert_ntree_limit, _deprecate_positional_args, _parse_eval_str, ) @@ -391,8 +390,7 @@ __model_doc = f""" metric will be used for early stopping. - If early stopping occurs, the model will have three additional fields: - :py:attr:`best_score`, :py:attr:`best_iteration` and - :py:attr:`best_ntree_limit`. + :py:attr:`best_score`, :py:attr:`best_iteration`. .. note:: @@ -1117,7 +1115,6 @@ class XGBModel(XGBModelBase): self, X: ArrayLike, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -1135,8 +1132,6 @@ class XGBModel(XGBModelBase): Data to predict with. output_margin : Whether to output the raw untransformed margin value. - ntree_limit : - Deprecated, use `iteration_range` instead. validate_features : When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. @@ -1156,9 +1151,6 @@ class XGBModel(XGBModelBase): """ with config_context(verbosity=self.verbosity): - iteration_range = _convert_ntree_limit( - self.get_booster(), ntree_limit, iteration_range - ) iteration_range = self._get_iteration_range(iteration_range) if self._can_use_inplace_predict(): try: @@ -1197,7 +1189,6 @@ class XGBModel(XGBModelBase): def apply( self, X: ArrayLike, - ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None, ) -> np.ndarray: """Return the predicted leaf every tree for each sample. If the model is trained @@ -1211,9 +1202,6 @@ class XGBModel(XGBModelBase): iteration_range : See :py:meth:`predict`. - ntree_limit : - Deprecated, use ``iteration_range`` instead. - Returns ------- X_leaves : array_like, shape=[n_samples, n_trees] @@ -1223,9 +1211,6 @@ class XGBModel(XGBModelBase): """ with config_context(verbosity=self.verbosity): - iteration_range = _convert_ntree_limit( - self.get_booster(), ntree_limit, iteration_range - ) iteration_range = self._get_iteration_range(iteration_range) test_dmatrix = DMatrix( X, @@ -1309,10 +1294,6 @@ class XGBModel(XGBModelBase): """ return int(self._early_stopping_attr("best_iteration")) - @property - def best_ntree_limit(self) -> int: - return int(self._early_stopping_attr("best_ntree_limit")) - @property def feature_importances_(self) -> np.ndarray: """Feature importances property, return depends on `importance_type` @@ -1562,7 +1543,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): self, X: ArrayLike, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -1571,7 +1551,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): class_probs = super().predict( X=X, output_margin=output_margin, - ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, @@ -1599,7 +1578,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): def predict_proba( self, X: ArrayLike, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -1614,8 +1592,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): ---------- X : array_like Feature matrix. See :ref:`py-data` for a list of supported types. - ntree_limit : int - Deprecated, use `iteration_range` instead. validate_features : bool When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. @@ -1642,7 +1618,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): if self.objective == "multi:softmax": raw_predt = super().predict( X=X, - ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, @@ -1652,7 +1627,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase): return class_prob class_probs = super().predict( X=X, - ntree_limit=ntree_limit, validate_features=validate_features, base_margin=base_margin, iteration_range=iteration_range, @@ -2074,7 +2048,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn): self, X: ArrayLike, output_margin: bool = False, - ntree_limit: Optional[int] = None, validate_features: bool = True, base_margin: Optional[ArrayLike] = None, iteration_range: Optional[Tuple[int, int]] = None, @@ -2083,20 +2056,18 @@ class XGBRanker(XGBModel, XGBRankerMixIn): return super().predict( X, output_margin, - ntree_limit, validate_features, base_margin, - iteration_range, + iteration_range=iteration_range, ) def apply( self, X: ArrayLike, - ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None, ) -> ArrayLike: X, _ = _get_qid(X, None) - return super().apply(X, ntree_limit, iteration_range) + return super().apply(X, iteration_range) def score(self, X: ArrayLike, y: ArrayLike) -> float: """Evaluate score for data using the last evaluation metric. If the model is diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index 6e2d4c6db..f2c5e1197 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -11,7 +11,6 @@ from xgboost import DataIter, DMatrix, QuantileDMatrix, XGBModel from xgboost.compat import concat from .._typing import ArrayLike -from ..core import _convert_ntree_limit from .utils import get_logger # type: ignore @@ -343,8 +342,7 @@ def pred_contribs( strict_shape: bool = False, ) -> np.ndarray: """Predict contributions with data with the full model.""" - iteration_range = _convert_ntree_limit(model.get_booster(), None, None) - iteration_range = model._get_iteration_range(iteration_range) + iteration_range = model._get_iteration_range(None) data_dmatrix = DMatrix( data, base_margin=base_margin, diff --git a/tests/ci_build/conda_env/aarch64_test.yml b/tests/ci_build/conda_env/aarch64_test.yml index fe30eced1..42a2fe1e4 100644 --- a/tests/ci_build/conda_env/aarch64_test.yml +++ b/tests/ci_build/conda_env/aarch64_test.yml @@ -31,6 +31,5 @@ dependencies: - pyspark - cloudpickle - pip: - - shap - awscli - auditwheel diff --git a/tests/ci_build/conda_env/linux_cpu_test.yml b/tests/ci_build/conda_env/linux_cpu_test.yml index 7977abcd4..bf657708d 100644 --- a/tests/ci_build/conda_env/linux_cpu_test.yml +++ b/tests/ci_build/conda_env/linux_cpu_test.yml @@ -37,7 +37,6 @@ dependencies: - pyarrow - protobuf - cloudpickle -- shap>=0.41 - modin # TODO: Replace it with pyspark>=3.4 once 3.4 released. # - https://ml-team-public-read.s3.us-west-2.amazonaws.com/pyspark-3.4.0.dev0.tar.gz diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index d248e14df..00791e19d 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -146,6 +146,7 @@ def main(args: argparse.Namespace) -> None: "tests/python/test_config.py", "tests/python/test_data_iterator.py", "tests/python/test_dt.py", + "tests/python/test_predict.py", "tests/python/test_quantile_dmatrix.py", "tests/python/test_tree_regularization.py", "tests/python-gpu/test_gpu_data_iterator.py", diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 516cbd6cf..f9d6f37e1 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -64,7 +64,7 @@ class TestModels: num_round = 2 bst = xgb.train(param, dtrain, num_round, watchlist) # this is prediction - preds = bst.predict(dtest, ntree_limit=num_round) + preds = bst.predict(dtest, iteration_range=(0, num_round)) labels = dtest.get_label() err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) @@ -83,7 +83,7 @@ class TestModels: bst2 = xgb.Booster(params=param, model_file=model_path) dtest2 = xgb.DMatrix(dtest_path) - preds2 = bst2.predict(dtest2, ntree_limit=num_round) + preds2 = bst2.predict(dtest2, iteration_range=(0, num_round)) # assert they are the same assert np.sum(np.abs(preds2 - preds)) == 0 @@ -96,7 +96,7 @@ class TestModels: # check whether custom evaluation metrics work bst = xgb.train(param, dtrain, num_round, watchlist, feval=my_logloss) - preds3 = bst.predict(dtest, ntree_limit=num_round) + preds3 = bst.predict(dtest, iteration_range=(0, num_round)) assert all(preds3 == preds) # check whether sample_type and normalize_type work @@ -110,7 +110,7 @@ class TestModels: param['sample_type'] = p[0] param['normalize_type'] = p[1] bst = xgb.train(param, dtrain, num_round, watchlist) - preds = bst.predict(dtest, ntree_limit=num_round) + preds = bst.predict(dtest, iteration_range=(0, num_round)) err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) assert err < 0.1 @@ -472,8 +472,8 @@ class TestModels: X, y = load_iris(return_X_y=True) cls = xgb.XGBClassifier(n_estimators=2) cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)]) - assert cls.get_booster().best_ntree_limit == 2 - assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit + assert cls.get_booster().best_iteration == cls.n_estimators - 1 + assert cls.best_iteration == cls.get_booster().best_iteration with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "cls.json") @@ -481,8 +481,8 @@ class TestModels: cls = xgb.XGBClassifier(n_estimators=2) cls.load_model(path) - assert cls.get_booster().best_ntree_limit == 2 - assert cls.best_ntree_limit == cls.get_booster().best_ntree_limit + assert cls.get_booster().best_iteration == cls.n_estimators - 1 + assert cls.best_iteration == cls.get_booster().best_iteration def run_slice( self, diff --git a/tests/python/test_cli.py b/tests/python/test_cli.py index 69e8df83d..3d7415232 100644 --- a/tests/python/test_cli.py +++ b/tests/python/test_cli.py @@ -102,7 +102,6 @@ eval[test] = {data_path} booster.feature_names = None booster.feature_types = None booster.set_attr(best_iteration=None) - booster.set_attr(best_ntree_limit=None) booster.save_model(model_out_py) py_predt = booster.predict(data) diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index cb400df87..6f89edd16 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -1,4 +1,4 @@ -'''Tests for running inplace prediction.''' +"""Tests for running inplace prediction.""" from concurrent.futures import ThreadPoolExecutor import numpy as np @@ -17,10 +17,10 @@ def run_threaded_predict(X, rows, predict_func): per_thread = 20 with ThreadPoolExecutor(max_workers=10) as e: for i in range(0, rows, int(rows / per_thread)): - if hasattr(X, 'iloc'): - predictor = X.iloc[i:i+per_thread, :] + if hasattr(X, "iloc"): + predictor = X.iloc[i : i + per_thread, :] else: - predictor = X[i:i+per_thread, ...] + predictor = X[i : i + per_thread, ...] f = e.submit(predict_func, predictor) results.append(f) @@ -61,27 +61,31 @@ def run_predict_leaf(predictor): validate_leaf_output(leaf, num_parallel_tree) - ntree_limit = 2 + n_iters = 2 sliced = booster.predict( - m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit, strict_shape=True + m, + pred_leaf=True, + iteration_range=(0, n_iters), + strict_shape=True, ) first = sliced[0, ...] - assert np.prod(first.shape) == classes * num_parallel_tree * ntree_limit + assert np.prod(first.shape) == classes * num_parallel_tree * n_iters # When there's only 1 tree, the output is a 1 dim vector booster = xgb.train({"tree_method": "hist"}, num_boost_round=1, dtrain=m) - assert booster.predict(m, pred_leaf=True).shape == (rows, ) + assert booster.predict(m, pred_leaf=True).shape == (rows,) return leaf def test_predict_leaf(): - run_predict_leaf('cpu_predictor') + run_predict_leaf("cpu_predictor") def test_predict_shape(): from sklearn.datasets import fetch_california_housing + X, y = fetch_california_housing(return_X_y=True) reg = xgb.XGBRegressor(n_estimators=1) reg.fit(X, y) @@ -119,13 +123,14 @@ def test_predict_shape(): class TestInplacePredict: - '''Tests for running inplace prediction''' + """Tests for running inplace prediction""" + @classmethod def setup_class(cls): cls.rows = 1000 cls.cols = 10 - cls.missing = 11 # set to integer for testing + cls.missing = 11 # set to integer for testing cls.rng = np.random.RandomState(1994) @@ -139,7 +144,7 @@ class TestInplacePredict: cls.test = xgb.DMatrix(cls.X[:10, ...], missing=cls.missing) cls.num_boost_round = 10 - cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10) + cls.booster = xgb.train({"tree_method": "hist"}, dtrain, num_boost_round=10) def test_predict(self): booster = self.booster @@ -162,28 +167,22 @@ class TestInplacePredict: predt_from_array = booster.inplace_predict( X[:10, ...], iteration_range=(0, 4), missing=self.missing ) - predt_from_dmatrix = booster.predict(test, ntree_limit=4) + predt_from_dmatrix = booster.predict(test, iteration_range=(0, 4)) np.testing.assert_allclose(predt_from_dmatrix, predt_from_array) - with pytest.raises(ValueError): - booster.predict(test, ntree_limit=booster.best_ntree_limit + 1) with pytest.raises(ValueError): booster.predict(test, iteration_range=(0, booster.best_iteration + 2)) default = booster.predict(test) range_full = booster.predict(test, iteration_range=(0, self.num_boost_round)) - ntree_full = booster.predict(test, ntree_limit=self.num_boost_round) np.testing.assert_allclose(range_full, default) - np.testing.assert_allclose(ntree_full, default) range_full = booster.predict( test, iteration_range=(0, booster.best_iteration + 1) ) - ntree_full = booster.predict(test, ntree_limit=booster.best_ntree_limit) np.testing.assert_allclose(range_full, default) - np.testing.assert_allclose(ntree_full, default) def predict_dense(x): inplace_predt = booster.inplace_predict(x) @@ -251,6 +250,7 @@ class TestInplacePredict: @pytest.mark.skipif(**tm.no_pandas()) def test_pd_dtypes(self) -> None: from pandas.api.types import is_bool_dtype + for orig, x in pd_dtypes(): dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes] if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]): diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 30de920f7..088b681ff 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -60,7 +60,7 @@ def test_ranking_with_weighted_data(): assert all(p <= q for p, q in zip(auc_rec, auc_rec[1:])) for i in range(1, 11): - pred = bst.predict(dtrain, ntree_limit=i) + pred = bst.predict(dtrain, iteration_range=(0, i)) # is_sorted[i]: is i-th group correctly sorted by the ranking predictor? is_sorted = [] for k in range(0, 20, 5): diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index 258af760c..3ec1f1ffb 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -95,44 +95,39 @@ class TestTrainingContinuation: res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class)) assert res1 == res2 - gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, - num_boost_round=3) - assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + - 1) * self.num_parallel_tree - + gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, num_boost_round=3) res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) - res2 = mean_squared_error(y_2class, - gbdt_04.predict( - dtrain_2class, - ntree_limit=gbdt_04.best_ntree_limit)) + res2 = mean_squared_error( + y_2class, + gbdt_04.predict( + dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) + ) + ) assert res1 == res2 - gbdt_04 = xgb.train(xgb_params_02, dtrain_2class, - num_boost_round=7, xgb_model=gbdt_04) - assert gbdt_04.best_ntree_limit == ( - gbdt_04.best_iteration + 1) * self.num_parallel_tree - + gbdt_04 = xgb.train( + xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04 + ) res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) - res2 = mean_squared_error(y_2class, - gbdt_04.predict( - dtrain_2class, - ntree_limit=gbdt_04.best_ntree_limit)) + res2 = mean_squared_error( + y_2class, + gbdt_04.predict( + dtrain_2class, iteration_range=(0, gbdt_04.best_iteration + 1) + ) + ) assert res1 == res2 gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=7) - assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05) - assert gbdt_05.best_ntree_limit == ( - gbdt_05.best_iteration + 1) * self.num_parallel_tree res1 = gbdt_05.predict(dtrain_5class) - res2 = gbdt_05.predict(dtrain_5class, - ntree_limit=gbdt_05.best_ntree_limit) + res2 = gbdt_05.predict( + dtrain_5class, iteration_range=(0, gbdt_05.best_iteration + 1) + ) np.testing.assert_almost_equal(res1, res2) @pytest.mark.skipif(**tm.no_sklearn()) diff --git a/tests/python/test_with_shap.py b/tests/python/test_with_shap.py index eab98f487..63d0fd11b 100644 --- a/tests/python/test_with_shap.py +++ b/tests/python/test_with_shap.py @@ -13,9 +13,9 @@ except Exception: pytestmark = pytest.mark.skipif(shap is None, reason="Requires shap package") -# Check integration is not broken from xgboost side -# Changes in binary format may cause problems -def test_with_shap(): +# xgboost removed ntree_limit in 2.0, which breaks the SHAP package. +@pytest.mark.xfail +def test_with_shap() -> None: from sklearn.datasets import fetch_california_housing X, y = fetch_california_housing(return_X_y=True) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 90d4dff18..67620e6dd 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -63,9 +63,15 @@ def test_multiclass_classification(objective): assert xgb_model.get_booster().num_boosted_rounds() == 100 preds = xgb_model.predict(X[test_index]) # test other params in XGBClassifier().fit - preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3) - preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0) - preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3) + preds2 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=(0, 1) + ) + preds3 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=None + ) + preds4 = xgb_model.predict( + X[test_index], output_margin=False, iteration_range=(0, 1) + ) labels = y[test_index] check_pred(preds, labels, output_margin=False) @@ -86,25 +92,21 @@ def test_multiclass_classification(objective): assert proba.shape[1] == cls.n_classes_ -def test_best_ntree_limit(): +def test_best_iteration(): from sklearn.datasets import load_iris X, y = load_iris(return_X_y=True) - def train(booster, forest): + def train(booster: str, forest: Optional[int]) -> None: rounds = 4 cls = xgb.XGBClassifier( n_estimators=rounds, num_parallel_tree=forest, booster=booster ).fit( X, y, eval_set=[(X, y)], early_stopping_rounds=3 ) + assert cls.best_iteration == rounds - 1 - if forest: - assert cls.best_ntree_limit == rounds * forest - else: - assert cls.best_ntree_limit == 0 - - # best_ntree_limit is used by default, assert that under gblinear it's + # best_iteration is used by default, assert that under gblinear it's # automatically ignored due to being 0. cls.predict(X) @@ -430,12 +432,15 @@ def test_regression(): preds = xgb_model.predict(X[test_index]) # test other params in XGBRegressor().fit - preds2 = xgb_model.predict(X[test_index], output_margin=True, - ntree_limit=3) - preds3 = xgb_model.predict(X[test_index], output_margin=True, - ntree_limit=0) - preds4 = xgb_model.predict(X[test_index], output_margin=False, - ntree_limit=3) + preds2 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=(0, 3) + ) + preds3 = xgb_model.predict( + X[test_index], output_margin=True, iteration_range=None + ) + preds4 = xgb_model.predict( + X[test_index], output_margin=False, iteration_range=(0, 3) + ) labels = y[test_index] assert mean_squared_error(preds, labels) < 25 diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index a8c64713f..0ffdb2a2b 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -169,7 +169,7 @@ def reg_with_weight( ) -RegData = namedtuple("RegData", ("reg_df_train", "reg_df_test")) +RegData = namedtuple("RegData", ("reg_df_train", "reg_df_test", "reg_params")) @pytest.fixture @@ -181,6 +181,13 @@ def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: predt0 = reg1.predict(X) pred_contrib0: np.ndarray = pred_contribs(reg1, X, None, False) + reg_params = { + "max_depth": 5, + "n_estimators": 10, + "iteration_range": [0, 5], + "max_bin": 9, + } + # convert np array to pyspark dataframe reg_df_train_data = [ (Vectors.dense(X[0, :]), int(y[0])), @@ -188,26 +195,34 @@ def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: ] reg_df_train = spark.createDataFrame(reg_df_train_data, ["features", "label"]) + reg2 = xgb.XGBRegressor(max_depth=5, n_estimators=10) + reg2.fit(X, y) + predt2 = reg2.predict(X, iteration_range=[0, 5]) + # array([0.22185266, 0.77814734], dtype=float32) + reg_df_test = spark.createDataFrame( [ ( Vectors.dense(X[0, :]), float(predt0[0]), pred_contrib0[0, :].tolist(), + float(predt2[0]), ), ( Vectors.sparse(3, {1: 1.0, 2: 5.5}), float(predt0[1]), pred_contrib0[1, :].tolist(), + float(predt2[1]), ), ], [ "features", "expected_prediction", "expected_pred_contribs", + "expected_prediction_with_params", ], ) - yield RegData(reg_df_train, reg_df_test) + yield RegData(reg_df_train, reg_df_test, reg_params) MultiClfData = namedtuple("MultiClfData", ("multi_clf_df_train", "multi_clf_df_test")) @@ -740,6 +755,76 @@ class TestPySparkLocal: model = classifier.fit(clf_data.cls_df_train) model.transform(clf_data.cls_df_test).collect() + def test_regressor_model_save_load(self, reg_data: RegData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = "file:" + tmpdir + regressor = SparkXGBRegressor(**reg_data.reg_params) + model = regressor.fit(reg_data.reg_df_train) + model.save(path) + loaded_model = SparkXGBRegressorModel.load(path) + assert model.uid == loaded_model.uid + for k, v in reg_data.reg_params.items(): + assert loaded_model.getOrDefault(k) == v + + pred_result = loaded_model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + + with pytest.raises(AssertionError, match="Expected class name"): + SparkXGBClassifierModel.load(path) + + assert_model_compatible(model, tmpdir) + + def test_regressor_with_params(self, reg_data: RegData) -> None: + regressor = SparkXGBRegressor(**reg_data.reg_params) + all_params = dict( + **(regressor._gen_xgb_params_dict()), + **(regressor._gen_fit_params_dict()), + **(regressor._gen_predict_params_dict()), + ) + check_sub_dict_match( + reg_data.reg_params, all_params, excluding_keys=_non_booster_params + ) + + model = regressor.fit(reg_data.reg_df_train) + all_params = dict( + **(model._gen_xgb_params_dict()), + **(model._gen_fit_params_dict()), + **(model._gen_predict_params_dict()), + ) + check_sub_dict_match( + reg_data.reg_params, all_params, excluding_keys=_non_booster_params + ) + pred_result = model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + + def test_regressor_model_pipeline_save_load(self, reg_data: RegData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = "file:" + tmpdir + regressor = SparkXGBRegressor() + pipeline = Pipeline(stages=[regressor]) + pipeline = pipeline.copy( + extra=get_params_map(reg_data.reg_params, regressor) + ) + model = pipeline.fit(reg_data.reg_df_train) + model.save(path) + + loaded_model = PipelineModel.load(path) + for k, v in reg_data.reg_params.items(): + assert loaded_model.stages[0].getOrDefault(k) == v + + pred_result = loaded_model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + assert_model_compatible(model.stages[0], tmpdir) + class XgboostLocalTest(SparkTestCase): def setUp(self): @@ -918,12 +1003,6 @@ class XgboostLocalTest(SparkTestCase): def get_local_tmp_dir(self): return self.tempdir + str(uuid.uuid4()) - def assert_model_compatible(self, model: XGBModel, model_path: str): - bst = xgb.Booster() - path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0] - bst.load_model(path) - self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) - def test_convert_to_sklearn_model_reg(self) -> None: regressor = SparkXGBRegressor( n_estimators=200, missing=2.0, max_depth=3, sketch_eps=0.5 @@ -1007,80 +1086,6 @@ class XgboostLocalTest(SparkTestCase): == "float64" ) - def test_regressor_with_params(self): - regressor = SparkXGBRegressor(**self.reg_params) - all_params = dict( - **(regressor._gen_xgb_params_dict()), - **(regressor._gen_fit_params_dict()), - **(regressor._gen_predict_params_dict()), - ) - check_sub_dict_match( - self.reg_params, all_params, excluding_keys=_non_booster_params - ) - - model = regressor.fit(self.reg_df_train) - all_params = dict( - **(model._gen_xgb_params_dict()), - **(model._gen_fit_params_dict()), - **(model._gen_predict_params_dict()), - ) - check_sub_dict_match( - self.reg_params, all_params, excluding_keys=_non_booster_params - ) - pred_result = model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - - def test_regressor_model_save_load(self): - tmp_dir = self.get_local_tmp_dir() - path = "file:" + tmp_dir - regressor = SparkXGBRegressor(**self.reg_params) - model = regressor.fit(self.reg_df_train) - model.save(path) - loaded_model = SparkXGBRegressorModel.load(path) - self.assertEqual(model.uid, loaded_model.uid) - for k, v in self.reg_params.items(): - self.assertEqual(loaded_model.getOrDefault(k), v) - - pred_result = loaded_model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - - with self.assertRaisesRegex(AssertionError, "Expected class name"): - SparkXGBClassifierModel.load(path) - - self.assert_model_compatible(model, tmp_dir) - - def test_regressor_model_pipeline_save_load(self): - tmp_dir = self.get_local_tmp_dir() - path = "file:" + tmp_dir - regressor = SparkXGBRegressor() - pipeline = Pipeline(stages=[regressor]) - pipeline = pipeline.copy(extra=get_params_map(self.reg_params, regressor)) - model = pipeline.fit(self.reg_df_train) - model.save(path) - - loaded_model = PipelineModel.load(path) - for k, v in self.reg_params.items(): - self.assertEqual(loaded_model.stages[0].getOrDefault(k), v) - - pred_result = loaded_model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - self.assert_model_compatible(model.stages[0], tmp_dir) - def test_callbacks(self): from xgboost.callback import LearningRateScheduler diff --git a/tests/test_distributed/test_with_spark/test_spark_local_cluster.py b/tests/test_distributed/test_with_spark/test_spark_local_cluster.py index 528b770ff..199a8087d 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local_cluster.py +++ b/tests/test_distributed/test_with_spark/test_spark_local_cluster.py @@ -1,16 +1,24 @@ import json +import logging import os import random +import tempfile import uuid +from collections import namedtuple import numpy as np import pytest +import xgboost as xgb from xgboost import testing as tm +from xgboost.callback import LearningRateScheduler pytestmark = pytest.mark.skipif(**tm.no_spark()) +from typing import Generator + from pyspark.ml.linalg import Vectors +from pyspark.sql import SparkSession from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor from xgboost.spark.utils import _get_max_num_concurrent_tasks @@ -18,51 +26,119 @@ from xgboost.spark.utils import _get_max_num_concurrent_tasks from .utils import SparkLocalClusterTestCase +@pytest.fixture +def spark() -> Generator[SparkSession, None, None]: + config = { + "spark.master": "local-cluster[2, 2, 1024]", + "spark.python.worker.reuse": "false", + "spark.driver.host": "127.0.0.1", + "spark.task.maxFailures": "1", + "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false", + "spark.sql.pyspark.jvmStacktrace.enabled": "true", + "spark.cores.max": "4", + "spark.task.cpus": "1", + "spark.executor.cores": "2", + } + + builder = SparkSession.builder.appName("XGBoost PySpark Python API Tests") + for k, v in config.items(): + builder.config(k, v) + logging.getLogger("pyspark").setLevel(logging.INFO) + sess = builder.getOrCreate() + yield sess + + sess.stop() + sess.sparkContext.stop() + + +RegData = namedtuple("RegData", ("reg_df_train", "reg_df_test", "reg_params")) + + +@pytest.fixture +def reg_data(spark: SparkSession) -> Generator[RegData, None, None]: + reg_params = {"max_depth": 5, "n_estimators": 10, "iteration_range": (0, 5)} + + X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) + y = np.array([0, 1]) + + def custom_lr(boosting_round): + return 1.0 / (boosting_round + 1) + + reg1 = xgb.XGBRegressor(callbacks=[LearningRateScheduler(custom_lr)]) + reg1.fit(X, y) + predt1 = reg1.predict(X) + # array([0.02406833, 0.97593164], dtype=float32) + + reg2 = xgb.XGBRegressor(max_depth=5, n_estimators=10) + reg2.fit(X, y) + predt2 = reg2.predict(X, iteration_range=(0, 5)) + # array([0.22185263, 0.77814734], dtype=float32) + + reg_df_train = spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), + ], + ["features", "label"], + ) + reg_df_test = spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0.0, float(predt2[0]), float(predt1[0])), + ( + Vectors.sparse(3, {1: 1.0, 2: 5.5}), + 1.0, + float(predt2[1]), + float(predt1[1]), + ), + ], + [ + "features", + "expected_prediction", + "expected_prediction_with_params", + "expected_prediction_with_callbacks", + ], + ) + yield RegData(reg_df_train, reg_df_test, reg_params) + + +class TestPySparkLocalCluster: + def test_regressor_basic_with_params(self, reg_data: RegData) -> None: + regressor = SparkXGBRegressor(**reg_data.reg_params) + model = regressor.fit(reg_data.reg_df_train) + pred_result = model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_params, atol=1e-3 + ) + + def test_callbacks(self, reg_data: RegData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, str(uuid.uuid4())) + + def custom_lr(boosting_round): + return 1.0 / (boosting_round + 1) + + cb = [LearningRateScheduler(custom_lr)] + regressor = SparkXGBRegressor(callbacks=cb) + + # Test the save/load of the estimator instead of the model, since + # the callbacks param only exists in the estimator but not in the model + regressor.save(path) + regressor = SparkXGBRegressor.load(path) + + model = regressor.fit(reg_data.reg_df_train) + pred_result = model.transform(reg_data.reg_df_test).collect() + for row in pred_result: + assert np.isclose( + row.prediction, row.expected_prediction_with_callbacks, atol=1e-3 + ) + + class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): def setUp(self): random.seed(2020) self.n_workers = _get_max_num_concurrent_tasks(self.session.sparkContext) - # The following code use xgboost python library to train xgb model and predict. - # - # >>> import numpy as np - # >>> import xgboost - # >>> X = np.array([[1.0, 2.0, 3.0], [0.0, 1.0, 5.5]]) - # >>> y = np.array([0, 1]) - # >>> reg1 = xgboost.XGBRegressor() - # >>> reg1.fit(X, y) - # >>> reg1.predict(X) - # array([8.8363886e-04, 9.9911636e-01], dtype=float32) - # >>> def custom_lr(boosting_round, num_boost_round): - # ... return 1.0 / (boosting_round + 1) - # ... - # >>> reg1.fit(X, y, callbacks=[xgboost.callback.reset_learning_rate(custom_lr)]) - # >>> reg1.predict(X) - # array([0.02406833, 0.97593164], dtype=float32) - # >>> reg2 = xgboost.XGBRegressor(max_depth=5, n_estimators=10) - # >>> reg2.fit(X, y) - # >>> reg2.predict(X, ntree_limit=5) - # array([0.22185263, 0.77814734], dtype=float32) - self.reg_params = {"max_depth": 5, "n_estimators": 10, "ntree_limit": 5} - self.reg_df_train = self.session.createDataFrame( - [ - (Vectors.dense(1.0, 2.0, 3.0), 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1), - ], - ["features", "label"], - ) - self.reg_df_test = self.session.createDataFrame( - [ - (Vectors.dense(1.0, 2.0, 3.0), 0.0, 0.2219, 0.02406), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 1.0, 0.7781, 0.9759), - ], - [ - "features", - "expected_prediction", - "expected_prediction_with_params", - "expected_prediction_with_callbacks", - ], - ) # Distributed section # Binary classification @@ -218,42 +294,6 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): self.reg_best_score_eval = 5.239e-05 self.reg_best_score_weight_and_eval = 4.850e-05 - def test_regressor_basic_with_params(self): - regressor = SparkXGBRegressor(**self.reg_params) - model = regressor.fit(self.reg_df_train) - pred_result = model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_params, atol=1e-3 - ) - ) - - def test_callbacks(self): - from xgboost.callback import LearningRateScheduler - - path = os.path.join(self.tempdir, str(uuid.uuid4())) - - def custom_learning_rate(boosting_round): - return 1.0 / (boosting_round + 1) - - cb = [LearningRateScheduler(custom_learning_rate)] - regressor = SparkXGBRegressor(callbacks=cb) - - # Test the save/load of the estimator instead of the model, since - # the callbacks param only exists in the estimator but not in the model - regressor.save(path) - regressor = SparkXGBRegressor.load(path) - - model = regressor.fit(self.reg_df_train) - pred_result = model.transform(self.reg_df_test).collect() - for row in pred_result: - self.assertTrue( - np.isclose( - row.prediction, row.expected_prediction_with_callbacks, atol=1e-3 - ) - ) - def test_classifier_distributed_basic(self): classifier = SparkXGBClassifier(num_workers=self.n_workers, n_estimators=100) model = classifier.fit(self.cls_df_train_distributed) @@ -409,7 +449,6 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase): pred_result = model.transform( self.cls_df_test_distributed_lower_estimators ).collect() - print(pred_result) for row in pred_result: self.assertTrue(np.isclose(row.expected_label, row.prediction, atol=1e-3)) self.assertTrue(