adjust weight

This commit is contained in:
antinucleon 2014-09-02 15:22:08 -06:00
parent c75275a861
commit 5177fa02e4

View File

@ -3,10 +3,11 @@
import ctypes import ctypes
import os import os
# optinally have scipy sparse, though not necessary # optinally have scipy sparse, though not necessary
import numpy import numpy as np
import sys import sys
import numpy.ctypeslib import numpy.ctypeslib
import scipy.sparse as scp import scipy.sparse as scp
import random
# set this line correctly # set this line correctly
if os.name == 'nt': if os.name == 'nt':
@ -32,18 +33,30 @@ xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
def ctypes2numpy(cptr, length, dtype): def ctypes2numpy(cptr, length, dtype):
# convert a ctypes pointer array to numpy """convert a ctypes pointer array to numpy array """
assert isinstance(cptr, ctypes.POINTER(ctypes.c_float)) assert isinstance(cptr, ctypes.POINTER(ctypes.c_float))
res = numpy.zeros(length, dtype=dtype) res = numpy.zeros(length, dtype=dtype)
assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]) assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0])
return res return res
# data matrix used in xgboost
class DMatrix: class DMatrix:
"""data matrix used in xgboost"""
# constructor # constructor
def __init__(self, data, label=None, missing=0.0, weight = None): def __init__(self, data, label=None, missing=0.0, weight = None):
""" constructor of DMatrix
Args:
data: string/numpy array/scipy.sparse
data source, string type is the path of svmlight format txt file or xgb buffer
label: list or numpy 1d array, optional
label of training data
missing: float
value in data which need to be present as missing value
weight: list or numpy 1d array, optional
weight for each instances
"""
# force into void_p, mac need to pass things in as void_p # force into void_p, mac need to pass things in as void_p
if data == None: if data is None:
self.handle = None self.handle = None
return return
if isinstance(data, str): if isinstance(data, str):
@ -63,22 +76,25 @@ class DMatrix:
self.set_label(label) self.set_label(label)
if weight !=None: if weight !=None:
self.set_weight(weight) self.set_weight(weight)
# convert data from csr matrix
def __init_from_csr(self, csr): def __init_from_csr(self, csr):
"""convert data from csr matrix"""
assert len(csr.indices) == len(csr.data) assert len(csr.indices) == len(csr.data)
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR(
(ctypes.c_ulong * len(csr.indptr))(*csr.indptr), (ctypes.c_ulong * len(csr.indptr))(*csr.indptr),
(ctypes.c_uint * len(csr.indices))(*csr.indices), (ctypes.c_uint * len(csr.indices))(*csr.indices),
(ctypes.c_float * len(csr.data))(*csr.data), (ctypes.c_float * len(csr.data))(*csr.data),
len(csr.indptr), len(csr.data))) len(csr.indptr), len(csr.data)))
# convert data from numpy matrix
def __init_from_npy2d(self,mat,missing): def __init_from_npy2d(self,mat,missing):
"""convert data from numpy matrix"""
data = numpy.array(mat.reshape(mat.size), dtype='float32') data = numpy.array(mat.reshape(mat.size), dtype='float32')
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
mat.shape[0], mat.shape[1], ctypes.c_float(missing))) mat.shape[0], mat.shape[1], ctypes.c_float(missing)))
# destructor
def __del__(self): def __del__(self):
"""destructor"""
xglib.XGDMatrixFree(self.handle) xglib.XGDMatrixFree(self.handle)
def get_float_info(self, field): def get_float_info(self, field):
length = ctypes.c_ulong() length = ctypes.c_ulong()
@ -96,16 +112,39 @@ class DMatrix:
def set_uint_info(self, field, data): def set_uint_info(self, field, data):
xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
(ctypes.c_uint*len(data))(*data), len(data)) (ctypes.c_uint*len(data))(*data), len(data))
# load data from file
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
"""save DMatrix to XGBoost buffer
Args:
fname: string
name of buffer file
slient: bool, option
whether print info
Returns:
None
"""
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
# set label of dmatrix
def set_label(self, label): def set_label(self, label):
"""set label of dmatrix
Args:
label: list
label for DMatrix
Returns:
None
"""
self.set_float_info('label', label) self.set_float_info('label', label)
# set weight of each instances
def set_weight(self, weight): def set_weight(self, weight):
"""set weight of each instances
Args:
weight: float
weight for positive instance
Returns:
None
"""
self.set_float_info('weight', weight) self.set_float_info('weight', weight)
# set initialized margin prediction
def set_base_margin(self, margin): def set_base_margin(self, margin):
""" """
set base margin of booster to start from set base margin of booster to start from
@ -116,31 +155,149 @@ class DMatrix:
see also example/demo.py see also example/demo.py
""" """
self.set_float_info('base_margin', margin) self.set_float_info('base_margin', margin)
# set group size of dmatrix, used for rank
def set_group(self, group): def set_group(self, group):
"""set group size of dmatrix, used for rank
Args:
group:
Returns:
None
"""
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group)) xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
# get label from dmatrix
def get_label(self): def get_label(self):
"""get label from dmatrix
Args:
None
Returns:
list, label of data
"""
return self.get_float_info('label') return self.get_float_info('label')
# get weight from dmatrix
def get_weight(self): def get_weight(self):
"""get weight from dmatrix
Args:
None
Returns:
float, weight
"""
return self.get_float_info('weight') return self.get_float_info('weight')
# get base_margin from dmatrix
def get_base_margin(self): def get_base_margin(self):
"""get base_margin from dmatrix
Args:
None
Returns:
float, base margin
"""
return self.get_float_info('base_margin') return self.get_float_info('base_margin')
def num_row(self): def num_row(self):
"""get number of rows
Args:
None
Returns:
int, num rows
"""
return xglib.XGDMatrixNumRow(self.handle) return xglib.XGDMatrixNumRow(self.handle)
# slice the DMatrix to return a new DMatrix that only contains rindex
def slice(self, rindex): def slice(self, rindex):
"""slice the DMatrix to return a new DMatrix that only contains rindex
Args:
rindex: list
list of index to be chosen
Returns:
res: DMatrix
new DMatrix with chosen index
"""
res = DMatrix(None) res = DMatrix(None)
res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix( res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix(
self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex))) self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex)))
return res return res
class CVPack:
def __init__(self, dtrain, dtest, param):
self.dtrain = dtrain
self.dtest = dtest
self.watchlist = watchlist = [ (dtrain,'train'), (dtest, 'test') ]
self.bst = Booster(param, [dtrain,dtest])
def update(self,r):
self.bst.update(self.dtrain, r)
def eval(self,r):
return self.bst.eval_set(self.watchlist, r)
def mknfold(dall, nfold, param, seed, weightscale=None, evals=[], set_pos_weight=None):
"""
mk nfold list of cvpack from randidx
"""
randidx = range(dall.num_row())
random.seed(seed)
random.shuffle(randidx)
idxset = []
kstep = len(randidx) / nfold
for i in range(nfold):
idxset.append(randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ])
ret = []
for k in range(nfold):
trainlst = []
for j in range(nfold):
if j == k:
testlst = idxset[j]
else:
trainlst += idxset[j]
dtrain = dall.slice(trainlst)
dtest = dall.slice(testlst)
# rescale weight of dtrain and dtest
if weightscale != None:
dtrain.set_weight( dtrain.get_weight() * weightscale * dall.num_row() / dtrain.num_row() )
dtest.set_weight( dtest.get_weight() * weightscale * dall.num_row() / dtest.num_row() )
if set_pos_weight != None:
label = dtrain.get_label()
weight = dtrain.get_weight()
sum_wpos = sum( weight[i] for i in range(len(label)) if label[i] == 1.0 )
sum_wneg = sum( weight[i] for i in range(len(label)) if label[i] == 0.0 )
param['scale_pos_weight'] = sum_wneg/sum_wpos
plst = param.items() + [('eval_metric', itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst))
return ret
def aggcv(rlist):
"""
aggregate cross validation results
"""
cvmap = {}
arr = rlist[0].split()
ret = arr[0]
for it in arr[1:]:
k, v = it.split(':')
cvmap[k] = [float(v)]
for line in rlist[1:]:
arr = line.split()
assert ret == arr[0]
for it in arr[1:]:
k, v = it.split(':')
cvmap[k].append(float(v))
for k, v in sorted(cvmap.items(), key = lambda x:x[0]):
v = np.array(v)
ret += '\t%s:%f+%f' % (k, np.mean(v), np.std(v))
return ret
class Booster: class Booster:
"""learner class """ """learner class """
def __init__(self, params={}, cache=[], model_file = None): def __init__(self, params={}, cache=[], model_file = None):
""" constructor, param: """ """ constructor
Args:
params: dict
params for boosters
cache: list
list of cache item
model_file: string
path of model file
Returns:
None
"""
for d in cache: for d in cache:
assert isinstance(d, DMatrix) assert isinstance(d, DMatrix)
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache]) dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
@ -166,16 +323,30 @@ class Booster:
xglib.XGBoosterSetParam( xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(k.encode('utf-8')), self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8'))) ctypes.c_char_p(str(v).encode('utf-8')))
def update(self, dtrain, it): def update(self, dtrain, it):
""" """
update update
dtrain: the training DMatrix Args:
it: current iteration number dtrain: DMatrix
the training DMatrix
it: int
current iteration number
Returns:
None
""" """
assert isinstance(dtrain, DMatrix) assert isinstance(dtrain, DMatrix)
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle) xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
def boost(self, dtrain, grad, hess): def boost(self, dtrain, grad, hess):
""" update """ """ update
Args:
dtrain: DMatrix
the training DMatrix
grad: list
the first order of gradient
hess: list
the second order of gradient
"""
assert len(grad) == len(hess) assert len(grad) == len(hess)
assert isinstance(dtrain, DMatrix) assert isinstance(dtrain, DMatrix)
xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle, xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle,
@ -183,6 +354,14 @@ class Booster:
(ctypes.c_float*len(hess))(*hess), (ctypes.c_float*len(hess))(*hess),
len(grad)) len(grad))
def eval_set(self, evals, it = 0): def eval_set(self, evals, it = 0):
"""evaluates by metric
Args:
evals: list of tuple (DMatrix, string)
lists of items to be evaluated
it: int
Returns:
evals result
"""
for d in evals: for d in evals:
assert isinstance(d[0], DMatrix) assert isinstance(d[0], DMatrix)
assert isinstance(d[1], str) assert isinstance(d[1], str)
@ -192,25 +371,49 @@ class Booster:
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0): def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it) return self.eval_set( [(mat,name)], it)
def predict(self, data, output_margin=False, ntree_limit=0): def predict(self, data, output_margin=False):
""" """
predict with data predict with data
data: the dmatrix storing the input Args:
output_margin: whether output raw margin value that is untransformed data: DMatrix
ntree_limit: limit number of trees in prediction, default to 0, 0 means using all the trees the dmatrix storing the input
output_margin: bool
whether output raw margin value that is untransformed
Returns:
numpy array of prediction
""" """
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle, preds = xglib.XGBoosterPredict(self.handle, data.handle,
int(output_margin), ntree_limit, ctypes.byref(length)) int(output_margin), ctypes.byref(length))
return ctypes2numpy(preds, length.value, 'float32') return ctypes2numpy(preds, length.value, 'float32')
def save_model(self, fname): def save_model(self, fname):
""" save model to file """ """ save model to file
Args:
fname: string
file name of saving model
Returns:
None
"""
xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8'))) xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8')))
def load_model(self, fname): def load_model(self, fname):
"""load model from file""" """load model from file
Args:
fname: string
file name of saving model
Returns:
None
"""
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) ) xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
def dump_model(self, fo, fmap=''): def dump_model(self, fo, fmap=''):
"""dump model into text file""" """dump model into text file
Args:
fo: string
file name to be dumped
fmap: string, optional
file name of feature map names
Returns:
None
"""
if isinstance(fo,str): if isinstance(fo,str):
fo = open(fo,'w') fo = open(fo,'w')
need_close = True need_close = True
@ -249,7 +452,17 @@ class Booster:
return fmap return fmap
def evaluate(bst, evals, it, feval = None): def evaluate(bst, evals, it, feval = None):
"""evaluation on eval set""" """evaluation on eval set
Args:
bst: XGBoost object
object of XGBoost model
evals: list of tuple (DMatrix, string)
obj need to be evaluated
it: int
feval: optional
Returns:
eval result
"""
if feval != None: if feval != None:
res = '[%d]' % it res = '[%d]' % it
for dm, evname in evals: for dm, evname in evals:
@ -260,10 +473,24 @@ def evaluate(bst, evals, it, feval = None):
return res return res
def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None): def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None):
""" train a booster with given paramaters """ """ train a booster with given paramaters
Args:
params: dict
params of booster
dtrain: DMatrix
data to be trained
num_boost_round: int
num of round to be boosted
evals: list
list of items to be evaluated
obj:
feval:
"""
bst = Booster(params, [dtrain]+[ d[0] for d in evals ] ) bst = Booster(params, [dtrain]+[ d[0] for d in evals ] )
if obj == None: if obj is None:
for i in range(num_boost_round): for i in range(num_boost_round):
bst.update( dtrain, i ) bst.update( dtrain, i )
if len(evals) != 0: if len(evals) != 0:
@ -277,3 +504,29 @@ def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None
if len(evals) != 0: if len(evals) != 0:
sys.stderr.write(evaluate(bst, evals, i, feval)+'\n') sys.stderr.write(evaluate(bst, evals, i, feval)+'\n')
return bst return bst
def cv(params, dtrain, num_boost_round = 10, nfold=3, evals = [], \
weightscale=None, obj=None, feval=None, set_pos_weight=None):
""" cross validation with given paramaters
Args:
params: dict
params of booster
dtrain: DMatrix
data to be trained
num_boost_round: int
num of round to be boosted
nfold: int
folds to do cv
evals: list
list of items to be evaluated
obj:
feval:
set_pos_weight: bool, optional
Adjust pos weight by number
"""
cvfolds = mknfold(dtrain, nfold, params, 0, weightscale, evals)
for i in range(num_boost_round):
for f in cvfolds:
f.update(i)
res = aggcv([f.eval(i) for f in cvfolds])
sys.stderr.write(res+'\n')