Sklearn convention update (#2323)
* Added n_jobs and random_state to keep up to date with sklearn API. Deprecated nthread and seed. Added tests for new params and deprecations. * Fixed docstring to reflect updates to n_jobs and random_state. * Fixed whitespace issues and removed nose import. * Added deprecation note for nthread and seed in docstring. * Attempted fix of deprecation tests. * Second attempted fix to tests. * Set n_jobs to 1.
This commit is contained in:
parent
da1629e848
commit
6cea1e3fb7
@ -4,6 +4,7 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import warnings
|
||||||
from .core import Booster, DMatrix, XGBoostError
|
from .core import Booster, DMatrix, XGBoostError
|
||||||
from .training import train
|
from .training import train
|
||||||
|
|
||||||
@ -12,6 +13,8 @@ from .training import train
|
|||||||
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
|
||||||
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
|
XGBClassifierBase, XGBRegressorBase, XGBLabelEncoder)
|
||||||
|
|
||||||
|
warnings.simplefilter('always', DeprecationWarning)
|
||||||
|
|
||||||
|
|
||||||
def _objective_decorator(func):
|
def _objective_decorator(func):
|
||||||
"""Decorate an objective function
|
"""Decorate an objective function
|
||||||
@ -68,7 +71,9 @@ class XGBModel(XGBModelBase):
|
|||||||
booster: string
|
booster: string
|
||||||
Specify which booster to use: gbtree, gblinear or dart.
|
Specify which booster to use: gbtree, gblinear or dart.
|
||||||
nthread : int
|
nthread : int
|
||||||
Number of parallel threads used to run xgboost.
|
Number of parallel threads used to run xgboost. (Deprecated, please use n_jobs)
|
||||||
|
n_jobs : int
|
||||||
|
Number of parallel threads used to run xgboost. (replaces nthread)
|
||||||
gamma : float
|
gamma : float
|
||||||
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
Minimum loss reduction required to make a further partition on a leaf node of the tree.
|
||||||
min_child_weight : int
|
min_child_weight : int
|
||||||
@ -87,11 +92,12 @@ class XGBModel(XGBModelBase):
|
|||||||
L2 regularization term on weights
|
L2 regularization term on weights
|
||||||
scale_pos_weight : float
|
scale_pos_weight : float
|
||||||
Balancing of positive and negative weights.
|
Balancing of positive and negative weights.
|
||||||
|
|
||||||
base_score:
|
base_score:
|
||||||
The initial prediction score of all instances, global bias.
|
The initial prediction score of all instances, global bias.
|
||||||
seed : int
|
seed : int
|
||||||
Random number seed.
|
Random number seed. (Deprecated, please use random_state)
|
||||||
|
random_state : int
|
||||||
|
Random number seed. (replaces seed)
|
||||||
missing : float, optional
|
missing : float, optional
|
||||||
Value in the data which needs to be present as a missing value. If
|
Value in the data which needs to be present as a missing value. If
|
||||||
None, defaults to np.nan.
|
None, defaults to np.nan.
|
||||||
@ -115,10 +121,10 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||||
silent=True, objective="reg:linear", booster='gbtree',
|
silent=True, objective="reg:linear", booster='gbtree',
|
||||||
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||||
base_score=0.5, seed=0, missing=None):
|
base_score=0.5, random_state=0, seed=None, missing=None):
|
||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
@ -138,11 +144,28 @@ class XGBModel(XGBModelBase):
|
|||||||
self.reg_alpha = reg_alpha
|
self.reg_alpha = reg_alpha
|
||||||
self.reg_lambda = reg_lambda
|
self.reg_lambda = reg_lambda
|
||||||
self.scale_pos_weight = scale_pos_weight
|
self.scale_pos_weight = scale_pos_weight
|
||||||
|
|
||||||
self.base_score = base_score
|
self.base_score = base_score
|
||||||
self.seed = seed
|
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
self._Booster = None
|
self._Booster = None
|
||||||
|
if seed:
|
||||||
|
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||||
|
'Please use random_state instead.'
|
||||||
|
'seed is deprecated.', DeprecationWarning)
|
||||||
|
self.seed = seed
|
||||||
|
self.random_state = seed
|
||||||
|
else:
|
||||||
|
self.seed = random_state
|
||||||
|
self.random_state = random_state
|
||||||
|
|
||||||
|
if nthread:
|
||||||
|
warnings.warn('The nthread parameter is deprecated as of version .6.'
|
||||||
|
'Please use n_jobs instead.'
|
||||||
|
'nthread is deprecated.', DeprecationWarning)
|
||||||
|
self.nthread = nthread
|
||||||
|
self.n_jobs = nthread
|
||||||
|
else:
|
||||||
|
self.nthread = n_jobs
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
# backward compatibility code
|
# backward compatibility code
|
||||||
@ -178,6 +201,8 @@ class XGBModel(XGBModelBase):
|
|||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
"""Get xgboost type parameters."""
|
"""Get xgboost type parameters."""
|
||||||
xgb_params = self.get_params()
|
xgb_params = self.get_params()
|
||||||
|
xgb_params.pop('random_state')
|
||||||
|
xgb_params.pop('n_jobs')
|
||||||
|
|
||||||
xgb_params['silent'] = 1 if self.silent else 0
|
xgb_params['silent'] = 1 if self.silent else 0
|
||||||
|
|
||||||
@ -360,17 +385,18 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
def __init__(self, max_depth=3, learning_rate=0.1,
|
def __init__(self, max_depth=3, learning_rate=0.1,
|
||||||
n_estimators=100, silent=True,
|
n_estimators=100, silent=True,
|
||||||
objective="binary:logistic", booster='gbtree',
|
objective="binary:logistic", booster='gbtree',
|
||||||
nthread=-1, gamma=0, min_child_weight=1,
|
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
|
||||||
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||||
base_score=0.5, seed=0, missing=None):
|
base_score=0.5, random_state=0, seed=None, missing=None):
|
||||||
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
||||||
n_estimators, silent, objective, booster,
|
n_estimators, silent, objective, booster,
|
||||||
nthread, gamma, min_child_weight,
|
n_jobs, nthread, gamma, min_child_weight,
|
||||||
max_delta_step, subsample,
|
max_delta_step, subsample,
|
||||||
colsample_bytree, colsample_bylevel,
|
colsample_bytree, colsample_bylevel,
|
||||||
reg_alpha, reg_lambda,
|
reg_alpha, reg_lambda,
|
||||||
scale_pos_weight, base_score, seed, missing)
|
scale_pos_weight, base_score,
|
||||||
|
random_state, seed, missing)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True):
|
early_stopping_rounds=None, verbose=True):
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import testing as tm
|
import testing as tm
|
||||||
|
import warnings
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
@ -326,3 +327,39 @@ def test_split_value_histograms():
|
|||||||
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
||||||
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
|
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
|
||||||
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
|
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sklearn_random_state():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(random_state=402)
|
||||||
|
assert clf.get_params()['seed'] == 402
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(seed=401)
|
||||||
|
assert clf.get_params()['seed'] == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_deprecation():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
xgb.XGBClassifier(seed=1)
|
||||||
|
assert w[0].category == DeprecationWarning
|
||||||
|
|
||||||
|
|
||||||
|
def test_sklearn_n_jobs():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(n_jobs=1)
|
||||||
|
assert clf.get_params()['nthread'] == 1
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(nthread=2)
|
||||||
|
assert clf.get_params()['nthread'] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_nthread_deprecation():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
xgb.XGBClassifier(nthread=1)
|
||||||
|
assert w[0].category == DeprecationWarning
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user