From be20df8c23c063f9b5ff242e66c29ebd66578ca6 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Thu, 2 Nov 2023 00:20:44 +0100 Subject: [PATCH] [Python] Accept numpy generators as `random_state` (#9743) * accept numpy generators for random_state * make linter happy * fix tests --- python-package/xgboost/sklearn.py | 10 ++++++++-- tests/python/test_with_sklearn.py | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index cb738477b..d5e20439a 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -248,7 +248,7 @@ __model_doc = f""" Balancing of positive and negative weights. base_score : Optional[float] The initial prediction score of all instances, global bias. - random_state : Optional[Union[numpy.random.RandomState, int]] + random_state : Optional[Union[numpy.random.RandomState, numpy.random.Generator, int]] Random number seed. .. note:: @@ -651,7 +651,9 @@ class XGBModel(XGBModelBase): reg_lambda: Optional[float] = None, scale_pos_weight: Optional[float] = None, base_score: Optional[float] = None, - random_state: Optional[Union[np.random.RandomState, int]] = None, + random_state: Optional[ + Union[np.random.RandomState, np.random.Generator, int] + ] = None, missing: float = np.nan, num_parallel_tree: Optional[int] = None, monotone_constraints: Optional[Union[Dict[str, int], str]] = None, @@ -789,6 +791,10 @@ class XGBModel(XGBModelBase): params["random_state"] = params["random_state"].randint( np.iinfo(np.int32).max ) + elif isinstance(params["random_state"], np.random.Generator): + params["random_state"] = int( + params["random_state"].integers(np.iinfo(np.int32).max) + ) return params diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index b40ae67c5..c919a01ad 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -702,6 +702,10 @@ def test_sklearn_random_state(): clf = xgb.XGBClassifier(random_state=random_state) assert isinstance(clf.get_xgb_params()['random_state'], int) + random_state = np.random.default_rng(seed=404) + clf = xgb.XGBClassifier(random_state=random_state) + assert isinstance(clf.get_xgb_params()['random_state'], int) + def test_sklearn_n_jobs(): clf = xgb.XGBClassifier(n_jobs=1)