Cleanup the train function. (#7377)
* Move attribute setter to callback. * Remove the internal train function. * Remove unnecessary initialization.
This commit is contained in:
parent
154b15060e
commit
c74df31bf9
@ -11,7 +11,7 @@ from typing import Sequence
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
||||
from .compat import STRING_TYPES
|
||||
|
||||
|
||||
@ -149,6 +149,25 @@ class CallbackContainer:
|
||||
assert isinstance(model.cvfolds, list), msg
|
||||
else:
|
||||
assert isinstance(model, Booster), msg
|
||||
|
||||
if not self.is_cv:
|
||||
num_parallel_tree, _ = _get_booster_layer_trees(model)
|
||||
if model.attr('best_score') is not None:
|
||||
model.best_score = float(cast(str, model.attr('best_score')))
|
||||
model.best_iteration = int(cast(str, model.attr('best_iteration')))
|
||||
# num_class is handled internally
|
||||
model.set_attr(
|
||||
best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree)
|
||||
)
|
||||
model.best_ntree_limit = int(cast(str, model.attr("best_ntree_limit")))
|
||||
else:
|
||||
# Due to compatibility with version older than 1.4, these attributes are
|
||||
# added to Python object even if early stopping is not used.
|
||||
model.best_iteration = model.num_boosted_rounds() - 1
|
||||
model.set_attr(best_iteration=str(model.best_iteration))
|
||||
model.best_ntree_limit = (model.best_iteration + 1) * num_parallel_tree
|
||||
model.set_attr(best_ntree_limit=str(model.best_ntree_limit))
|
||||
|
||||
return model
|
||||
|
||||
def before_iteration(
|
||||
|
||||
@ -5,10 +5,10 @@
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Union, Tuple, cast, Sequence
|
||||
from typing import Optional, Dict, Any, Union, Tuple, Sequence
|
||||
|
||||
import numpy as np
|
||||
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||
from .core import Metric, Objective
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
from . import callback
|
||||
@ -45,93 +45,12 @@ def _configure_custom_metric(
|
||||
return eval_metric
|
||||
|
||||
|
||||
def _train_internal(
|
||||
params: Dict[str, Any],
|
||||
dtrain: DMatrix,
|
||||
num_boost_round: int = 10,
|
||||
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
|
||||
obj: Optional[Objective] = None,
|
||||
feval: Optional[Metric] = None,
|
||||
custom_metric: Optional[Metric] = None,
|
||||
xgb_model: Optional[Union[str, os.PathLike, Booster, bytearray]] = None,
|
||||
callbacks: Optional[Sequence[callback.TrainingCallback]] = None,
|
||||
evals_result: callback.TrainingCallback.EvalsLog = None,
|
||||
maximize: Optional[bool] = None,
|
||||
verbose_eval: Optional[Union[bool, int]] = True,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
) -> Booster:
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
|
||||
metric_fn = _configure_custom_metric(feval, custom_metric)
|
||||
evals = list(evals) if evals else []
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
|
||||
if xgb_model is not None:
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||
|
||||
start_iteration = 0
|
||||
|
||||
_assert_new_callback(callbacks)
|
||||
if verbose_eval:
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(
|
||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
cb_container = callback.CallbackContainer(
|
||||
callbacks,
|
||||
metric=metric_fn,
|
||||
# For old `feval` parameter, the behavior is unchanged. For the new
|
||||
# `custom_metric`, it will receive proper prediction result when custom objective
|
||||
# is not used.
|
||||
output_margin=callable(obj) or metric_fn is feval,
|
||||
)
|
||||
|
||||
bst = cb_container.before_training(bst)
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if cb_container.before_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
bst.update(dtrain, i, obj)
|
||||
if cb_container.after_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
|
||||
bst = cb_container.after_training(bst)
|
||||
|
||||
if evals_result is not None:
|
||||
evals_result.update(cb_container.history)
|
||||
|
||||
# These should be moved into callback functions `after_training`, but until old
|
||||
# callbacks are removed, the train function is the only place for setting the
|
||||
# attributes.
|
||||
num_parallel_tree, _ = _get_booster_layer_trees(bst)
|
||||
if bst.attr('best_score') is not None:
|
||||
bst.best_score = float(cast(str, bst.attr('best_score')))
|
||||
bst.best_iteration = int(cast(str, bst.attr('best_iteration')))
|
||||
# num_class is handled internally
|
||||
bst.set_attr(
|
||||
best_ntree_limit=str((bst.best_iteration + 1) * num_parallel_tree)
|
||||
)
|
||||
bst.best_ntree_limit = int(cast(str, bst.attr("best_ntree_limit")))
|
||||
else:
|
||||
# Due to compatibility with version older than 1.4, these attributes are added
|
||||
# to Python object even if early stopping is not used.
|
||||
bst.best_iteration = bst.num_boosted_rounds() - 1
|
||||
bst.set_attr(best_iteration=str(bst.best_iteration))
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
bst.set_attr(best_ntree_limit=str(bst.best_ntree_limit))
|
||||
|
||||
# Copy to serialise and unserialise booster to reset state and free
|
||||
# training memory
|
||||
return bst.copy()
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
def train(
|
||||
params: Dict[str, Any],
|
||||
dtrain: DMatrix,
|
||||
num_boost_round: int = 10,
|
||||
*,
|
||||
evals: Optional[Sequence[Tuple[DMatrix, str]]] = None,
|
||||
obj: Optional[Objective] = None,
|
||||
feval: Optional[Metric] = None,
|
||||
@ -223,22 +142,48 @@ def train(
|
||||
-------
|
||||
Booster : a trained booster model
|
||||
"""
|
||||
bst = _train_internal(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
evals=evals,
|
||||
obj=obj,
|
||||
feval=feval,
|
||||
xgb_model=xgb_model,
|
||||
callbacks=callbacks,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
maximize=maximize,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
custom_metric=custom_metric,
|
||||
|
||||
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
|
||||
metric_fn = _configure_custom_metric(feval, custom_metric)
|
||||
evals = list(evals) if evals else []
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||
start_iteration = 0
|
||||
|
||||
_assert_new_callback(callbacks)
|
||||
if verbose_eval:
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(
|
||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
cb_container = callback.CallbackContainer(
|
||||
callbacks,
|
||||
metric=metric_fn,
|
||||
# For old `feval` parameter, the behavior is unchanged. For the new
|
||||
# `custom_metric`, it will receive proper prediction result when custom objective
|
||||
# is not used.
|
||||
output_margin=callable(obj) or metric_fn is feval,
|
||||
)
|
||||
return bst
|
||||
|
||||
bst = cb_container.before_training(bst)
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if cb_container.before_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
bst.update(dtrain, i, obj)
|
||||
if cb_container.after_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
|
||||
bst = cb_container.after_training(bst)
|
||||
|
||||
if evals_result is not None:
|
||||
evals_result.update(cb_container.history)
|
||||
|
||||
# Copy to serialise and unserialise booster to reset state and free
|
||||
# training memory
|
||||
return bst.copy()
|
||||
|
||||
|
||||
class CVPack:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user