Allow using RandomState object from Numpy in sklearn interface. (#5049)
This commit is contained in:
parent
4d2779663e
commit
a4f5c86276
@ -230,6 +230,9 @@ class XGBModel(XGBModelBase):
|
|||||||
params['missing'] = None # sklearn doesn't handle nan. see #4725
|
params['missing'] = None # sklearn doesn't handle nan. see #4725
|
||||||
if not params.get('eval_metric', True):
|
if not params.get('eval_metric', True):
|
||||||
del params['eval_metric'] # don't give as None param to Booster
|
del params['eval_metric'] # don't give as None param to Booster
|
||||||
|
if isinstance(params['random_state'], np.random.RandomState):
|
||||||
|
params['random_state'] = params['random_state'].randint(
|
||||||
|
np.iinfo(np.int32).max)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def get_xgb_params(self):
|
def get_xgb_params(self):
|
||||||
|
|||||||
@ -450,6 +450,10 @@ def test_sklearn_random_state():
|
|||||||
clf = xgb.XGBClassifier(random_state=401)
|
clf = xgb.XGBClassifier(random_state=401)
|
||||||
assert clf.get_xgb_params()['random_state'] == 401
|
assert clf.get_xgb_params()['random_state'] == 401
|
||||||
|
|
||||||
|
random_state = np.random.RandomState(seed=403)
|
||||||
|
clf = xgb.XGBClassifier(random_state=random_state)
|
||||||
|
assert isinstance(clf.get_xgb_params()['random_state'], int)
|
||||||
|
|
||||||
|
|
||||||
def test_sklearn_n_jobs():
|
def test_sklearn_n_jobs():
|
||||||
clf = xgb.XGBClassifier(n_jobs=1)
|
clf = xgb.XGBClassifier(n_jobs=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user