[PYTHON] Refactor trainnig API to use callback

This commit is contained in:
tqchen
2016-05-19 17:47:11 -07:00
parent 03996dd4e8
commit 149589c583
18 changed files with 492 additions and 278 deletions

View File

@@ -12,15 +12,18 @@ print ('running cross validation')
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed = 0)
metrics={'error'}, seed = 0,
callbacks=[xgb.callback.print_evaluation(show_stdv=True)])
print ('running cross validation, disable standard deviation display')
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed = 0, show_stdv = False)
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
metrics={'error'}, seed = 0,
callbacks=[xgb.callback.print_evaluation(show_stdv=False),
xgb.callback.early_stop(3)])
print (res)
print ('running cross validation, with preprocessing function')
# define the preprocessing function
# used to return the preprocessed training, test data, and parameter
@@ -58,4 +61,3 @@ param = {'max_depth':2, 'eta':1, 'silent':1}
# train with customized objective
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
obj = logregobj, feval=evalerror)