[Python] Accept numpy generators as random_state (#9743)
* accept numpy generators for random_state * make linter happy * fix tests
This commit is contained in:
parent
4da4e092b5
commit
be20df8c23
@ -248,7 +248,7 @@ __model_doc = f"""
|
|||||||
Balancing of positive and negative weights.
|
Balancing of positive and negative weights.
|
||||||
base_score : Optional[float]
|
base_score : Optional[float]
|
||||||
The initial prediction score of all instances, global bias.
|
The initial prediction score of all instances, global bias.
|
||||||
random_state : Optional[Union[numpy.random.RandomState, int]]
|
random_state : Optional[Union[numpy.random.RandomState, numpy.random.Generator, int]]
|
||||||
Random number seed.
|
Random number seed.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -651,7 +651,9 @@ class XGBModel(XGBModelBase):
|
|||||||
reg_lambda: Optional[float] = None,
|
reg_lambda: Optional[float] = None,
|
||||||
scale_pos_weight: Optional[float] = None,
|
scale_pos_weight: Optional[float] = None,
|
||||||
base_score: Optional[float] = None,
|
base_score: Optional[float] = None,
|
||||||
random_state: Optional[Union[np.random.RandomState, int]] = None,
|
random_state: Optional[
|
||||||
|
Union[np.random.RandomState, np.random.Generator, int]
|
||||||
|
] = None,
|
||||||
missing: float = np.nan,
|
missing: float = np.nan,
|
||||||
num_parallel_tree: Optional[int] = None,
|
num_parallel_tree: Optional[int] = None,
|
||||||
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
|
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
|
||||||
@ -789,6 +791,10 @@ class XGBModel(XGBModelBase):
|
|||||||
params["random_state"] = params["random_state"].randint(
|
params["random_state"] = params["random_state"].randint(
|
||||||
np.iinfo(np.int32).max
|
np.iinfo(np.int32).max
|
||||||
)
|
)
|
||||||
|
elif isinstance(params["random_state"], np.random.Generator):
|
||||||
|
params["random_state"] = int(
|
||||||
|
params["random_state"].integers(np.iinfo(np.int32).max)
|
||||||
|
)
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|||||||
@ -702,6 +702,10 @@ def test_sklearn_random_state():
|
|||||||
clf = xgb.XGBClassifier(random_state=random_state)
|
clf = xgb.XGBClassifier(random_state=random_state)
|
||||||
assert isinstance(clf.get_xgb_params()['random_state'], int)
|
assert isinstance(clf.get_xgb_params()['random_state'], int)
|
||||||
|
|
||||||
|
random_state = np.random.default_rng(seed=404)
|
||||||
|
clf = xgb.XGBClassifier(random_state=random_state)
|
||||||
|
assert isinstance(clf.get_xgb_params()['random_state'], int)
|
||||||
|
|
||||||
|
|
||||||
def test_sklearn_n_jobs():
|
def test_sklearn_n_jobs():
|
||||||
clf = xgb.XGBClassifier(n_jobs=1)
|
clf = xgb.XGBClassifier(n_jobs=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user