Merge pull request #563 from Far0n/eta_decay
learning_rates per boosting round
This commit is contained in:
commit
c16a6222f3
@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from .core import Booster, STRING_TYPES
|
from .core import Booster, STRING_TYPES
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
early_stopping_rounds=None, evals_result=None, verbose_eval=True):
|
early_stopping_rounds=None, evals_result=None, verbose_eval=True, learning_rates=None):
|
||||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||||
"""Train a booster with given parameters.
|
"""Train a booster with given parameters.
|
||||||
|
|
||||||
@ -46,6 +46,10 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
verbose_eval : bool
|
verbose_eval : bool
|
||||||
If `verbose_eval` then the evaluation metric on the validation set, if
|
If `verbose_eval` then the evaluation metric on the validation set, if
|
||||||
given, is printed at each boosting stage.
|
given, is printed at each boosting stage.
|
||||||
|
learning_rates: list or function
|
||||||
|
Learning rate for each boosting round (yields learning rate decay).
|
||||||
|
- list l: eta = l[boosting round]
|
||||||
|
- function f: eta = f(boosting round, num_boost_round)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -119,7 +123,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
best_msg = ''
|
best_msg = ''
|
||||||
best_score_i = 0
|
best_score_i = 0
|
||||||
|
|
||||||
|
if isinstance(learning_rates, list) and len(learning_rates) < num_boost_round:
|
||||||
|
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||||
|
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
|
if learning_rates is not None:
|
||||||
|
if isinstance(learning_rates, list):
|
||||||
|
bst.set_param({'eta': learning_rates[i]})
|
||||||
|
else:
|
||||||
|
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
||||||
bst.update(dtrain, i, obj)
|
bst.update(dtrain, i, obj)
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user