Sklearn convention update (#2323)

* Added n_jobs and random_state to keep up to date with sklearn API.
Deprecated nthread and seed.  Added tests for new params and
deprecations.

* Fixed docstring to reflect updates to n_jobs and random_state.

* Fixed whitespace issues and removed nose import.

* Added deprecation note for nthread and seed in docstring.

* Attempted fix of deprecation tests.

* Second attempted fix to tests.

* Set n_jobs to 1.
This commit is contained in:
gaw89 2017-05-22 09:22:05 -04:00 committed by Yuan (Terry) Tang
parent da1629e848
commit 6cea1e3fb7
2 changed files with 74 additions and 11 deletions

View File

@ -4,6 +4,7 @@
from __future__ import absolute_import
import numpy as np
import warnings
from .core import Booster, DMatrix, XGBoostError
from .training import train
@ -12,6 +13,8 @@ from .training import train
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
warnings.simplefilter('always', DeprecationWarning)
def _objective_decorator(func):
"""Decorate an objective function
@ -68,7 +71,9 @@ class XGBModel(XGBModelBase):
booster: string
Specify which booster to use: gbtree, gblinear or dart.
nthread : int
Number of parallel threads used to run xgboost.
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
n_jobs : int
Number of parallel threads used to run xgboost. (replaces nthread)
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
@ -87,11 +92,12 @@ class XGBModel(XGBModelBase):
L2 regularization term on weights
scale_pos_weight : float
Balancing of positive and negative weights.
base_score:
The initial prediction score of all instances, global bias.
seed : int
Random number seed.
Random number seed. (Deprecated, please use random_state)
random_state : int
Random number seed. (replaces seed)
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
@ -115,10 +121,10 @@ class XGBModel(XGBModelBase):
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear", booster='gbtree',
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, seed=0, missing=None):
base_score=0.5, random_state=0, seed=None, missing=None):
if not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth
@ -138,11 +144,28 @@ class XGBModel(XGBModelBase):
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
self.base_score = base_score
self.seed = seed
self.missing = missing if missing is not None else np.nan
self._Booster = None
if seed:
warnings.warn('The seed parameter is deprecated as of version .6.'
'Please use random_state instead.'
'seed is deprecated.', DeprecationWarning)
self.seed = seed
self.random_state = seed
else:
self.seed = random_state
self.random_state = random_state
if nthread:
warnings.warn('The nthread parameter is deprecated as of version .6.'
'Please use n_jobs instead.'
'nthread is deprecated.', DeprecationWarning)
self.nthread = nthread
self.n_jobs = nthread
else:
self.nthread = n_jobs
self.n_jobs = n_jobs
def __setstate__(self, state):
# backward compatibility code
@ -178,6 +201,8 @@ class XGBModel(XGBModelBase):
def get_xgb_params(self):
"""Get xgboost type parameters."""
xgb_params = self.get_params()
xgb_params.pop('random_state')
xgb_params.pop('n_jobs')
xgb_params['silent'] = 1 if self.silent else 0
@ -360,17 +385,18 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, silent=True,
objective="binary:logistic", booster='gbtree',
nthread=-1, gamma=0, min_child_weight=1,
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, seed=0, missing=None):
base_score=0.5, random_state=0, seed=None, missing=None):
super(XGBClassifier, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective, booster,
nthread, gamma, min_child_weight,
n_jobs, nthread, gamma, min_child_weight,
max_delta_step, subsample,
colsample_bytree, colsample_bylevel,
reg_alpha, reg_lambda,
scale_pos_weight, base_score, seed, missing)
scale_pos_weight, base_score,
random_state, seed, missing)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):

View File

@ -2,6 +2,7 @@ import numpy as np
import random
import xgboost as xgb
import testing as tm
import warnings
rng = np.random.RandomState(1994)
@ -326,3 +327,39 @@ def test_split_value_histograms():
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
def test_sklearn_random_state():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(random_state=402)
assert clf.get_params()['seed'] == 402
clf = xgb.XGBClassifier(seed=401)
assert clf.get_params()['seed'] == 401
def test_seed_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(seed=1)
assert w[0].category == DeprecationWarning
def test_sklearn_n_jobs():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(n_jobs=1)
assert clf.get_params()['nthread'] == 1
clf = xgb.XGBClassifier(nthread=2)
assert clf.get_params()['nthread'] == 2
def test_nthread_deprecation():
tm._skip_if_no_sklearn()
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as w:
xgb.XGBClassifier(nthread=1)
assert w[0].category == DeprecationWarning