From 6cea1e3fb7bdc435aeba5ea6fc5567e29249486e Mon Sep 17 00:00:00 2001 From: gaw89 Date: Mon, 22 May 2017 09:22:05 -0400 Subject: [PATCH] Sklearn convention update (#2323) * Added n_jobs and random_state to keep up to date with sklearn API. Deprecated nthread and seed. Added tests for new params and deprecations. * Fixed docstring to reflect updates to n_jobs and random_state. * Fixed whitespace issues and removed nose import. * Added deprecation note for nthread and seed in docstring. * Attempted fix of deprecation tests. * Second attempted fix to tests. * Set n_jobs to 1. --- python-package/xgboost/sklearn.py | 48 ++++++++++++++++++++++++------- tests/python/test_with_sklearn.py | 37 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 11 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 919d929d4..66df796e6 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -4,6 +4,7 @@ from __future__ import absolute_import import numpy as np +import warnings from .core import Booster, DMatrix, XGBoostError from .training import train @@ -12,6 +13,8 @@ from .training import train from .compat import (SKLEARN_INSTALLED, XGBModelBase, XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder) +warnings.simplefilter('always', DeprecationWarning) + def _objective_decorator(func): """Decorate an objective function @@ -68,7 +71,9 @@ class XGBModel(XGBModelBase): booster: string Specify which booster to use: gbtree, gblinear or dart. nthread : int - Number of parallel threads used to run xgboost. + Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs) + n_jobs : int + Number of parallel threads used to run xgboost. (replaces nthread) gamma : float Minimum loss reduction required to make a further partition on a leaf node of the tree. min_child_weight : int @@ -87,11 +92,12 @@ class XGBModel(XGBModelBase): L2 regularization term on weights scale_pos_weight : float Balancing of positive and negative weights. - base_score: The initial prediction score of all instances, global bias. seed : int - Random number seed. + Random number seed. (Deprecated, please use random_state) + random_state : int + Random number seed. (replaces seed) missing : float, optional Value in the data which needs to be present as a missing value. If None, defaults to np.nan. @@ -115,10 +121,10 @@ class XGBModel(XGBModelBase): def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="reg:linear", booster='gbtree', - nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, + n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, - base_score=0.5, seed=0, missing=None): + base_score=0.5, random_state=0, seed=None, missing=None): if not SKLEARN_INSTALLED: raise XGBoostError('sklearn needs to be installed in order to use this module') self.max_depth = max_depth @@ -138,11 +144,28 @@ class XGBModel(XGBModelBase): self.reg_alpha = reg_alpha self.reg_lambda = reg_lambda self.scale_pos_weight = scale_pos_weight - self.base_score = base_score - self.seed = seed self.missing = missing if missing is not None else np.nan self._Booster = None + if seed: + warnings.warn('The seed parameter is deprecated as of version .6.' + 'Please use random_state instead.' + 'seed is deprecated.', DeprecationWarning) + self.seed = seed + self.random_state = seed + else: + self.seed = random_state + self.random_state = random_state + + if nthread: + warnings.warn('The nthread parameter is deprecated as of version .6.' + 'Please use n_jobs instead.' + 'nthread is deprecated.', DeprecationWarning) + self.nthread = nthread + self.n_jobs = nthread + else: + self.nthread = n_jobs + self.n_jobs = n_jobs def __setstate__(self, state): # backward compatibility code @@ -178,6 +201,8 @@ class XGBModel(XGBModelBase): def get_xgb_params(self): """Get xgboost type parameters.""" xgb_params = self.get_params() + xgb_params.pop('random_state') + xgb_params.pop('n_jobs') xgb_params['silent'] = 1 if self.silent else 0 @@ -360,17 +385,18 @@ class XGBClassifier(XGBModel, XGBClassifierBase): def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="binary:logistic", booster='gbtree', - nthread=-1, gamma=0, min_child_weight=1, + n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, - base_score=0.5, seed=0, missing=None): + base_score=0.5, random_state=0, seed=None, missing=None): super(XGBClassifier, self).__init__(max_depth, learning_rate, n_estimators, silent, objective, booster, - nthread, gamma, min_child_weight, + n_jobs, nthread, gamma, min_child_weight, max_delta_step, subsample, colsample_bytree, colsample_bylevel, reg_alpha, reg_lambda, - scale_pos_weight, base_score, seed, missing) + scale_pos_weight, base_score, + random_state, seed, missing) def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True): diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index b5fc29411..b676a1a08 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -2,6 +2,7 @@ import numpy as np import random import xgboost as xgb import testing as tm +import warnings rng = np.random.RandomState(1994) @@ -326,3 +327,39 @@ def test_split_value_histograms(): assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2 assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2 assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2 + + +def test_sklearn_random_state(): + tm._skip_if_no_sklearn() + + clf = xgb.XGBClassifier(random_state=402) + assert clf.get_params()['seed'] == 402 + + clf = xgb.XGBClassifier(seed=401) + assert clf.get_params()['seed'] == 401 + + +def test_seed_deprecation(): + tm._skip_if_no_sklearn() + warnings.simplefilter("always") + with warnings.catch_warnings(record=True) as w: + xgb.XGBClassifier(seed=1) + assert w[0].category == DeprecationWarning + + +def test_sklearn_n_jobs(): + tm._skip_if_no_sklearn() + + clf = xgb.XGBClassifier(n_jobs=1) + assert clf.get_params()['nthread'] == 1 + + clf = xgb.XGBClassifier(nthread=2) + assert clf.get_params()['nthread'] == 2 + + +def test_nthread_deprecation(): + tm._skip_if_no_sklearn() + warnings.simplefilter("always") + with warnings.catch_warnings(record=True) as w: + xgb.XGBClassifier(nthread=1) + assert w[0].category == DeprecationWarning