[breaking] Remove label encoder deprecated in 1.3. (#7357)

This commit is contained in:
Jiaming Yuan 2021-10-28 13:24:29 +08:00 committed by GitHub
parent d05754f558
commit 3c4aa9b2ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 74 additions and 83 deletions

View File

@ -1,5 +1,4 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, R0912, C0302
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines
"""Scikit-Learn Wrapper interface for XGBoost."""
import copy
import warnings
@ -278,14 +277,13 @@ def _wrap_evaluation_matrices(
eval_qid: Optional[List[Any]],
create_dmatrix: Callable,
enable_categorical: bool,
label_transform: Callable = lambda x: x,
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
"""
train_dmatrix = create_dmatrix(
data=X,
label=label_transform(y),
label=y,
group=group,
qid=qid,
weight=sample_weight,
@ -333,7 +331,7 @@ def _wrap_evaluation_matrices(
else:
m = create_dmatrix(
data=valid_X,
label=label_transform(valid_y),
label=valid_y,
weight=sample_weight_eval_set[i],
group=eval_group[i],
qid=eval_qid[i],
@ -1112,9 +1110,6 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
['model', 'objective'], extra_parameters='''
n_estimators : int
Number of boosting rounds.
use_label_encoder : bool
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new
code, we recommend that you set this parameter to False.
''')
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@ -1123,10 +1118,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
self,
*,
objective: _SklObjective = "binary:logistic",
use_label_encoder: bool = True,
use_label_encoder: bool = False,
**kwargs: Any
) -> None:
# must match the parameters for `get_params`
self.use_label_encoder = use_label_encoder
if use_label_encoder is True:
raise ValueError("Label encoder was removed in 1.6.")
super().__init__(objective=objective, **kwargs)
@_deprecate_positional_args
@ -1148,51 +1146,32 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
callbacks: Optional[List[TrainingCallback]] = None
) -> "XGBClassifier":
# pylint: disable = attribute-defined-outside-init,too-many-statements
can_use_label_encoder = True
label_encoding_check_error = (
"The label must consist of integer "
"labels of form 0, 1, 2, ..., [num_class - 1]."
)
label_encoder_deprecation_msg = (
"The use of label encoder in XGBClassifier is deprecated and will be "
"removed in a future release. To remove this warning, do the "
"following: 1) Pass option use_label_encoder=False when constructing "
"XGBClassifier object; and 2) Encode your labels (y) as integers "
"starting with 0, i.e. 0, 1, 2, ..., [num_class - 1]."
)
evals_result: TrainingCallback.EvalsLog = {}
if _is_cudf_df(y) or _is_cudf_ser(y):
import cupy as cp # pylint: disable=E0401
self.classes_ = cp.unique(y.values)
self.n_classes_ = len(self.classes_)
can_use_label_encoder = False
expected_classes = cp.arange(self.n_classes_)
if (
self.classes_.shape != expected_classes.shape
or not (self.classes_ == expected_classes).all()
):
raise ValueError(label_encoding_check_error)
elif _is_cupy_array(y):
import cupy as cp # pylint: disable=E0401
self.classes_ = cp.unique(y)
self.n_classes_ = len(self.classes_)
can_use_label_encoder = False
expected_classes = cp.arange(self.n_classes_)
if (
self.classes_.shape != expected_classes.shape
or not (self.classes_ == expected_classes).all()
):
raise ValueError(label_encoding_check_error)
else:
self.classes_ = np.unique(np.asarray(y))
self.n_classes_ = len(self.classes_)
if not self.use_label_encoder and (
not np.array_equal(self.classes_, np.arange(self.n_classes_))
):
raise ValueError(label_encoding_check_error)
expected_classes = np.arange(self.n_classes_)
if (
self.classes_.shape != expected_classes.shape
or not (self.classes_ == expected_classes).all()
):
raise ValueError(
f"Invalid classes inferred from unique values of `y`. "
f"Expected: {expected_classes}, got {self.classes_}"
)
params = self.get_xgb_params()
@ -1211,18 +1190,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_
if self.use_label_encoder:
if not can_use_label_encoder:
raise ValueError('The option use_label_encoder=True is incompatible with inputs ' +
'of type cuDF or cuPy. Please set use_label_encoder=False when ' +
'constructing XGBClassifier object. NOTE: ' +
label_encoder_deprecation_msg)
warnings.warn(label_encoder_deprecation_msg, UserWarning)
self._le = XGBoostLabelEncoder().fit(y)
label_transform = self._le.transform
else:
label_transform = lambda x: x
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
@ -1240,7 +1207,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
enable_categorical=self.enable_categorical,
label_transform=label_transform,
)
self._Booster = train(
@ -1403,9 +1369,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
extra_parameters='''
n_estimators : int
Number of trees in random forest to fit.
use_label_encoder : bool
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new
code, we recommend that you set this parameter to False.
''')
class XGBRFClassifier(XGBClassifier):
# pylint: disable=missing-docstring
@ -1416,14 +1379,12 @@ class XGBRFClassifier(XGBClassifier):
subsample: float = 0.8,
colsample_bynode: float = 0.8,
reg_lambda: float = 1e-5,
use_label_encoder: bool = True,
**kwargs: Any
):
super().__init__(learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
use_label_encoder=use_label_encoder,
**kwargs)
def get_xgb_params(self) -> Dict[str, Any]:

View File

@ -239,7 +239,7 @@ def test_cudf_training_with_sklearn():
y_cudf_series = ss(data=y.iloc[:, 0])
for y_obj in [y_cudf, y_cudf_series]:
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False)
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist')
clf.fit(X_cudf, y_obj, sample_weight=cudf_weights, base_margin=cudf_base_margin,
eval_set=[(X_cudf, y_obj)])
pred = clf.predict(X_cudf)

View File

@ -122,7 +122,7 @@ def test_cupy_training_with_sklearn():
base_margin = np.random.random(50)
cupy_base_margin = cp.array(base_margin)
clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist", use_label_encoder=False)
clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist")
clf.fit(
X,
y,

View File

@ -7,6 +7,7 @@ sys.path.append("tests/python")
# Don't import the test class, otherwise they will run twice.
import test_callback as test_cb # noqa
import test_basic_models as test_bm
import testing as tm
rng = np.random.RandomState(1994)
@ -14,16 +15,12 @@ class TestGPUBasicModels:
cpu_test_cb = test_cb.TestCallbacks()
cpu_test_bm = test_bm.TestModels()
def run_cls(self, X, y, deterministic):
cls = xgb.XGBClassifier(tree_method='gpu_hist',
deterministic_histogram=deterministic,
single_precision_histogram=True)
def run_cls(self, X, y):
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
cls = xgb.XGBClassifier(tree_method='gpu_hist',
deterministic_histogram=deterministic,
single_precision_histogram=True)
cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True)
cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
@ -49,19 +46,22 @@ class TestGPUBasicModels:
kClasses = 4
# Create large values to force rounding.
X = np.random.randn(kRows, kCols) * 1e4
y = np.random.randint(0, kClasses, size=kRows) * 1e4
y = np.random.randint(0, kClasses, size=kRows)
model_0, model_1 = self.run_cls(X, y, True)
model_0, model_1 = self.run_cls(X, y)
assert model_0 == model_1
@pytest.mark.skipif(**tm.no_sklearn())
def test_invalid_gpu_id(self):
X = np.random.randn(10, 5) * 1e4
y = np.random.randint(0, 2, size=10) * 1e4
from sklearn.datasets import load_digits
X, y = load_digits(return_X_y=True)
# should pass with invalid gpu id
cls1 = xgb.XGBClassifier(tree_method='gpu_hist', gpu_id=9999)
cls1.fit(X, y)
# should throw error with fail_on_invalid_gpu_id enabled
cls2 = xgb.XGBClassifier(tree_method='gpu_hist', gpu_id=9999, fail_on_invalid_gpu_id=True)
cls2 = xgb.XGBClassifier(
tree_method='gpu_hist', gpu_id=9999, fail_on_invalid_gpu_id=True
)
try:
cls2.fit(X, y)
assert False, "Should have failed with with fail_on_invalid_gpu_id enabled"

View File

@ -146,8 +146,10 @@ class TestPickling:
os.remove(model_path)
@pytest.mark.skipif(**tm.no_sklearn())
def test_predict_sklearn_pickle(self):
x, y = build_dataset()
from sklearn.datasets import load_digits
x, y = load_digits(return_X_y=True)
kwargs = {'tree_method': 'gpu_hist',
'predictor': 'gpu_predictor',

View File

@ -56,7 +56,6 @@ def test_categorical():
X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train"))
clf = xgb.XGBClassifier(
tree_method="gpu_hist",
use_label_encoder=False,
enable_categorical=True,
n_estimators=10,
)
@ -98,3 +97,36 @@ def test_categorical():
X = cudf.DataFrame(X)
check_predt(X, y)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_cudf())
def test_classififer():
from sklearn.datasets import load_digits
import cupy as cp
import cudf
X, y = load_digits(return_X_y=True)
y *= 10
clf = xgb.XGBClassifier(tree_method="gpu_hist", n_estimators=1)
# numpy
with pytest.raises(ValueError, match=r"Invalid classes.*"):
clf.fit(X, y)
# cupy
X, y = cp.array(X), cp.array(y)
with pytest.raises(ValueError, match=r"Invalid classes.*"):
clf.fit(X, y)
# cudf
X, y = cudf.DataFrame(X), cudf.DataFrame(y)
with pytest.raises(ValueError, match=r"Invalid classes.*"):
clf.fit(X, y)
# pandas
X, y = load_digits(return_X_y=True, as_frame=True)
y *= 10
with pytest.raises(ValueError, match=r"Invalid classes.*"):
clf.fit(X, y)

View File

@ -283,7 +283,6 @@ def test_feature_importances_gain():
random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain",
use_label_encoder=False,
).fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
@ -306,7 +305,6 @@ def test_feature_importances_gain():
tree_method="exact",
learning_rate=0.1,
importance_type="gain",
use_label_encoder=False,
).fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -315,14 +313,11 @@ def test_feature_importances_gain():
tree_method="exact",
learning_rate=0.1,
importance_type="gain",
use_label_encoder=False,
).fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
# no split can be found
cls = xgb.XGBClassifier(
min_child_weight=1000, tree_method="hist", n_estimators=1, use_label_encoder=False
)
cls = xgb.XGBClassifier(min_child_weight=1000, tree_method="hist", n_estimators=1)
cls.fit(X, y)
assert np.all(cls.feature_importances_ == 0)
@ -497,7 +492,7 @@ def test_classification_with_custom_objective():
X, y
)
cls = xgb.XGBClassifier(use_label_encoder=False, n_estimators=1)
cls = xgb.XGBClassifier(n_estimators=1)
cls.fit(X, y)
is_called = [False]
@ -923,7 +918,7 @@ def test_RFECV():
bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1,
n_estimators=10,
objective='binary:logistic',
random_state=0, verbosity=0, use_label_encoder=False)
random_state=0, verbosity=0)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='roc_auc')
rfecv.fit(X, y)
@ -934,7 +929,7 @@ def test_RFECV():
n_estimators=10,
objective='multi:softprob',
random_state=0, reg_alpha=0.001, reg_lambda=0.01,
scale_pos_weight=0.5, verbosity=0, use_label_encoder=False)
scale_pos_weight=0.5, verbosity=0)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_log_loss')
rfecv.fit(X, y)
@ -943,7 +938,7 @@ def test_RFECV():
rfecv = RFECV(estimator=reg)
rfecv.fit(X, y)
cls = xgb.XGBClassifier(use_label_encoder=False)
cls = xgb.XGBClassifier()
rfecv = RFECV(estimator=cls, step=1, cv=3,
scoring='neg_mean_squared_error')
rfecv.fit(X, y)
@ -1052,8 +1047,9 @@ def test_deprecate_position_arg():
with pytest.warns(FutureWarning):
model.fit(X, y, w)
with pytest.warns(FutureWarning):
with pytest.raises(ValueError):
xgb.XGBRFClassifier(1, use_label_encoder=True)
model = xgb.XGBRFClassifier(n_estimators=1)
with pytest.warns(FutureWarning):
model.fit(X, y, w)