Let XGBoostError inherit ValueError. (#5696)

This commit is contained in:
Jiaming Yuan 2020-05-26 08:34:56 +08:00 committed by GitHub
parent 8438c7d0e4
commit f145241593
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 21 deletions

View File

@ -26,7 +26,7 @@ from .libpath import find_lib_path
c_bst_ulong = ctypes.c_uint64
class XGBoostError(Exception):
class XGBoostError(ValueError):
"""Error thrown by xgboost trainer."""

View File

@ -537,16 +537,13 @@ class XGBModel(XGBModelBase):
else:
params.update({'eval_metric': eval_metric})
try:
self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result,
obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
except XGBoostError as e:
raise ValueError(e)
self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result,
obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
if evals_result:
for val in evals_result.items():
@ -1230,16 +1227,13 @@ class XGBRanker(XGBModel):
'Custom evaluation metric is not yet supported for XGBRanker.')
params.update({'eval_metric': eval_metric})
try:
self._Booster = train(params, train_dmatrix,
self.n_estimators,
early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
except XGBoostError as e:
raise ValueError(e)
self._Booster = train(params, train_dmatrix,
self.n_estimators,
early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
self.objective = params["objective"]