diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 7638fa399..be87ed774 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -155,25 +155,10 @@ class XGBModel(XGBModelBase): self.missing = missing if missing is not None else np.nan self.kwargs = kwargs self._Booster = None - if seed: - warnings.warn('The seed parameter is deprecated as of version .6.' - 'Please use random_state instead.' - 'seed is deprecated.', DeprecationWarning) - self.seed = seed - self.random_state = seed - else: - self.seed = random_state - self.random_state = random_state - - if nthread: - warnings.warn('The nthread parameter is deprecated as of version .6.' - 'Please use n_jobs instead.' - 'nthread is deprecated.', DeprecationWarning) - self.nthread = nthread - self.n_jobs = nthread - else: - self.nthread = n_jobs - self.n_jobs = n_jobs + self.seed = seed + self.random_state = random_state + self.nthread = nthread + self.n_jobs = n_jobs def __setstate__(self, state): # backward compatibility code @@ -211,12 +196,24 @@ class XGBModel(XGBModelBase): def get_xgb_params(self): """Get xgboost type parameters.""" xgb_params = self.get_params() - xgb_params.pop('random_state') - xgb_params.pop('n_jobs') + random_state = xgb_params.pop('random_state') + if xgb_params['seed'] is not None: + warnings.warn('The seed parameter is deprecated as of version .6.' + 'Please use random_state instead.' + 'seed is deprecated.', DeprecationWarning) + else: + xgb_params['seed'] = random_state + n_jobs = xgb_params.pop('n_jobs') + if xgb_params['nthread'] is not None: + warnings.warn('The nthread parameter is deprecated as of version .6.' + 'Please use n_jobs instead.' + 'nthread is deprecated.', DeprecationWarning) + else: + xgb_params['nthread'] = n_jobs xgb_params['silent'] = 1 if self.silent else 0 - if self.nthread <= 0: + if xgb_params['nthread'] <= 0: xgb_params.pop('nthread', None) return xgb_params diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 2ef3ea83c..8d61217bb 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -334,17 +334,17 @@ def test_sklearn_random_state(): tm._skip_if_no_sklearn() clf = xgb.XGBClassifier(random_state=402) - assert clf.get_params()['seed'] == 402 + assert clf.get_xgb_params()['seed'] == 402 clf = xgb.XGBClassifier(seed=401) - assert clf.get_params()['seed'] == 401 + assert clf.get_xgb_params()['seed'] == 401 def test_seed_deprecation(): tm._skip_if_no_sklearn() warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: - xgb.XGBClassifier(seed=1) + xgb.XGBClassifier(seed=1).get_xgb_params() assert w[0].category == DeprecationWarning @@ -352,17 +352,17 @@ def test_sklearn_n_jobs(): tm._skip_if_no_sklearn() clf = xgb.XGBClassifier(n_jobs=1) - assert clf.get_params()['nthread'] == 1 + assert clf.get_xgb_params()['nthread'] == 1 clf = xgb.XGBClassifier(nthread=2) - assert clf.get_params()['nthread'] == 2 + assert clf.get_xgb_params()['nthread'] == 2 def test_nthread_deprecation(): tm._skip_if_no_sklearn() warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: - xgb.XGBClassifier(nthread=1) + xgb.XGBClassifier(nthread=1).get_xgb_params() assert w[0].category == DeprecationWarning @@ -383,3 +383,12 @@ def test_kwargs_error(): params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1} clf = xgb.XGBClassifier(n_jobs=1000, **params) assert isinstance(clf, xgb.XGBClassifier) + + +def test_sklearn_clone(): + tm._skip_if_no_sklearn() + from sklearn.base import clone + + clf = xgb.XGBClassifier(n_jobs=2, nthread=3) + clf.n_jobs = -1 + clone(clf)