Sklearn convention update (#2323)
* Added n_jobs and random_state to keep up to date with sklearn API. Deprecated nthread and seed. Added tests for new params and deprecations. * Fixed docstring to reflect updates to n_jobs and random_state. * Fixed whitespace issues and removed nose import. * Added deprecation note for nthread and seed in docstring. * Attempted fix of deprecation tests. * Second attempted fix to tests. * Set n_jobs to 1.
This commit is contained in:
parent
da1629e848
commit
6cea1e3fb7
@ -4,6 +4,7 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
from .training import train
|
||||
|
||||
@ -12,6 +13,8 @@ from .training import train
|
||||
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
||||
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
|
||||
|
||||
warnings.simplefilter('always', DeprecationWarning)
|
||||
|
||||
|
||||
def _objective_decorator(func):
|
||||
"""Decorate an objective function
|
||||
@ -68,7 +71,9 @@ class XGBModel(XGBModelBase):
|
||||
booster: string
|
||||
Specify which booster to use: gbtree, gblinear or dart.
|
||||
nthread : int
|
||||
Number of parallel threads used to run xgboost.
|
||||
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
|
||||
n_jobs : int
|
||||
Number of parallel threads used to run xgboost. (replaces nthread)
|
||||
gamma : float
|
||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||
min_child_weight : int
|
||||
@ -87,11 +92,12 @@ class XGBModel(XGBModelBase):
|
||||
L2 regularization term on weights
|
||||
scale_pos_weight : float
|
||||
Balancing of positive and negative weights.
|
||||
|
||||
base_score:
|
||||
The initial prediction score of all instances, global bias.
|
||||
seed : int
|
||||
Random number seed.
|
||||
Random number seed. (Deprecated, please use random_state)
|
||||
random_state : int
|
||||
Random number seed. (replaces seed)
|
||||
missing : float, optional
|
||||
Value in the data which needs to be present as a missing value. If
|
||||
None, defaults to np.nan.
|
||||
@ -115,10 +121,10 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||
silent=True, objective="reg:linear", booster='gbtree',
|
||||
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, seed=0, missing=None):
|
||||
base_score=0.5, random_state=0, seed=None, missing=None):
|
||||
if not SKLEARN_INSTALLED:
|
||||
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
||||
self.max_depth = max_depth
|
||||
@ -138,11 +144,28 @@ class XGBModel(XGBModelBase):
|
||||
self.reg_alpha = reg_alpha
|
||||
self.reg_lambda = reg_lambda
|
||||
self.scale_pos_weight = scale_pos_weight
|
||||
|
||||
self.base_score = base_score
|
||||
self.seed = seed
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self._Booster = None
|
||||
if seed:
|
||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||
'Please use random_state instead.'
|
||||
'seed is deprecated.', DeprecationWarning)
|
||||
self.seed = seed
|
||||
self.random_state = seed
|
||||
else:
|
||||
self.seed = random_state
|
||||
self.random_state = random_state
|
||||
|
||||
if nthread:
|
||||
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||
'Please use n_jobs instead.'
|
||||
'nthread is deprecated.', DeprecationWarning)
|
||||
self.nthread = nthread
|
||||
self.n_jobs = nthread
|
||||
else:
|
||||
self.nthread = n_jobs
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def __setstate__(self, state):
|
||||
# backward compatibility code
|
||||
@ -178,6 +201,8 @@ class XGBModel(XGBModelBase):
|
||||
def get_xgb_params(self):
|
||||
"""Get xgboost type parameters."""
|
||||
xgb_params = self.get_params()
|
||||
xgb_params.pop('random_state')
|
||||
xgb_params.pop('n_jobs')
|
||||
|
||||
xgb_params['silent'] = 1 if self.silent else 0
|
||||
|
||||
@ -360,17 +385,18 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
def __init__(self, max_depth=3, learning_rate=0.1,
|
||||
n_estimators=100, silent=True,
|
||||
objective="binary:logistic", booster='gbtree',
|
||||
nthread=-1, gamma=0, min_child_weight=1,
|
||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
|
||||
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||
base_score=0.5, seed=0, missing=None):
|
||||
base_score=0.5, random_state=0, seed=None, missing=None):
|
||||
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
||||
n_estimators, silent, objective, booster,
|
||||
nthread, gamma, min_child_weight,
|
||||
n_jobs, nthread, gamma, min_child_weight,
|
||||
max_delta_step, subsample,
|
||||
colsample_bytree, colsample_bylevel,
|
||||
reg_alpha, reg_lambda,
|
||||
scale_pos_weight, base_score, seed, missing)
|
||||
scale_pos_weight, base_score,
|
||||
random_state, seed, missing)
|
||||
|
||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True):
|
||||
|
||||
@ -2,6 +2,7 @@ import numpy as np
|
||||
import random
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import warnings
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@ -326,3 +327,39 @@ def test_split_value_histograms():
|
||||
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
|
||||
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
|
||||
|
||||
|
||||
def test_sklearn_random_state():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(random_state=402)
|
||||
assert clf.get_params()['seed'] == 402
|
||||
|
||||
clf = xgb.XGBClassifier(seed=401)
|
||||
assert clf.get_params()['seed'] == 401
|
||||
|
||||
|
||||
def test_seed_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(seed=1)
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
|
||||
def test_sklearn_n_jobs():
|
||||
tm._skip_if_no_sklearn()
|
||||
|
||||
clf = xgb.XGBClassifier(n_jobs=1)
|
||||
assert clf.get_params()['nthread'] == 1
|
||||
|
||||
clf = xgb.XGBClassifier(nthread=2)
|
||||
assert clf.get_params()['nthread'] == 2
|
||||
|
||||
|
||||
def test_nthread_deprecation():
|
||||
tm._skip_if_no_sklearn()
|
||||
warnings.simplefilter("always")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
xgb.XGBClassifier(nthread=1)
|
||||
assert w[0].category == DeprecationWarning
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user