a correct version

This commit is contained in:
tqchen 2014-05-15 21:11:46 -07:00
parent 2be3f6ece0
commit 6af6d64f0b
5 changed files with 24 additions and 15 deletions

View File

@ -13,6 +13,8 @@ dpath = 'data'
# load in training data, directly use numpy # load in training data, directly use numpy
dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } ) dtrain = np.loadtxt( dpath+'/training.csv', delimiter=',', skiprows=1, converters={32: lambda x:int(x=='s') } )
print 'finish loading from csv '
label = dtrain[:,32] label = dtrain[:,32]
data = dtrain[:,1:31] data = dtrain[:,1:31]
# rescale weight to make it same as test set # rescale weight to make it same as test set
@ -25,25 +27,28 @@ sum_wneg = sum( weight[i] for i in xrange(len(label)) if label[i] == 0.0 )
print 'weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos ) print 'weight statistics: wpos=%g, wneg=%g, ratio=%g' % ( sum_wpos, sum_wneg, sum_wneg/sum_wpos )
# construct xgboost.DMatrix from numpy array, treat -999.0 as missing value # construct xgboost.DMatrix from numpy array, treat -999.0 as missing value
xtrain = xgb.DMatrix( data, label=label, missing = -999.0 ) xgmat = xgb.DMatrix( data, label=label, missing = -999.0, weight=weight )
# setup parameters for xgboost # setup parameters for xgboost
params = {} param = {}
# use logistic regression loss # use logistic regression loss
param['loss_type'] = 3 param['loss_type'] = 3
# scale weight of positive examples # scale weight of positive examples
param['scale_pos_weight'] = sum_wpos/sum_wpos param['scale_pos_weight'] = sum_wneg/sum_wpos
param['bst:eta'] = 0.1 param['bst:eta'] = 0.1
param['bst:max_depth'] = 6 param['bst:max_depth'] = 6
param['eval_metric'] = 'ams@0.15' param['eval_metric'] = 'auc'
param['silent'] = 1 param['silent'] = 1
param['eval_train'] = 1
param['nthread'] = 16 param['nthread'] = 16
# you can directly throw param in, though we want to watch multiple metrics here
plst = param.items()+[('eval_metric', 'ams@0.15')]
watchlist = [ (xgmat,'train') ]
# boost 120 tres # boost 120 tres
num_round = 120 num_round = 120
print 'loading data end, start to boost trees' print 'loading data end, start to boost trees'
bst = xgb.train( xtrain, param, num_round ); bst = xgb.train( plst, xgmat, num_round, watchlist );
# save out model # save out model
bst.save_model('higgs.model') bst.save_model('higgs.model')

View File

@ -1,5 +1,5 @@
#!/usr/bin/python #!/usr/bin/python
# this is the example script to use xgboost to train # make prediction
import sys import sys
import numpy as np import numpy as np
# add path of xgboost python module # add path of xgboost python module
@ -17,13 +17,14 @@ threshold_ratio = 0.15
# load in training data, directly use numpy # load in training data, directly use numpy
dtest = np.loadtxt( dpath+'/test.csv', delimiter=',', skiprows=1 ) dtest = np.loadtxt( dpath+'/test.csv', delimiter=',', skiprows=1 )
data = dtest[:,1:31] data = dtest[:,1:31]
idx = dtest[:,1] idx = dtest[:,0]
xtest = xgb.DMatrix( data, missing = -999.0 ) print 'finish loading from csv '
bst = xgb.Booster() xgmat = xgb.DMatrix( data, missing = -999.0 )
bst = xgb.Booster({'nthread':16})
bst.load_model( modelfile ) bst.load_model( modelfile )
ypred = bst.predict( xgmat )
ypred = bst.predict( dtest )
res = [ ( int(idx[i]), ypred[i] ) for i in xrange(len(ypred)) ] res = [ ( int(idx[i]), ypred[i] ) for i in xrange(len(ypred)) ]
rorder = {} rorder = {}
@ -31,7 +32,7 @@ for k, v in sorted( res, key = lambda x:-x[1] ):
rorder[ k ] = len(rorder) + 1 rorder[ k ] = len(rorder) + 1
# write out predictions # write out predictions
ntop = int( ratio * len(rorder ) ) ntop = int( threshold_ratio * len(rorder ) )
fo = open(outfile, 'w') fo = open(outfile, 'w')
nhit = 0 nhit = 0
ntot = 0 ntot = 0
@ -46,7 +47,7 @@ for k, v in res:
ntot += 1 ntot += 1
fo.close() fo.close()
print 'finished writing into model file' print 'finished writing into prediction file'

View File

@ -33,7 +33,7 @@ def ctypes2numpy( cptr, length ):
# data matrix used in xgboost # data matrix used in xgboost
class DMatrix: class DMatrix:
# constructor # constructor
def __init__(self, data=None, label=None, missing=0.0): def __init__(self, data=None, label=None, missing=0.0, weight = None):
self.handle = xglib.XGDMatrixCreate() self.handle = xglib.XGDMatrixCreate()
if data == None: if data == None:
return return
@ -51,6 +51,8 @@ class DMatrix:
raise Exception, "can not intialize DMatrix from"+str(type(data)) raise Exception, "can not intialize DMatrix from"+str(type(data))
if label != None: if label != None:
self.set_label(label) self.set_label(label)
if weight !=None:
self.set_weight(weight)
# convert data from csr matrix # convert data from csr matrix
def __init_from_csr(self,csr): def __init_from_csr(self,csr):

View File

@ -57,7 +57,7 @@ namespace xgboost{
DMatrix(void){} DMatrix(void){}
/*! \brief get the number of instances */ /*! \brief get the number of instances */
inline size_t Size() const{ inline size_t Size() const{
return info.labels.size(); return data.NumRow();
} }
/*! /*!
* \brief load from text file * \brief load from text file

View File

@ -110,6 +110,7 @@ namespace xgboost{
virtual float Eval(const std::vector<float> &preds, virtual float Eval(const std::vector<float> &preds,
const DMatrix::Info &info) const { const DMatrix::Info &info) const {
const unsigned ndata = static_cast<unsigned>(preds.size()); const unsigned ndata = static_cast<unsigned>(preds.size());
utils::Assert( info.weights.size() == ndata, "we need weight to evaluate ams");
std::vector< std::pair<float, unsigned> > rec(ndata); std::vector< std::pair<float, unsigned> > rec(ndata);
#pragma omp parallel for schedule( static ) #pragma omp parallel for schedule( static )