Merge pull request #587 from Far0n/py_train
python training continuation & maximize parameter
This commit is contained in:
commit
deb802b2be
@ -10,7 +10,8 @@ import numpy as np
|
|||||||
from .core import Booster, STRING_TYPES
|
from .core import Booster, STRING_TYPES
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
early_stopping_rounds=None, evals_result=None, verbose_eval=True, learning_rates=None):
|
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||||
|
verbose_eval=True, learning_rates=None, xgb_model=None):
|
||||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||||
"""Train a booster with given parameters.
|
"""Train a booster with given parameters.
|
||||||
|
|
||||||
@ -29,6 +30,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
Customized objective function.
|
Customized objective function.
|
||||||
feval : function
|
feval : function
|
||||||
Customized evaluation function.
|
Customized evaluation function.
|
||||||
|
maximize : bool
|
||||||
|
Whether to maximize feval.
|
||||||
early_stopping_rounds: int
|
early_stopping_rounds: int
|
||||||
Activates early stopping. Validation error needs to decrease at least
|
Activates early stopping. Validation error needs to decrease at least
|
||||||
every <early_stopping_rounds> round(s) to continue training.
|
every <early_stopping_rounds> round(s) to continue training.
|
||||||
@ -50,13 +53,23 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
Learning rate for each boosting round (yields learning rate decay).
|
Learning rate for each boosting round (yields learning rate decay).
|
||||||
- list l: eta = l[boosting round]
|
- list l: eta = l[boosting round]
|
||||||
- function f: eta = f(boosting round, num_boost_round)
|
- function f: eta = f(boosting round, num_boost_round)
|
||||||
|
xgb_model : file name of stored xgb model or 'Booster' instance
|
||||||
|
Xgb model to be loaded before training (allows training continuation).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
booster : a trained booster model
|
booster : a trained booster model
|
||||||
"""
|
"""
|
||||||
evals = list(evals)
|
evals = list(evals)
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
ntrees = 0
|
||||||
|
if xgb_model is not None:
|
||||||
|
if not isinstance(xgb_model, STRING_TYPES):
|
||||||
|
xgb_model = xgb_model.save_raw()
|
||||||
|
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||||
|
ntrees = len(bst.get_dump())
|
||||||
|
else:
|
||||||
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
|
|
||||||
|
|
||||||
if evals_result is not None:
|
if evals_result is not None:
|
||||||
if not isinstance(evals_result, dict):
|
if not isinstance(evals_result, dict):
|
||||||
@ -69,6 +82,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
if not early_stopping_rounds:
|
if not early_stopping_rounds:
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
bst.update(dtrain, i, obj)
|
bst.update(dtrain, i, obj)
|
||||||
|
ntrees += 1
|
||||||
if len(evals) != 0:
|
if len(evals) != 0:
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
if isinstance(bst_eval_set, STRING_TYPES):
|
if isinstance(bst_eval_set, STRING_TYPES):
|
||||||
@ -91,6 +105,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
evals_result[key][res_key].append(res_val)
|
evals_result[key][res_key].append(res_val)
|
||||||
else:
|
else:
|
||||||
evals_result[key][res_key] = [res_val]
|
evals_result[key][res_key] = [res_val]
|
||||||
|
bst.best_iteration = (ntrees - 1)
|
||||||
return bst
|
return bst
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -115,6 +130,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
maximize_metrics = ('auc', 'map', 'ndcg')
|
maximize_metrics = ('auc', 'map', 'ndcg')
|
||||||
if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
|
if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
|
||||||
maximize_score = True
|
maximize_score = True
|
||||||
|
if feval is not None:
|
||||||
|
maximize_score = maximize
|
||||||
|
|
||||||
if maximize_score:
|
if maximize_score:
|
||||||
best_score = 0.0
|
best_score = 0.0
|
||||||
@ -122,7 +139,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
best_score = float('inf')
|
best_score = float('inf')
|
||||||
|
|
||||||
best_msg = ''
|
best_msg = ''
|
||||||
best_score_i = 0
|
best_score_i = ntrees
|
||||||
|
|
||||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
||||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||||
@ -134,6 +151,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
else:
|
else:
|
||||||
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
bst.set_param({'eta': learning_rates(i, num_boost_round)})
|
||||||
bst.update(dtrain, i, obj)
|
bst.update(dtrain, i, obj)
|
||||||
|
ntrees += 1
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
|
|
||||||
if isinstance(bst_eval_set, STRING_TYPES):
|
if isinstance(bst_eval_set, STRING_TYPES):
|
||||||
@ -162,7 +180,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
if (maximize_score and score > best_score) or \
|
if (maximize_score and score > best_score) or \
|
||||||
(not maximize_score and score < best_score):
|
(not maximize_score and score < best_score):
|
||||||
best_score = score
|
best_score = score
|
||||||
best_score_i = i
|
best_score_i = (ntrees - 1)
|
||||||
best_msg = msg
|
best_msg = msg
|
||||||
elif i - best_score_i >= early_stopping_rounds:
|
elif i - best_score_i >= early_stopping_rounds:
|
||||||
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||||
|
|||||||
52
tests/python/test_training_continuation.py
Normal file
52
tests/python/test_training_continuation.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.cross_validation import KFold, train_test_split
|
||||||
|
from sklearn.metrics import mean_squared_error
|
||||||
|
from sklearn.grid_search import GridSearchCV
|
||||||
|
from sklearn.datasets import load_iris, load_digits, load_boston
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
rng = np.random.RandomState(1337)
|
||||||
|
|
||||||
|
class TestTrainingContinuation(unittest.TestCase):
|
||||||
|
|
||||||
|
xgb_params = {
|
||||||
|
'colsample_bytree': 0.7,
|
||||||
|
'silent': 1,
|
||||||
|
'nthread': 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_training_continuation(self):
|
||||||
|
digits = load_digits(2)
|
||||||
|
X = digits['data']
|
||||||
|
y = digits['target']
|
||||||
|
|
||||||
|
dtrain = xgb.DMatrix(X,label=y)
|
||||||
|
|
||||||
|
gbdt_01 = xgb.train(self.xgb_params, dtrain, num_boost_round=10)
|
||||||
|
ntrees_01 = len(gbdt_01.get_dump())
|
||||||
|
assert ntrees_01 == 10
|
||||||
|
|
||||||
|
gbdt_02 = xgb.train(self.xgb_params, dtrain, num_boost_round=0)
|
||||||
|
gbdt_02.save_model('xgb_tc.model')
|
||||||
|
|
||||||
|
gbdt_02a = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model=gbdt_02)
|
||||||
|
gbdt_02b = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model="xgb_tc.model")
|
||||||
|
ntrees_02a = len(gbdt_02a.get_dump())
|
||||||
|
ntrees_02b = len(gbdt_02b.get_dump())
|
||||||
|
assert ntrees_02a == 10
|
||||||
|
assert ntrees_02b == 10
|
||||||
|
assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain))
|
||||||
|
assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain))
|
||||||
|
|
||||||
|
gbdt_03 = xgb.train(self.xgb_params, dtrain, num_boost_round=3)
|
||||||
|
gbdt_03.save_model('xgb_tc.model')
|
||||||
|
|
||||||
|
gbdt_03a = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model=gbdt_03)
|
||||||
|
gbdt_03b = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model="xgb_tc.model")
|
||||||
|
ntrees_03a = len(gbdt_03a.get_dump())
|
||||||
|
ntrees_03b = len(gbdt_03b.get_dump())
|
||||||
|
assert ntrees_03a == 10
|
||||||
|
assert ntrees_03b == 10
|
||||||
|
assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain))
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user