Cleanup the train function. (#7377)

* Move attribute setter to callback.
* Remove the internal train function.
* Remove unnecessary initialization.
This commit is contained in:
Jiaming Yuan 2021-11-02 18:00:26 +08:00 committed by GitHub
parent 154b15060e
commit c74df31bf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 101 deletions

View File

@ -11,7 +11,7 @@ from typing import Sequence
import numpy
from . import rabit
from .core import Booster, DMatrix, XGBoostError
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
from .compat import STRING_TYPES
@ -149,6 +149,25 @@ class CallbackContainer:
assert isinstance(model.cvfolds, list), msg
else:
assert isinstance(model, Booster), msg
if not self.is_cv:
num_parallel_tree, _ = _get_booster_layer_trees(model)
if model.attr('best_score') is not None:
model.best_score = float(cast(str, model.attr('best_score')))
model.best_iteration = int(cast(str, model.attr('best_iteration')))
# num_class is handled internally
model.set_attr(
best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree)
)
model.best_ntree_limit = int(cast(str, model.attr("best_ntree_limit")))
else:
# Due to compatibility with version older than 1.4, these attributes are
# added to Python object even if early stopping is not used.
model.best_iteration = model.num_boosted_rounds() - 1
model.set_attr(best_iteration=str(model.best_iteration))
model.best_ntree_limit = (model.best_iteration + 1) * num_parallel_tree
model.set_attr(best_ntree_limit=str(model.best_ntree_limit))
return model
def before_iteration(

View File

@ -5,10 +5,10 @@
import copy
import os
import warnings
from typing import Optional, Dict, Any, Union, Tuple, cast, Sequence
from typing import Optional, Dict, Any, Union, Tuple, Sequence
import numpy as np
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
from .core import Metric, Objective
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import callback
@ -45,93 +45,12 @@ def _configure_custom_metric(
return eval_metric
def _train_internal(
params: Dict[str, Any],
dtrain: DMatrix,
num_boost_round: int = 10,
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
obj: Optional[Objective] = None,
feval: Optional[Metric] = None,
custom_metric: Optional[Metric] = None,
xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
evals_result: callback.TrainingCallback.EvalsLog = None,
maximize: Optional[bool] = None,
verbose_eval: Optional[Union[bool, int]] = True,
early_stopping_rounds: Optional[int] = None,
) -> Booster:
"""internal training function"""
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
metric_fn = _configure_custom_metric(feval, custom_metric)
evals = list(evals) if evals else []
bst = Booster(params, [dtrain] + [d[0] for d in evals])
if xgb_model is not None:
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
start_iteration = 0
_assert_new_callback(callbacks)
if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
if early_stopping_rounds:
callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
cb_container = callback.CallbackContainer(
callbacks,
metric=metric_fn,
# For old `feval` parameter, the behavior is unchanged. For the new
# `custom_metric`, it will receive proper prediction result when custom objective
# is not used.
output_margin=callable(obj) or metric_fn is feval,
)
bst = cb_container.before_training(bst)
for i in range(start_iteration, num_boost_round):
if cb_container.before_iteration(bst, i, dtrain, evals):
break
bst.update(dtrain, i, obj)
if cb_container.after_iteration(bst, i, dtrain, evals):
break
bst = cb_container.after_training(bst)
if evals_result is not None:
evals_result.update(cb_container.history)
# These should be moved into callback functions `after_training`, but until old
# callbacks are removed, the train function is the only place for setting the
# attributes.
num_parallel_tree, _ = _get_booster_layer_trees(bst)
if bst.attr('best_score') is not None:
bst.best_score = float(cast(str, bst.attr('best_score')))
bst.best_iteration = int(cast(str, bst.attr('best_iteration')))
# num_class is handled internally
bst.set_attr(
best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree)
)
bst.best_ntree_limit = int(cast(str, bst.attr("best_ntree_limit")))
else:
# Due to compatibility with version older than 1.4, these attributes are added
# to Python object even if early stopping is not used.
bst.best_iteration = bst.num_boosted_rounds() - 1
bst.set_attr(best_iteration=str(bst.best_iteration))
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
bst.set_attr(best_ntree_limit=str(bst.best_ntree_limit))
# Copy to serialise and unserialise booster to reset state and free
# training memory
return bst.copy()
@_deprecate_positional_args
def train(
params: Dict[str, Any],
dtrain: DMatrix,
num_boost_round: int = 10,
*,
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
obj: Optional[Objective] = None,
feval: Optional[Metric] = None,
@ -223,22 +142,48 @@ def train(
-------
Booster : a trained booster model
"""
bst = _train_internal(
params,
dtrain,
num_boost_round=num_boost_round,
evals=evals,
obj=obj,
feval=feval,
xgb_model=xgb_model,
callbacks=callbacks,
verbose_eval=verbose_eval,
evals_result=evals_result,
maximize=maximize,
early_stopping_rounds=early_stopping_rounds,
custom_metric=custom_metric,
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
metric_fn = _configure_custom_metric(feval, custom_metric)
evals = list(evals) if evals else []
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
start_iteration = 0
_assert_new_callback(callbacks)
if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
if early_stopping_rounds:
callbacks.append(
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
cb_container = callback.CallbackContainer(
callbacks,
metric=metric_fn,
# For old `feval` parameter, the behavior is unchanged. For the new
# `custom_metric`, it will receive proper prediction result when custom objective
# is not used.
output_margin=callable(obj) or metric_fn is feval,
)
return bst
bst = cb_container.before_training(bst)
for i in range(start_iteration, num_boost_round):
if cb_container.before_iteration(bst, i, dtrain, evals):
break
bst.update(dtrain, i, obj)
if cb_container.after_iteration(bst, i, dtrain, evals):
break
bst = cb_container.after_training(bst)
if evals_result is not None:
evals_result.update(cb_container.history)
# Copy to serialise and unserialise booster to reset state and free
# training memory
return bst.copy()
class CVPack: