Sklearn kwargs (#2338)
* Added kwargs support for Sklearn API * Updated NEWS and CONTRIBUTORS * Fixed CONTRIBUTORS.md * Added clarification of **kwargs and test for proper usage * Fixed lint error * Fixed more lint errors and clf assigned but never used * Fixed more lint errors * Fixed more lint errors * Fixed issue with changes from different branch bleeding over * Fixed issue with changes from other branch bleeding over * Added note that kwargs may not be compatible with Sklearn * Fixed linting on kwargs note
This commit is contained in:
parent
6cea1e3fb7
commit
0f3a404d91
@ -65,3 +65,4 @@ List of Contributors
|
|||||||
* [Adam Pocock](https://github.com/Craigacp)
|
* [Adam Pocock](https://github.com/Craigacp)
|
||||||
* [Rory Mitchell](https://github.com/RAMitchell)
|
* [Rory Mitchell](https://github.com/RAMitchell)
|
||||||
- Rory is the author of the GPU plugin and also contributed the cmake build system and windows continuous integration
|
- Rory is the author of the GPU plugin and also contributed the cmake build system and windows continuous integration
|
||||||
|
* [Gideon Whitehead](https://github.com/gaw89)
|
||||||
3
NEWS.md
3
NEWS.md
@ -4,6 +4,9 @@ XGBoost Change Log
|
|||||||
This file records the changes in xgboost library in reverse chronological order.
|
This file records the changes in xgboost library in reverse chronological order.
|
||||||
|
|
||||||
## in progress version
|
## in progress version
|
||||||
|
* Updated Sklearn API
|
||||||
|
- Updated to allow use of all XGBoost parameters via **kwargs.
|
||||||
|
- Updated nthread to n_jobs and seed to random_state (as per Sklearn convention).
|
||||||
* Refactored gbm to allow more friendly cache strategy
|
* Refactored gbm to allow more friendly cache strategy
|
||||||
- Specialized some prediction routine
|
- Specialized some prediction routine
|
||||||
* Automatically remove nan from input data when it is sparse.
|
* Automatically remove nan from input data when it is sparse.
|
||||||
|
|||||||
@ -101,6 +101,14 @@ class XGBModel(XGBModelBase):
|
|||||||
missing : float, optional
|
missing : float, optional
|
||||||
Value in the data which needs to be present as a missing value. If
|
Value in the data which needs to be present as a missing value. If
|
||||||
None, defaults to np.nan.
|
None, defaults to np.nan.
|
||||||
|
**kwargs : dict, optional
|
||||||
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters can
|
||||||
|
be found here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md.
|
||||||
|
Attempting to set a parameter via the constructor args and **kwargs dict simultaneously
|
||||||
|
will result in a TypeError.
|
||||||
|
Note:
|
||||||
|
**kwargs is unsupported by Sklearn. We do not guarantee that parameters passed via
|
||||||
|
this argument will interact properly with Sklearn.
|
||||||
|
|
||||||
Note
|
Note
|
||||||
----
|
----
|
||||||
@ -124,7 +132,7 @@ class XGBModel(XGBModelBase):
|
|||||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
n_jobs=1, nthread=None, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||||
base_score=0.5, random_state=0, seed=None, missing=None):
|
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||||
if not SKLEARN_INSTALLED:
|
if not SKLEARN_INSTALLED:
|
||||||
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
raise XGBoostError('sklearn needs to be installed in order to use this module')
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
@ -133,7 +141,6 @@ class XGBModel(XGBModelBase):
|
|||||||
self.silent = silent
|
self.silent = silent
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
self.booster = booster
|
self.booster = booster
|
||||||
|
|
||||||
self.nthread = nthread
|
self.nthread = nthread
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.min_child_weight = min_child_weight
|
self.min_child_weight = min_child_weight
|
||||||
@ -146,6 +153,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self.scale_pos_weight = scale_pos_weight
|
self.scale_pos_weight = scale_pos_weight
|
||||||
self.base_score = base_score
|
self.base_score = base_score
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
|
self.kwargs = kwargs
|
||||||
self._Booster = None
|
self._Booster = None
|
||||||
if seed:
|
if seed:
|
||||||
warnings.warn('The seed parameter is deprecated as of version .6.'
|
warnings.warn('The seed parameter is deprecated as of version .6.'
|
||||||
@ -192,6 +200,8 @@ class XGBModel(XGBModelBase):
|
|||||||
def get_params(self, deep=False):
|
def get_params(self, deep=False):
|
||||||
"""Get parameter.s"""
|
"""Get parameter.s"""
|
||||||
params = super(XGBModel, self).get_params(deep=deep)
|
params = super(XGBModel, self).get_params(deep=deep)
|
||||||
|
if isinstance(self.kwargs, dict): # if kwargs is a dict, update params accordingly
|
||||||
|
params.update(self.kwargs)
|
||||||
if params['missing'] is np.nan:
|
if params['missing'] is np.nan:
|
||||||
params['missing'] = None # sklearn doesn't handle nan. see #4725
|
params['missing'] = None # sklearn doesn't handle nan. see #4725
|
||||||
if not params.get('eval_metric', True):
|
if not params.get('eval_metric', True):
|
||||||
@ -388,7 +398,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
|
n_jobs=1, nthread=None, gamma=0, min_child_weight=1,
|
||||||
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||||
base_score=0.5, random_state=0, seed=None, missing=None):
|
base_score=0.5, random_state=0, seed=None, missing=None, **kwargs):
|
||||||
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
||||||
n_estimators, silent, objective, booster,
|
n_estimators, silent, objective, booster,
|
||||||
n_jobs, nthread, gamma, min_child_weight,
|
n_jobs, nthread, gamma, min_child_weight,
|
||||||
@ -396,7 +406,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
colsample_bytree, colsample_bylevel,
|
colsample_bytree, colsample_bylevel,
|
||||||
reg_alpha, reg_lambda,
|
reg_alpha, reg_lambda,
|
||||||
scale_pos_weight, base_score,
|
scale_pos_weight, base_score,
|
||||||
random_state, seed, missing)
|
random_state, seed, missing, **kwargs)
|
||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True):
|
early_stopping_rounds=None, verbose=True):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import random
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import testing as tm
|
import testing as tm
|
||||||
import warnings
|
import warnings
|
||||||
|
from nose.tools import raises
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
@ -363,3 +364,22 @@ def test_nthread_deprecation():
|
|||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
xgb.XGBClassifier(nthread=1)
|
xgb.XGBClassifier(nthread=1)
|
||||||
assert w[0].category == DeprecationWarning
|
assert w[0].category == DeprecationWarning
|
||||||
|
|
||||||
|
|
||||||
|
def test_kwargs():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
|
||||||
|
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
||||||
|
clf = xgb.XGBClassifier(n_estimators=1000, **params)
|
||||||
|
assert clf.get_params()['updater'] == 'grow_gpu'
|
||||||
|
assert clf.get_params()['subsample'] == .5
|
||||||
|
assert clf.get_params()['n_estimators'] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
@raises(TypeError)
|
||||||
|
def test_kwargs_error():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
|
||||||
|
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
|
||||||
|
clf = xgb.XGBClassifier(n_jobs=1000, **params)
|
||||||
|
assert isinstance(clf, xgb.XGBClassifier)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user