add cutomized training
This commit is contained in:
parent
ebde99bde8
commit
9c2bb12cd1
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/python
|
||||
import sys
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
# append the path to xgboost
|
||||
sys.path.append('../')
|
||||
@ -80,3 +81,22 @@ dtrain.set_label(labels)
|
||||
evallist = [(dtest,'eval'), (dtrain,'train')]
|
||||
bst = xgb.train( param, dtrain, num_round, evallist )
|
||||
|
||||
###
|
||||
# cutomsized loss function, set loss_type to 0, so that predict get untransformed score
|
||||
#
|
||||
print 'start running example to used cutomized objective function'
|
||||
|
||||
|
||||
# note: set loss_type properly, loss_type=2 means the prediction will get logistic transformed
|
||||
# in most case, we may want to set loss_type = 0, to get untransformed score to compute gradient
|
||||
bst = param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'loss_type':2 }
|
||||
|
||||
# user define objective function, given prediction, return gradient and second order gradient
|
||||
def logregobj( preds, dtrain ):
|
||||
labels = dtrain.get_label()
|
||||
grad = preds - labels
|
||||
hess = preds * (1.0-preds)
|
||||
return grad, hess
|
||||
|
||||
# training with customized objective, we can also do step by step training, simply look at xgboost.py's implementation of train
|
||||
bst = xgb.train( param, dtrain, num_round, evallist, logregobj )
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
import ctypes
|
||||
import os
|
||||
# optinally have scipy sparse, though not necessary
|
||||
import numpy as np
|
||||
import numpy
|
||||
import numpy.ctypeslib
|
||||
import scipy.sparse as scp
|
||||
|
||||
# set this line correctly
|
||||
@ -71,8 +72,8 @@ class DMatrix:
|
||||
# get label from dmatrix
|
||||
def get_label(self):
|
||||
length = ctypes.c_ulong()
|
||||
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length));
|
||||
return [ labels[i] for i in xrange(length.value) ]
|
||||
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length))
|
||||
return numpy.array( [labels[i] for i in xrange(length.value)] )
|
||||
# clear everything
|
||||
def clear(self):
|
||||
xglib.XGDMatrixClear(self.handle)
|
||||
@ -111,6 +112,14 @@ class Booster:
|
||||
""" update """
|
||||
assert isinstance(dtrain, DMatrix)
|
||||
xglib.XGBoosterUpdateOneIter( self.handle, dtrain.handle )
|
||||
def boost(self, dtrain, grad, hess, bst_group = -1):
|
||||
""" update """
|
||||
assert len(grad) == len(hess)
|
||||
assert isinstance(dtrain, DMatrix)
|
||||
xglib.XGBoosterBoostOneIter( self.handle, dtrain.handle,
|
||||
(ctypes.c_float*len(grad))(*grad),
|
||||
(ctypes.c_float*len(hess))(*hess),
|
||||
len(grad), bst_group )
|
||||
def update_interact(self, dtrain, action, booster_index=None):
|
||||
""" beta: update with specified action"""
|
||||
assert isinstance(dtrain, DMatrix)
|
||||
@ -126,10 +135,10 @@ class Booster:
|
||||
xglib.XGBoosterEvalOneIter( self.handle, it, dmats, evnames, len(evals) )
|
||||
def eval(self, mat, name = 'eval', it = 0 ):
|
||||
self.eval_set( [(mat,name)], it)
|
||||
def predict(self, data):
|
||||
def predict(self, data, bst_group = -1):
|
||||
length = ctypes.c_ulong()
|
||||
preds = xglib.XGBoosterPredict( self.handle, data.handle, ctypes.byref(length))
|
||||
return [ preds[i] for i in xrange(length.value) ]
|
||||
preds = xglib.XGBoosterPredict( self.handle, data.handle, ctypes.byref(length), bst_group)
|
||||
return numpy.array( [ preds[i] for i in xrange(length.value)])
|
||||
def save_model(self, fname):
|
||||
""" save model to file """
|
||||
xglib.XGBoosterSaveModel( self.handle, ctypes.c_char_p(fname) )
|
||||
@ -140,12 +149,21 @@ class Booster:
|
||||
"""dump model into text file"""
|
||||
xglib.XGBoosterDumpModel( self.handle, ctypes.c_char_p(fname), ctypes.c_char_p(fmap) )
|
||||
|
||||
def train(params, dtrain, num_boost_round = 10, evals = []):
|
||||
def train(params, dtrain, num_boost_round = 10, evals = [], obj=None):
|
||||
""" train a booster with given paramaters """
|
||||
bst = Booster(params, [dtrain] )
|
||||
if obj == None:
|
||||
for i in xrange(num_boost_round):
|
||||
bst.update( dtrain )
|
||||
if len(evals) != 0:
|
||||
bst.eval_set( evals, i )
|
||||
else:
|
||||
# try customized objective function
|
||||
for i in xrange(num_boost_round):
|
||||
pred = bst.predict( dtrain )
|
||||
grad, hess = obj( pred, dtrain )
|
||||
bst.boost( dtrain, grad, hess )
|
||||
if len(evals) != 0:
|
||||
bst.eval_set( evals, i )
|
||||
return bst
|
||||
|
||||
|
||||
@ -102,11 +102,34 @@ namespace xgboost{
|
||||
xgboost::regrank::RegRankBoostLearner::LoadModel(fname);
|
||||
this->init_model = true;
|
||||
}
|
||||
const float *Pred( const DMatrix &dmat, size_t *len ){
|
||||
this->Predict( this->preds_, dmat );
|
||||
const float *Pred( const DMatrix &dmat, size_t *len, int bst_group ){
|
||||
this->CheckInit();
|
||||
|
||||
this->Predict( this->preds_, dmat, bst_group );
|
||||
*len = this->preds_.size();
|
||||
return &this->preds_[0];
|
||||
}
|
||||
inline void BoostOneIter( const DMatrix &train,
|
||||
float *grad, float *hess, size_t len, int bst_group ){
|
||||
this->grad_.resize( len ); this->hess_.resize( len );
|
||||
memcpy( &this->grad_[0], grad, sizeof(float)*len );
|
||||
memcpy( &this->hess_[0], hess, sizeof(float)*len );
|
||||
|
||||
if( grad_.size() == train.Size() ){
|
||||
if( bst_group < 0 ) bst_group = 0;
|
||||
base_gbm.DoBoost(grad_, hess_, train.data, train.info.root_index, bst_group);
|
||||
}else{
|
||||
utils::Assert( bst_group == -1, "must set bst_group to -1 to support all group boosting" );
|
||||
int ngroup = base_gbm.NumBoosterGroup();
|
||||
utils::Assert( grad_.size() == train.Size() * (size_t)ngroup, "BUG: UpdateOneIter: mclass" );
|
||||
std::vector<float> tgrad( train.Size() ), thess( train.Size() );
|
||||
for( int g = 0; g < ngroup; ++ g ){
|
||||
memcpy( &tgrad[0], &grad_[g*tgrad.size()], sizeof(float)*tgrad.size() );
|
||||
memcpy( &thess[0], &hess_[g*tgrad.size()], sizeof(float)*tgrad.size() );
|
||||
base_gbm.DoBoost(tgrad, thess, train.data, train.info.root_index, g );
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
};
|
||||
@ -182,6 +205,13 @@ extern "C"{
|
||||
bst->CheckInit(); dtr->CheckInit();
|
||||
bst->UpdateOneIter( *dtr );
|
||||
}
|
||||
void XGBoosterBoostOneIter( void *handle, void *dtrain,
|
||||
float *grad, float *hess, size_t len, int bst_group ){
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
DMatrix *dtr = static_cast<DMatrix*>(dtrain);
|
||||
bst->CheckInit(); dtr->CheckInit();
|
||||
bst->BoostOneIter( *dtr, grad, hess, len, bst_group );
|
||||
}
|
||||
void XGBoosterEvalOneIter( void *handle, int iter, void *dmats[], const char *evnames[], size_t len ){
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->CheckInit();
|
||||
@ -194,8 +224,8 @@ extern "C"{
|
||||
}
|
||||
bst->EvalOneIter( iter, mats, names, stdout );
|
||||
}
|
||||
const float *XGBoosterPredict( void *handle, void *dmat, size_t *len ){
|
||||
return static_cast<Booster*>(handle)->Pred( *static_cast<DMatrix*>(dmat), len );
|
||||
const float *XGBoosterPredict( void *handle, void *dmat, size_t *len, int bst_group ){
|
||||
return static_cast<Booster*>(handle)->Pred( *static_cast<DMatrix*>(dmat), len, bst_group );
|
||||
}
|
||||
void XGBoosterLoadModel( void *handle, const char *fname ){
|
||||
static_cast<Booster*>(handle)->LoadModel( fname );
|
||||
|
||||
@ -127,6 +127,19 @@ extern "C"{
|
||||
* \param dtrain training data
|
||||
*/
|
||||
void XGBoosterUpdateOneIter( void *handle, void *dtrain );
|
||||
|
||||
/*!
|
||||
* \brief update the model, by directly specify gradient and second order gradient,
|
||||
* this can be used to replace UpdateOneIter, to support customized loss function
|
||||
* \param handle handle
|
||||
* \param dtrain training data
|
||||
* \param grad gradient statistics
|
||||
* \param hess second order gradient statistics
|
||||
* \param len length of grad/hess array
|
||||
* \param bst_group boost group we are working at, default = -1
|
||||
*/
|
||||
void XGBoosterBoostOneIter( void *handle, void *dtrain,
|
||||
float *grad, float *hess, size_t len, int bst_group );
|
||||
/*!
|
||||
* \brief print evaluation statistics to stdout for xgboost
|
||||
* \param handle handle
|
||||
@ -141,8 +154,9 @@ extern "C"{
|
||||
* \param handle handle
|
||||
* \param dmat data matrix
|
||||
* \param len used to store length of returning result
|
||||
* \param bst_group booster group, if model contains multiple booster group, default = -1 means predict for all groups
|
||||
*/
|
||||
const float *XGBoosterPredict( void *handle, void *dmat, size_t *len );
|
||||
const float *XGBoosterPredict( void *handle, void *dmat, size_t *len, int bst_group );
|
||||
/*!
|
||||
* \brief load model from existing file
|
||||
* \param handle handle
|
||||
|
||||
@ -262,7 +262,6 @@ namespace xgboost{
|
||||
base_gbm.InteractRePredict(data.data, j, buffer_offset + j);
|
||||
}
|
||||
}
|
||||
private:
|
||||
/*! \brief get un-transformed prediction*/
|
||||
inline void PredictRaw(std::vector<float> &preds, const DMatrix &data, int bst_group = -1 ){
|
||||
int buffer_offset = this->FindBufferOffset(data);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user