Sklearn convention update (#2323)

* Added n_jobs and random_state to keep up to date with sklearn API.
Deprecated nthread and seed.  Added tests for new params and
deprecations.

* Fixed docstring to reflect updates to n_jobs and random_state.

* Fixed whitespace issues and removed nose import.

* Added deprecation note for nthread and seed in docstring.

* Attempted fix of deprecation tests.

* Second attempted fix to tests.

* Set n_jobs to 1.
This commit is contained in:
gaw89 2017-05-22 09:22:05 -04:00 committed by Yuan (Terry) Tang
parent da1629e848
commit 6cea1e3fb7
2 changed files with 74 additions and 11 deletions

View File

@ -4,6 +4,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
import warnings
from .core import Booster, DMatrix, XGBoostError from .core import Booster, DMatrix, XGBoostError
from .training import train from .training import train
@ -12,6 +13,8 @@ from .training import train
from .compat import (SKLEARN_INSTALLED, XGBModelBase, from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder) XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
warnings.simplefilter('always', DeprecationWarning)
def _objective_decorator(func): def _objective_decorator(func):
"""Decorate an objective function """Decorate an objective function
@ -68,7 +71,9 @@ class XGBModel(XGBModelBase):
booster: string booster: string
Specify which booster to use: gbtree, gblinear or dart. Specify which booster to use: gbtree, gblinear or dart.
nthread : int nthread : int
Number of parallel threads used to run xgboost. Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
n_jobs : int
Number of parallel threads used to run xgboost. (replaces nthread)
gamma : float gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree. Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int min_child_weight : int
@ -87,11 +92,12 @@ class XGBModel(XGBModelBase):
L2 regularization term on weights L2 regularization term on weights
scale_pos_weight : float scale_pos_weight : float
Balancing of positive and negative weights. Balancing of positive and negative weights.
base_score: base_score:
The initial prediction score of all instances, global bias. The initial prediction score of all instances, global bias.
seed : int seed : int
Random number seed. Random number seed. (Deprecated, please use random_state)
random_state : int
Random number seed. (replaces seed)
missing : float, optional missing : float, optional
Value in the data which needs to be present as a missing value. If Value in the data which needs to be present as a missing value. If
None, defaults to np.nan. None, defaults to np.nan.
@ -115,10 +121,10 @@ class XGBModel(XGBModelBase):
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear", booster='gbtree', silent=True, objective="reg:linear", booster='gbtree',
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1, subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, seed=0, missing=None): base_score=0.5, random_state=0, seed=None, missing=None):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use this module') raise XGBoostError('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth self.max_depth = max_depth
@ -138,11 +144,28 @@ class XGBModel(XGBModelBase):
self.reg_alpha = reg_alpha self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight self.scale_pos_weight = scale_pos_weight
self.base_score = base_score self.base_score = base_score
self.seed = seed
self.missing = missing if missing is not None else np.nan self.missing = missing if missing is not None else np.nan
self._Booster = None self._Booster = None
if seed:
warnings.warn('The seed parameter is deprecated as of version .6.'
'Please use random_state instead.'
'seed is deprecated.', DeprecationWarning)
self.seed = seed
self.random_state = seed
else:
self.seed = random_state
self.random_state = random_state
if nthread:
warnings.warn('The nthread parameter is deprecated as of version .6.'
'Please use n_jobs instead.'
'nthread is deprecated.', DeprecationWarning)
self.nthread = nthread
self.n_jobs = nthread
else:
self.nthread = n_jobs
self.n_jobs = n_jobs
def __setstate__(self, state): def __setstate__(self, state):
# backward compatibility code # backward compatibility code
@ -178,6 +201,8 @@ class XGBModel(XGBModelBase):
def get_xgb_params(self): def get_xgb_params(self):
"""Get xgboost type parameters.""" """Get xgboost type parameters."""
xgb_params = self.get_params() xgb_params = self.get_params()
xgb_params.pop('random_state')
xgb_params.pop('n_jobs')
xgb_params['silent'] = 1 if self.silent else 0 xgb_params['silent'] = 1 if self.silent else 0
@ -360,17 +385,18 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def __init__(self, max_depth=3, learning_rate=0.1, def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, silent=True, n_estimators=100, silent=True,
objective="binary:logistic", booster='gbtree', objective="binary:logistic", booster='gbtree',
nthread=-1, gamma=0, min_child_weight=1, n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, seed=0, missing=None): base_score=0.5, random_state=0, seed=None, missing=None):
super(XGBClassifier, self).__init__(max_depth, learning_rate, super(XGBClassifier, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster, n_estimators, silent, objective, booster,
nthread, gamma, min_child_weight, n_jobs, nthread, gamma, min_child_weight,
max_delta_step, subsample, max_delta_step, subsample,
colsample_bytree, colsample_bylevel, colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda, reg_alpha, reg_lambda,
scale_pos_weight, base_score, seed, missing) scale_pos_weight, base_score,
random_state, seed, missing)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True): early_stopping_rounds=None, verbose=True):

View File

@ -2,6 +2,7 @@ import numpy as np
import random import random
import xgboost as xgb import xgboost as xgb
import testing as tm import testing as tm
import warnings
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -326,3 +327,39 @@ def test_split_value_histograms():
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2 assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2 assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2 assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
def test_sklearn_random_state():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(random_state=402)
assert clf.get_params()['seed'] == 402
clf = xgb.XGBClassifier(seed=401)
assert clf.get_params()['seed'] == 401
def test_seed_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(seed=1)
assert w[0].category == DeprecationWarning
def test_sklearn_n_jobs():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(n_jobs=1)
assert clf.get_params()['nthread'] == 1
clf = xgb.XGBClassifier(nthread=2)
assert clf.get_params()['nthread'] == 2
def test_nthread_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(nthread=1)
assert w[0].category == DeprecationWarning