From 3c4aa9b2ead21d11ef1589059db2ea50208c55ea Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 28 Oct 2021 13:24:29 +0800 Subject: [PATCH] [breaking] Remove label encoder deprecated in 1.3. (#7357) --- python-package/xgboost/sklearn.py | 73 ++++++----------------- tests/python-gpu/test_from_cudf.py | 2 +- tests/python-gpu/test_from_cupy.py | 2 +- tests/python-gpu/test_gpu_basic_models.py | 24 ++++---- tests/python-gpu/test_gpu_pickling.py | 4 +- tests/python-gpu/test_gpu_with_sklearn.py | 34 ++++++++++- tests/python/test_with_sklearn.py | 18 +++--- 7 files changed, 74 insertions(+), 83 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index e4c4b9928..feea9c5b3 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1,5 +1,4 @@ -# coding: utf-8 -# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, R0912, C0302 +# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines """Scikit-Learn Wrapper interface for XGBoost.""" import copy import warnings @@ -278,14 +277,13 @@ def _wrap_evaluation_matrices( eval_qid: Optional[List[Any]], create_dmatrix: Callable, enable_categorical: bool, - label_transform: Callable = lambda x: x, ) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]: """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. """ train_dmatrix = create_dmatrix( data=X, - label=label_transform(y), + label=y, group=group, qid=qid, weight=sample_weight, @@ -333,7 +331,7 @@ def _wrap_evaluation_matrices( else: m = create_dmatrix( data=valid_X, - label=label_transform(valid_y), + label=valid_y, weight=sample_weight_eval_set[i], group=eval_group[i], qid=eval_qid[i], @@ -1112,9 +1110,6 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> ['model', 'objective'], extra_parameters=''' n_estimators : int Number of boosting rounds. - use_label_encoder : bool - (Deprecated) Use the label encoder from scikit-learn to encode the labels. For new - code, we recommend that you set this parameter to False. ''') class XGBClassifier(XGBModel, XGBClassifierBase): # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes @@ -1123,10 +1118,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase): self, *, objective: _SklObjective = "binary:logistic", - use_label_encoder: bool = True, + use_label_encoder: bool = False, **kwargs: Any ) -> None: + # must match the parameters for `get_params` self.use_label_encoder = use_label_encoder + if use_label_encoder is True: + raise ValueError("Label encoder was removed in 1.6.") super().__init__(objective=objective, **kwargs) @_deprecate_positional_args @@ -1148,51 +1146,32 @@ class XGBClassifier(XGBModel, XGBClassifierBase): callbacks: Optional[List[TrainingCallback]] = None ) -> "XGBClassifier": # pylint: disable = attribute-defined-outside-init,too-many-statements - can_use_label_encoder = True - label_encoding_check_error = ( - "The label must consist of integer " - "labels of form 0, 1, 2, ..., [num_class - 1]." - ) - label_encoder_deprecation_msg = ( - "The use of label encoder in XGBClassifier is deprecated and will be " - "removed in a future release. To remove this warning, do the " - "following: 1) Pass option use_label_encoder=False when constructing " - "XGBClassifier object; and 2) Encode your labels (y) as integers " - "starting with 0, i.e. 0, 1, 2, ..., [num_class - 1]." - ) - evals_result: TrainingCallback.EvalsLog = {} + if _is_cudf_df(y) or _is_cudf_ser(y): import cupy as cp # pylint: disable=E0401 self.classes_ = cp.unique(y.values) self.n_classes_ = len(self.classes_) - can_use_label_encoder = False expected_classes = cp.arange(self.n_classes_) - if ( - self.classes_.shape != expected_classes.shape - or not (self.classes_ == expected_classes).all() - ): - raise ValueError(label_encoding_check_error) elif _is_cupy_array(y): import cupy as cp # pylint: disable=E0401 self.classes_ = cp.unique(y) self.n_classes_ = len(self.classes_) - can_use_label_encoder = False expected_classes = cp.arange(self.n_classes_) - if ( - self.classes_.shape != expected_classes.shape - or not (self.classes_ == expected_classes).all() - ): - raise ValueError(label_encoding_check_error) else: self.classes_ = np.unique(np.asarray(y)) self.n_classes_ = len(self.classes_) - if not self.use_label_encoder and ( - not np.array_equal(self.classes_, np.arange(self.n_classes_)) - ): - raise ValueError(label_encoding_check_error) + expected_classes = np.arange(self.n_classes_) + if ( + self.classes_.shape != expected_classes.shape + or not (self.classes_ == expected_classes).all() + ): + raise ValueError( + f"Invalid classes inferred from unique values of `y`. " + f"Expected: {expected_classes}, got {self.classes_}" + ) params = self.get_xgb_params() @@ -1211,18 +1190,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase): params["objective"] = "multi:softprob" params["num_class"] = self.n_classes_ - if self.use_label_encoder: - if not can_use_label_encoder: - raise ValueError('The option use_label_encoder=True is incompatible with inputs ' + - 'of type cuDF or cuPy. Please set use_label_encoder=False when ' + - 'constructing XGBClassifier object. NOTE: ' + - label_encoder_deprecation_msg) - warnings.warn(label_encoder_deprecation_msg, UserWarning) - self._le = XGBoostLabelEncoder().fit(y) - label_transform = self._le.transform - else: - label_transform = lambda x: x - model, feval, params = self._configure_fit(xgb_model, eval_metric, params) train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, @@ -1240,7 +1207,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase): eval_qid=None, create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), enable_categorical=self.enable_categorical, - label_transform=label_transform, ) self._Booster = train( @@ -1403,9 +1369,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase): extra_parameters=''' n_estimators : int Number of trees in random forest to fit. - use_label_encoder : bool - (Deprecated) Use the label encoder from scikit-learn to encode the labels. For new - code, we recommend that you set this parameter to False. ''') class XGBRFClassifier(XGBClassifier): # pylint: disable=missing-docstring @@ -1416,14 +1379,12 @@ class XGBRFClassifier(XGBClassifier): subsample: float = 0.8, colsample_bynode: float = 0.8, reg_lambda: float = 1e-5, - use_label_encoder: bool = True, **kwargs: Any ): super().__init__(learning_rate=learning_rate, subsample=subsample, colsample_bynode=colsample_bynode, reg_lambda=reg_lambda, - use_label_encoder=use_label_encoder, **kwargs) def get_xgb_params(self) -> Dict[str, Any]: diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 904dbf093..6250ab328 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -239,7 +239,7 @@ def test_cudf_training_with_sklearn(): y_cudf_series = ss(data=y.iloc[:, 0]) for y_obj in [y_cudf, y_cudf_series]: - clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False) + clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist') clf.fit(X_cudf, y_obj, sample_weight=cudf_weights, base_margin=cudf_base_margin, eval_set=[(X_cudf, y_obj)]) pred = clf.predict(X_cudf) diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index d0504e575..fee5edbb8 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -122,7 +122,7 @@ def test_cupy_training_with_sklearn(): base_margin = np.random.random(50) cupy_base_margin = cp.array(base_margin) - clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist", use_label_encoder=False) + clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist") clf.fit( X, y, diff --git a/tests/python-gpu/test_gpu_basic_models.py b/tests/python-gpu/test_gpu_basic_models.py index b65370f10..06e63bdd5 100644 --- a/tests/python-gpu/test_gpu_basic_models.py +++ b/tests/python-gpu/test_gpu_basic_models.py @@ -7,6 +7,7 @@ sys.path.append("tests/python") # Don't import the test class, otherwise they will run twice. import test_callback as test_cb # noqa import test_basic_models as test_bm +import testing as tm rng = np.random.RandomState(1994) @@ -14,16 +15,12 @@ class TestGPUBasicModels: cpu_test_cb = test_cb.TestCallbacks() cpu_test_bm = test_bm.TestModels() - def run_cls(self, X, y, deterministic): - cls = xgb.XGBClassifier(tree_method='gpu_hist', - deterministic_histogram=deterministic, - single_precision_histogram=True) + def run_cls(self, X, y): + cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) cls.fit(X, y) cls.get_booster().save_model('test_deterministic_gpu_hist-0.json') - cls = xgb.XGBClassifier(tree_method='gpu_hist', - deterministic_histogram=deterministic, - single_precision_histogram=True) + cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) cls.fit(X, y) cls.get_booster().save_model('test_deterministic_gpu_hist-1.json') @@ -49,19 +46,22 @@ class TestGPUBasicModels: kClasses = 4 # Create large values to force rounding. X = np.random.randn(kRows, kCols) * 1e4 - y = np.random.randint(0, kClasses, size=kRows) * 1e4 + y = np.random.randint(0, kClasses, size=kRows) - model_0, model_1 = self.run_cls(X, y, True) + model_0, model_1 = self.run_cls(X, y) assert model_0 == model_1 + @pytest.mark.skipif(**tm.no_sklearn()) def test_invalid_gpu_id(self): - X = np.random.randn(10, 5) * 1e4 - y = np.random.randint(0, 2, size=10) * 1e4 + from sklearn.datasets import load_digits + X, y = load_digits(return_X_y=True) # should pass with invalid gpu id cls1 = xgb.XGBClassifier(tree_method='gpu_hist', gpu_id=9999) cls1.fit(X, y) # should throw error with fail_on_invalid_gpu_id enabled - cls2 = xgb.XGBClassifier(tree_method='gpu_hist', gpu_id=9999, fail_on_invalid_gpu_id=True) + cls2 = xgb.XGBClassifier( + tree_method='gpu_hist', gpu_id=9999, fail_on_invalid_gpu_id=True + ) try: cls2.fit(X, y) assert False, "Should have failed with with fail_on_invalid_gpu_id enabled" diff --git a/tests/python-gpu/test_gpu_pickling.py b/tests/python-gpu/test_gpu_pickling.py index 19a4bf76a..d368c1ceb 100644 --- a/tests/python-gpu/test_gpu_pickling.py +++ b/tests/python-gpu/test_gpu_pickling.py @@ -146,8 +146,10 @@ class TestPickling: os.remove(model_path) + @pytest.mark.skipif(**tm.no_sklearn()) def test_predict_sklearn_pickle(self): - x, y = build_dataset() + from sklearn.datasets import load_digits + x, y = load_digits(return_X_y=True) kwargs = {'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor', diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index f8d510753..5f70ef631 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -56,7 +56,6 @@ def test_categorical(): X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train")) clf = xgb.XGBClassifier( tree_method="gpu_hist", - use_label_encoder=False, enable_categorical=True, n_estimators=10, ) @@ -98,3 +97,36 @@ def test_categorical(): X = cudf.DataFrame(X) check_predt(X, y) + + +@pytest.mark.skipif(**tm.no_cupy()) +@pytest.mark.skipif(**tm.no_cudf()) +def test_classififer(): + from sklearn.datasets import load_digits + import cupy as cp + import cudf + + X, y = load_digits(return_X_y=True) + y *= 10 + + clf = xgb.XGBClassifier(tree_method="gpu_hist", n_estimators=1) + + # numpy + with pytest.raises(ValueError, match=r"Invalid classes.*"): + clf.fit(X, y) + + # cupy + X, y = cp.array(X), cp.array(y) + with pytest.raises(ValueError, match=r"Invalid classes.*"): + clf.fit(X, y) + + # cudf + X, y = cudf.DataFrame(X), cudf.DataFrame(y) + with pytest.raises(ValueError, match=r"Invalid classes.*"): + clf.fit(X, y) + + # pandas + X, y = load_digits(return_X_y=True, as_frame=True) + y *= 10 + with pytest.raises(ValueError, match=r"Invalid classes.*"): + clf.fit(X, y) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 75020eec1..7ca79fecc 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -283,7 +283,6 @@ def test_feature_importances_gain(): random_state=0, tree_method="exact", learning_rate=0.1, importance_type="gain", - use_label_encoder=False, ).fit(X, y) exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., @@ -306,7 +305,6 @@ def test_feature_importances_gain(): tree_method="exact", learning_rate=0.1, importance_type="gain", - use_label_encoder=False, ).fit(X, y) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) @@ -315,14 +313,11 @@ def test_feature_importances_gain(): tree_method="exact", learning_rate=0.1, importance_type="gain", - use_label_encoder=False, ).fit(X, y) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) # no split can be found - cls = xgb.XGBClassifier( - min_child_weight=1000, tree_method="hist", n_estimators=1, use_label_encoder=False - ) + cls = xgb.XGBClassifier(min_child_weight=1000, tree_method="hist", n_estimators=1) cls.fit(X, y) assert np.all(cls.feature_importances_ == 0) @@ -497,7 +492,7 @@ def test_classification_with_custom_objective(): X, y ) - cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1) + cls = xgb.XGBClassifier(n_estimators=1) cls.fit(X, y) is_called = [False] @@ -923,7 +918,7 @@ def test_RFECV(): bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1, n_estimators=10, objective='binary:logistic', - random_state=0, verbosity=0, use_label_encoder=False) + random_state=0, verbosity=0) rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='roc_auc') rfecv.fit(X, y) @@ -934,7 +929,7 @@ def test_RFECV(): n_estimators=10, objective='multi:softprob', random_state=0, reg_alpha=0.001, reg_lambda=0.01, - scale_pos_weight=0.5, verbosity=0, use_label_encoder=False) + scale_pos_weight=0.5, verbosity=0) rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_log_loss') rfecv.fit(X, y) @@ -943,7 +938,7 @@ def test_RFECV(): rfecv = RFECV(estimator=reg) rfecv.fit(X, y) - cls = xgb.XGBClassifier(use_label_encoder=False) + cls = xgb.XGBClassifier() rfecv = RFECV(estimator=cls, step=1, cv=3, scoring='neg_mean_squared_error') rfecv.fit(X, y) @@ -1052,8 +1047,9 @@ def test_deprecate_position_arg(): with pytest.warns(FutureWarning): model.fit(X, y, w) - with pytest.warns(FutureWarning): + with pytest.raises(ValueError): xgb.XGBRFClassifier(1, use_label_encoder=True) + model = xgb.XGBRFClassifier(n_estimators=1) with pytest.warns(FutureWarning): model.fit(X, y, w)