* Remove duplicated code, which fixes typo `evals_result` -> `evals_result_`.
This commit is contained in:
parent
8e321adac8
commit
6a29afb480
@ -4,6 +4,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
|
from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||||
from .training import train
|
from .training import train
|
||||||
@ -494,6 +495,13 @@ class XGBModel(XGBModelBase):
|
|||||||
# Delete the attribute after load
|
# Delete the attribute after load
|
||||||
self.get_booster().set_attr(scikit_learn=None)
|
self.get_booster().set_attr(scikit_learn=None)
|
||||||
|
|
||||||
|
def _set_evaluation_result(self, evals_result: Optional[dict]) -> None:
|
||||||
|
if evals_result:
|
||||||
|
for val in evals_result.items():
|
||||||
|
evals_result_key = list(val[1].keys())[0]
|
||||||
|
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
||||||
|
self.evals_result_ = evals_result
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
||||||
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
||||||
@ -601,12 +609,7 @@ class XGBModel(XGBModelBase):
|
|||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
verbose_eval=verbose, xgb_model=xgb_model,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
if evals_result:
|
self._set_evaluation_result(evals_result)
|
||||||
for val in evals_result.items():
|
|
||||||
evals_result_key = list(val[1].keys())[0]
|
|
||||||
evals_result[val[0]][evals_result_key] = val[1][
|
|
||||||
evals_result_key]
|
|
||||||
self.evals_result_ = evals_result
|
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
if early_stopping_rounds is not None:
|
||||||
self.best_score = self._Booster.best_score
|
self.best_score = self._Booster.best_score
|
||||||
@ -919,12 +922,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
self.objective = xgb_options["objective"]
|
self.objective = xgb_options["objective"]
|
||||||
if evals_result:
|
self._set_evaluation_result(evals_result)
|
||||||
for val in evals_result.items():
|
|
||||||
evals_result_key = list(val[1].keys())[0]
|
|
||||||
evals_result[val[0]][
|
|
||||||
evals_result_key] = val[1][evals_result_key]
|
|
||||||
self.evals_result_ = evals_result
|
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
if early_stopping_rounds is not None:
|
||||||
self.best_score = self._Booster.best_score
|
self.best_score = self._Booster.best_score
|
||||||
@ -1328,12 +1326,7 @@ class XGBRanker(XGBModel):
|
|||||||
|
|
||||||
self.objective = params["objective"]
|
self.objective = params["objective"]
|
||||||
|
|
||||||
if evals_result:
|
self._set_evaluation_result(evals_result)
|
||||||
for val in evals_result.items():
|
|
||||||
evals_result_key = list(val[1].keys())[0]
|
|
||||||
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
|
||||||
self.evals_result = evals_result
|
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
if early_stopping_rounds is not None:
|
||||||
self.best_score = self._Booster.best_score
|
self.best_score = self._Booster.best_score
|
||||||
self.best_iteration = self._Booster.best_iteration
|
self.best_iteration = self._Booster.best_iteration
|
||||||
|
|||||||
@ -122,6 +122,8 @@ def test_ranking():
|
|||||||
model = xgb.sklearn.XGBRanker(**params)
|
model = xgb.sklearn.XGBRanker(**params)
|
||||||
model.fit(x_train, y_train, group=train_group,
|
model.fit(x_train, y_train, group=train_group,
|
||||||
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
|
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
|
||||||
|
assert model.evals_result()
|
||||||
|
|
||||||
pred = model.predict(x_test)
|
pred = model.predict(x_test)
|
||||||
|
|
||||||
train_data = xgb.DMatrix(x_train, y_train)
|
train_data = xgb.DMatrix(x_train, y_train)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user