Fix evaluation result for XGBRanker. (#6594)

* Remove duplicated code, which fixes typo `evals_result` -> `evals_result_`.
This commit is contained in:
Jiaming Yuan
2021-01-12 09:36:41 +08:00
committed by GitHub
parent f2f7dd87b8
commit c709f2aaaf
2 changed files with 11 additions and 21 deletions

View File

@@ -572,6 +572,13 @@ class XGBModel(XGBModelBase):
params.update({"eval_metric": eval_metric})
return model, feval, params
def _set_evaluation_result(self, evals_result: Optional[dict]) -> None:
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
self.evals_result_ = evals_result
@_deprecate_positional_args
def fit(self, X, y, *, sample_weight=None, base_margin=None,
eval_set=None, eval_metric=None, early_stopping_rounds=None,
@@ -678,13 +685,7 @@ class XGBModel(XGBModelBase):
verbose_eval=verbose, xgb_model=model,
callbacks=callbacks)
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][
evals_result_key]
self.evals_result_ = evals_result
self._set_evaluation_result(evals_result)
return self
def predict(self, data, output_margin=False, ntree_limit=None,
@@ -1035,13 +1036,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
if not callable(self.objective):
self.objective = params["objective"]
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][
evals_result_key] = val[1][evals_result_key]
self.evals_result_ = evals_result
self._set_evaluation_result(evals_result)
return self
fit.__doc__ = XGBModel.fit.__doc__.replace(
@@ -1502,13 +1497,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
callbacks=callbacks)
self.objective = params["objective"]
if evals_result:
for val in evals_result.items():
evals_result_key = list(val[1].keys())[0]
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
self.evals_result = evals_result
self._set_evaluation_result(evals_result)
return self
def predict(self, data, output_margin=False,