45 lines
1.8 KiB
Python
Executable File
45 lines
1.8 KiB
Python
Executable File
#!/usr/bin/python
|
|
import sys
|
|
import numpy as np
|
|
sys.path.append('../../wrapper')
|
|
import xgboost as xgb
|
|
###
|
|
# advanced: cutomsized loss function
|
|
#
|
|
print ('start running example to used cutomized objective function')
|
|
|
|
dtrain = xgb.DMatrix('../data/agaricus.txt.train')
|
|
dtest = xgb.DMatrix('../data/agaricus.txt.test')
|
|
|
|
# note: for customized objective function, we leave objective as default
|
|
# note: what we are getting is margin value in prediction
|
|
# you must know what you are doing
|
|
param = {'max_depth':2, 'eta':1, 'silent':1 }
|
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
|
num_round = 2
|
|
|
|
# user define objective function, given prediction, return gradient and second order gradient
|
|
# this is loglikelihood loss
|
|
def logregobj(preds, dtrain):
|
|
labels = dtrain.get_label()
|
|
preds = 1.0 / (1.0 + np.exp(-preds))
|
|
grad = preds - labels
|
|
hess = preds * (1.0-preds)
|
|
return grad, hess
|
|
|
|
# user defined evaluation function, return a pair metric_name, result
|
|
# NOTE: when you do customized loss function, the default prediction value is margin
|
|
# this may make buildin evalution metric not function properly
|
|
# for example, we are doing logistic loss, the prediction is score before logistic transformation
|
|
# the buildin evaluation error assumes input is after logistic transformation
|
|
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
|
def evalerror(preds, dtrain):
|
|
labels = dtrain.get_label()
|
|
# return a pair metric_name, result
|
|
# since preds are margin(before logistic transformation, cutoff at 0)
|
|
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
|
|
|
|
# training with customized objective, we can also do step by step training
|
|
# simply look at xgboost.py's implementation of train
|
|
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)
|