diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 66efb01fa..39ad73415 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -230,6 +230,9 @@ class XGBModel(XGBModelBase): params['missing'] = None # sklearn doesn't handle nan. see #4725 if not params.get('eval_metric', True): del params['eval_metric'] # don't give as None param to Booster + if isinstance(params['random_state'], np.random.RandomState): + params['random_state'] = params['random_state'].randint( + np.iinfo(np.int32).max) return params def get_xgb_params(self): diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 43c1f2767..a098a97ae 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -450,6 +450,10 @@ def test_sklearn_random_state(): clf = xgb.XGBClassifier(random_state=401) assert clf.get_xgb_params()['random_state'] == 401 + random_state = np.random.RandomState(seed=403) + clf = xgb.XGBClassifier(random_state=random_state) + assert isinstance(clf.get_xgb_params()['random_state'], int) + def test_sklearn_n_jobs(): clf = xgb.XGBClassifier(n_jobs=1)