Merge pull request #178 from aldanor/master

[python] Fixed the dll import for relative paths + various cleanup.
This commit is contained in:
Tianqi Chen 2015-02-16 09:51:40 -08:00
commit 15562126a6

View File

@ -1,142 +1,185 @@
""" """
xgboost: eXtreme Gradient Boosting library xgboost: eXtreme Gradient Boosting library
Author: Tianqi Chen, Bing Xu
Authors: Tianqi Chen, Bing Xu
""" """
import ctypes
from __future__ import absolute_import
import os import os
# optinally have scipy sparse, though not necessary
import numpy as np
import sys import sys
import numpy.ctypeslib import ctypes
import scipy.sparse as scp import collections
# set this line correctly import numpy as np
if os.name == 'nt': import scipy.sparse
XGBOOST_PATH = os.path.dirname(__file__)+'/../windows/x64/Release/xgboost_wrapper.dll'
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
if sys.version_info[0] == 3:
string_types = str,
else: else:
XGBOOST_PATH = os.path.dirname(__file__)+'/libxgboostwrapper.so' string_types = basestring,
def load_xglib():
dll_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
if os.name == 'nt':
dll_path = os.path.join(dll_path, '../windows/x64/Release/xgboost_wrapper.dll')
else:
dll_path = os.path.join(dll_path, 'libxgboostwrapper.so')
# load the xgboost wrapper library
lib = ctypes.cdll.LoadLibrary(dll_path)
# DMatrix functions
lib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
lib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
lib.XGDMatrixCreateFromCSC.restype = ctypes.c_void_p
lib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
lib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
lib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
lib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint)
lib.XGDMatrixNumRow.restype = ctypes.c_ulong
# Booster functions
lib.XGBoosterCreate.restype = ctypes.c_void_p
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
return lib
# load the XGBoost library globally
xglib = load_xglib()
# load in xgboost library
xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH)
# DMatrix functions
xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromCSC.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint)
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
# booster functions
xglib.XGBoosterCreate.restype = ctypes.c_void_p
xglib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
def ctypes2numpy(cptr, length, dtype): def ctypes2numpy(cptr, length, dtype):
"""convert a ctypes pointer array to numpy array """ """
assert isinstance(cptr, ctypes.POINTER(ctypes.c_float)) Convert a ctypes pointer array to a numpy array.
res = numpy.zeros(length, dtype=dtype) """
assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]) if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
raise RuntimeError('expected float pointer')
res = np.zeros(length, dtype=dtype)
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
raise RuntimeError('memmove failed')
return res return res
class DMatrix:
"""data matrix used in xgboost"""
# constructor
def __init__(self, data, label=None, missing=0.0, weight = None):
""" constructor of DMatrix
Args: def c_str(string):
data: string/numpy array/scipy.sparse return ctypes.c_char_p(string.encode('utf-8'))
data source, string type is the path of svmlight format txt file or xgb buffer
label: list or numpy 1d array, optional
label of training data def c_array(ctype, values):
missing: float return (ctype * len(values))(*values)
value in data which need to be present as missing value
weight: list or numpy 1d array, optional
weight for each instances class DMatrix(object):
def __init__(self, data, label=None, missing=0.0, weight=None):
""" """
Data matrix used in XGBoost.
Parameters
----------
data : string/numpy array/scipy.sparse
Data source, string type is the path of svmlight format txt file or xgb buffer.
label : list or numpy 1-D array (optional)
Label of the training data.
missing : float
Value in the data which needs to be present as a missing value.
weight : list or numpy 1-D array (optional)
Weight for each instance.
"""
# force into void_p, mac need to pass things in as void_p # force into void_p, mac need to pass things in as void_p
if data is None: if data is None:
self.handle = None self.handle = None
return return
if isinstance(data, str): if isinstance(data, string_types):
self.handle = ctypes.c_void_p( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromFile(c_str(data), 0))
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0)) elif isinstance(data, scipy.sparse.csr_matrix):
elif isinstance(data, scp.csr_matrix): self._init_from_csr(data)
self.__init_from_csr(data) elif isinstance(data, scipy.sparse.csc_matrix):
elif isinstance(data, scp.csc_matrix): self._init_from_csc(data)
self.__init_from_csc(data) elif isinstance(data, np.ndarray) and len(data.shape) == 2:
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2: self._init_from_npy2d(data, missing)
self.__init_from_npy2d(data, missing)
else: else:
try: try:
csr = scp.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
self.__init_from_csr(csr) self._init_from_csr(csr)
except: except:
raise Exception("can not intialize DMatrix from"+str(type(data))) raise TypeError('can not intialize DMatrix from {}'.format(type(data).__name__))
if label != None: if label is not None:
self.set_label(label) self.set_label(label)
if weight !=None: if weight is not None:
self.set_weight(weight) self.set_weight(weight)
def __init_from_csr(self, csr): def _init_from_csr(self, csr):
"""convert data from csr matrix""" """
assert len(csr.indices) == len(csr.data) Initialize data from a CSR matrix.
"""
if len(csr.indices) != len(csr.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR(
(ctypes.c_ulong * len(csr.indptr))(*csr.indptr), c_array(ctypes.c_ulong, csr.indptr),
(ctypes.c_uint * len(csr.indices))(*csr.indices), c_array(ctypes.c_uint, csr.indices),
(ctypes.c_float * len(csr.data))(*csr.data), c_array(ctypes.c_float, csr.data),
len(csr.indptr), len(csr.data))) len(csr.indptr), len(csr.data)))
def __init_from_csc(self, csc): def _init_from_csc(self, csc):
"""convert data from csr matrix""" """
assert len(csc.indices) == len(csc.data) Initialize data from a CSC matrix.
"""
if len(csc.indices) != len(csc.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSC( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSC(
(ctypes.c_ulong * len(csc.indptr))(*csc.indptr), c_array(ctypes.c_ulong, csc.indptr),
(ctypes.c_uint * len(csc.indices))(*csc.indices), c_array(ctypes.c_uint, csc.indices),
(ctypes.c_float * len(csc.data))(*csc.data), c_array(ctypes.c_float, csc.data),
len(csc.indptr), len(csc.data))) len(csc.indptr), len(csc.data)))
def __init_from_npy2d(self,mat,missing): def _init_from_npy2d(self, mat, missing):
"""convert data from numpy matrix""" """
data = numpy.array(mat.reshape(mat.size), dtype='float32') Initialize data from a 2-D numpy matrix.
"""
data = np.array(mat.reshape(mat.size), dtype=np.float32)
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
mat.shape[0], mat.shape[1], ctypes.c_float(missing))) mat.shape[0], mat.shape[1], ctypes.c_float(missing)))
def __del__(self): def __del__(self):
"""destructor"""
xglib.XGDMatrixFree(self.handle) xglib.XGDMatrixFree(self.handle)
def get_float_info(self, field): def get_float_info(self, field):
length = ctypes.c_ulong() length = ctypes.c_ulong()
ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), ret = xglib.XGDMatrixGetFloatInfo(self.handle, c_str(field), ctypes.byref(length))
ctypes.byref(length)) return ctypes2numpy(ret, length.value, np.float32)
return ctypes2numpy(ret, length.value, 'float32')
def get_uint_info(self, field): def get_uint_info(self, field):
length = ctypes.c_ulong() length = ctypes.c_ulong()
ret = xglib.XGDMatrixGetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), ret = xglib.XGDMatrixGetUIntInfo(self.handle, c_str(field), ctypes.byref(length))
ctypes.byref(length)) return ctypes2numpy(ret, length.value, np.uint32)
return ctypes2numpy(ret, length.value, 'uint32')
def set_float_info(self, field, data): def set_float_info(self, field, data):
xglib.XGDMatrixSetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), xglib.XGDMatrixSetFloatInfo(self.handle, c_str(field),
(ctypes.c_float*len(data))(*data), len(data)) c_array(ctypes.c_float, data), len(data))
def set_uint_info(self, field, data): def set_uint_info(self, field, data):
xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), xglib.XGDMatrixSetUIntInfo(self.handle, c_str(field),
(ctypes.c_uint*len(data))(*data), len(data)) c_array(ctypes.c_uint, data), len(data))
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
"""save DMatrix to XGBoost buffer
Args:
fname: string
name of buffer file
slient: bool, option
whether print info
Returns:
None
""" """
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) Save DMatrix to an XGBoost buffer.
Parameters
----------
fname : string
Name of the output buffer file.
silent : bool (optional; default: True)
If set, the output is suppressed.
"""
xglib.XGDMatrixSaveBinary(self.handle, c_str(fname), int(silent))
def set_label(self, label): def set_label(self, label):
"""set label of dmatrix """set label of dmatrix
@ -149,12 +192,13 @@ class DMatrix:
self.set_float_info('label', label) self.set_float_info('label', label)
def set_weight(self, weight): def set_weight(self, weight):
"""set weight of each instances """
Args: Set weight of each instance.
weight: float
weight for positive instance Parameters
Returns: ----------
None weight : float
Weight for positive instance.
""" """
self.set_float_info('weight', weight) self.set_float_info('weight', weight)
@ -170,159 +214,180 @@ class DMatrix:
self.set_float_info('base_margin', margin) self.set_float_info('base_margin', margin)
def set_group(self, group): def set_group(self, group):
"""set group size of dmatrix, used for rank
Args:
group:
Returns:
None
""" """
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group)) Set group size of DMatrix (used for ranking).
Parameters
----------
group : int
Group size.
"""
xglib.XGDMatrixSetGroup(self.handle, c_array(ctypes.c_uint, group), len(group))
def get_label(self): def get_label(self):
"""get label from dmatrix """
Args: Get the label of the DMatrix.
None
Returns: Returns
list, label of data -------
label : list
""" """
return self.get_float_info('label') return self.get_float_info('label')
def get_weight(self): def get_weight(self):
"""get weight from dmatrix """
Args: Get the weight of the DMatrix.
None
Returns: Returns
float, weight -------
weight : float
""" """
return self.get_float_info('weight') return self.get_float_info('weight')
def get_base_margin(self): def get_base_margin(self):
"""get base_margin from dmatrix """
Args: Get the base margin of the DMatrix.
None
Returns: Returns
float, base margin -------
base_margin : float
""" """
return self.get_float_info('base_margin') return self.get_float_info('base_margin')
def num_row(self): def num_row(self):
"""get number of rows """
Args: Get the number of rows in the DMatrix.
None
Returns: Returns
int, num rows -------
number of rows : int
""" """
return xglib.XGDMatrixNumRow(self.handle) return xglib.XGDMatrixNumRow(self.handle)
def slice(self, rindex): def slice(self, rindex):
"""slice the DMatrix to return a new DMatrix that only contains rindex """
Args: Slice the DMatrix and return a new DMatrix that only contains `rindex`.
rindex: list
list of index to be chosen Parameters
Returns: ----------
res: DMatrix rindex : list
new DMatrix with chosen index List of indices to be selected.
Returns
-------
res : DMatrix
A new DMatrix containing only selected indices.
""" """
res = DMatrix(None) res = DMatrix(None)
res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix( res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix(
self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex))) self.handle, c_array(ctypes.c_int, rindex), len(rindex)))
return res return res
class Booster:
"""learner class """ class Booster(object):
def __init__(self, params={}, cache=[], model_file = None): def __init__(self, params=None, cache=(), model_file=None):
""" constructor """
Args: Learner class.
params: dict
params for boosters Parameters
cache: list ----------
list of cache item params : dict
model_file: string Parameters for boosters.
path of model file cache : list
Returns: List of cache items.
None model_file : string
Path to the model file.
""" """
for d in cache: for d in cache:
assert isinstance(d, DMatrix) if not isinstance(d, DMatrix):
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache]) raise TypeError('invalid cache item: {}'.format(type(d).__name__))
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache))) self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache)))
self.set_param({'seed':0}) self.set_param({'seed': 0})
self.set_param(params) self.set_param(params or {})
if model_file != None: if model_file is not None:
self.load_model(model_file) self.load_model(model_file)
def __del__(self): def __del__(self):
xglib.XGBoosterFree(self.handle) xglib.XGBoosterFree(self.handle)
def set_param(self, params, pv=None): def set_param(self, params, pv=None):
if isinstance(params, dict): if isinstance(params, collections.Mapping):
for k, v in params.items(): params = params.items()
xglib.XGBoosterSetParam( elif isinstance(params, string_types) and pv is not None:
self.handle, ctypes.c_char_p(k.encode('utf-8')), params = [(params, pv)]
ctypes.c_char_p(str(v).encode('utf-8'))) for k, v in params:
elif isinstance(params,str) and pv != None: xglib.XGBoosterSetParam(self.handle, c_str(k), c_str(str(v)))
xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(params.encode('utf-8')),
ctypes.c_char_p(str(pv).encode('utf-8')))
else:
for k, v in params:
xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8')))
def update(self, dtrain, it, fobj=None): def update(self, dtrain, it, fobj=None):
""" """
update Update (one iteration).
Args:
dtrain: DMatrix Parameters
the training DMatrix ----------
it: int dtrain : DMatrix
current iteration number Training data.
fobj: function it : int
cutomzied objective function Current iteration number.
Returns: fobj : function
None Customized objective function.
""" """
assert isinstance(dtrain, DMatrix) if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
if fobj is None: if fobj is None:
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle) xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
else: else:
pred = self.predict( dtrain ) pred = self.predict(dtrain)
grad, hess = fobj( pred, dtrain ) grad, hess = fobj(pred, dtrain)
self.boost( dtrain, grad, hess ) self.boost(dtrain, grad, hess)
def boost(self, dtrain, grad, hess): def boost(self, dtrain, grad, hess):
""" update
Args:
dtrain: DMatrix
the training DMatrix
grad: list
the first order of gradient
hess: list
the second order of gradient
""" """
assert len(grad) == len(hess) Update.
assert isinstance(dtrain, DMatrix)
Parameters
----------
dtrain : DMatrix
The training DMatrix.
grad : list
The first order of gradient.
hess : list
The second order of gradient.
"""
if len(grad) != len(hess):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle, xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle,
(ctypes.c_float*len(grad))(*grad), c_array(ctypes.c_float, grad),
(ctypes.c_float*len(hess))(*hess), c_array(ctypes.c_float, hess),
len(grad)) len(grad))
def eval_set(self, evals, it = 0, feval = None): def eval_set(self, evals, it=0, feval=None):
"""evaluates by metric """
Args: Evaluate by a metric.
evals: list of tuple (DMatrix, string)
lists of items to be evaluated Parameters
it: int ----------
current iteration evals : list of tuples (DMatrix, string)
feval: function List of items to be evaluated.
custom evaluation function it : int
Returns: Current iteration.
evals result feval : function
Custom evaluation function.
Returns
-------
evaluation result
""" """
if feval is None: if feval is None:
for d in evals: for d in evals:
assert isinstance(d[0], DMatrix) if not isinstance(d[0], DMatrix):
assert isinstance(d[1], str) raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals]) if not isinstance(d[1], string_types):
evnames = (ctypes.c_char_p * len(evals))( raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
* [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
else: else:
res = '[%d]' % it res = '[%d]' % it
@ -330,97 +395,115 @@ class Booster:
name, val = feval(self.predict(dm), dm) name, val = feval(self.predict(dm), dm)
res += '\t%s-%s:%f' % (evname, name, val) res += '\t%s-%s:%f' % (evname, name, val)
return res return res
def eval(self, mat, name = 'eval', it = 0):
return self.eval_set( [(mat,name)], it) def eval(self, mat, name='eval', it=0):
return self.eval_set([(mat, name)], it)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
""" """
predict with data Predict with data.
Args:
data: DMatrix Parameters
the dmatrix storing the input ----------
output_margin: bool data : DMatrix
whether output raw margin value that is untransformed The dmatrix storing the input.
ntree_limit: int output_margin : bool
limit number of trees in prediction, default to 0, 0 means using all the trees Whether to output the raw untransformed margin value.
pred_leaf: bool ntree_limit : int
when this option is on, the output will be a matrix of (nsample, ntrees) Limit number of trees in the prediction; defaults to 0 (use all trees).
with each record indicate the predicted leaf index of each sample in each tree pred_leaf : bool
Note that the leaf index of tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0 When this option is on, the output will be a matrix of (nsample, ntrees)
Returns: with each record indicating the predicted leaf index of each sample in each tree.
numpy array of prediction Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0.
Returns
-------
prediction : numpy array
""" """
option_mask = 0 option_mask = 0x00
if output_margin: if output_margin:
option_mask += 1 option_mask |= 0x01
if pred_leaf: if pred_leaf:
option_mask += 2 option_mask |= 0x02
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle, preds = xglib.XGBoosterPredict(self.handle, data.handle,
option_mask, ntree_limit, ctypes.byref(length)) option_mask, ntree_limit, ctypes.byref(length))
preds = ctypes2numpy(preds, length.value, 'float32') preds = ctypes2numpy(preds, length.value, np.float32)
if pred_leaf: if pred_leaf:
preds = preds.astype('int32') preds = preds.astype(np.int32)
nrow = data.num_row() nrow = data.num_row()
if preds.size != nrow and preds.size % nrow == 0: if preds.size != nrow and preds.size % nrow == 0:
preds = preds.reshape(nrow, preds.size / nrow) preds = preds.reshape(nrow, preds.size / nrow)
return preds return preds
def save_model(self, fname): def save_model(self, fname):
""" save model to file
Args:
fname: string
file name of saving model
Returns:
None
""" """
xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8'))) Save the model to a file.
Parameters
----------
fname : string
Output file name.
"""
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
def load_model(self, fname): def load_model(self, fname):
"""load model from file
Args:
fname: string
file name of saving model
Returns:
None
""" """
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) ) Load the model from a file.
def dump_model(self, fo, fmap='', with_stats = False):
"""dump model into text file Parameters
Args: ----------
fo: string fname : string
file name to be dumped Input file name.
fmap: string, optional
file name of feature map names
with_stats: bool, optional
whether output statistics of the split
Returns:
None
""" """
if isinstance(fo,str): xglib.XGBoosterLoadModel(self.handle, c_str(fname))
fo = open(fo,'w')
def dump_model(self, fo, fmap='', with_stats=False):
"""
Dump model into a text file.
Parameters
----------
fo : string
Output file name.
fmap : string, optional
Name of the file containing feature map names.
with_stats : bool (optional)
Controls whether the split statistics are output.
"""
if isinstance(fo, string_types):
fo = open(fo, 'w')
need_close = True need_close = True
else: else:
need_close = False need_close = False
ret = self.get_dump(fmap, with_stats) ret = self.get_dump(fmap, with_stats)
for i in range(len(ret)): for i in range(len(ret)):
fo.write('booster[%d]:\n' %i) fo.write('booster[{}]:\n'.format(i))
fo.write( ret[i] ) fo.write(ret[i])
if need_close: if need_close:
fo.close() fo.close()
def get_dump(self, fmap='', with_stats=False): def get_dump(self, fmap='', with_stats=False):
"""get dump of model as list of strings """ """
Returns the dump the model as a list of strings.
"""
length = ctypes.c_ulong() length = ctypes.c_ulong()
sarr = xglib.XGBoosterDumpModel(self.handle, sarr = xglib.XGBoosterDumpModel(self.handle, c_str(fmap),
ctypes.c_char_p(fmap.encode('utf-8')),
int(with_stats), ctypes.byref(length)) int(with_stats), ctypes.byref(length))
res = [] res = []
for i in range(length.value): for i in range(length.value):
res.append( str(sarr[i]) ) res.append(str(sarr[i]))
return res return res
def get_fscore(self, fmap=''): def get_fscore(self, fmap=''):
""" get feature importance of each feature """ """
Get feature importance of each feature.
"""
trees = self.get_dump(fmap) trees = self.get_dump(fmap)
fmap = {} fmap = {}
for tree in trees: for tree in trees:
print (tree) sys.stdout.write(str(tree) + '\n')
for l in tree.split('\n'): for l in tree.split('\n'):
arr = l.split('[') arr = l.split('[')
if len(arr) == 1: if len(arr) == 1:
@ -430,56 +513,70 @@ class Booster:
if fid not in fmap: if fid not in fmap:
fmap[fid] = 1 fmap[fid] = 1
else: else:
fmap[fid]+= 1 fmap[fid] += 1
return fmap return fmap
def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None):
""" train a booster with given paramaters def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
Args:
params: dict
params of booster
dtrain: DMatrix
data to be trained
num_boost_round: int
num of round to be boosted
watchlist: list of pairs (DMatrix, string)
list of items to be evaluated during training, this allows user to watch performance on validation set
obj: function
cutomized objective function
feval: function
cutomized evaluation function
Returns: Booster model trained
""" """
bst = Booster(params, [dtrain]+[ d[0] for d in evals ] ) Train a booster with given parameters.
Parameters
----------
params : dict
Booster params.
dtrain : DMatrix
Data to be trained.
num_boost_round: int
Number of boosting iterations.
watchlist : list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch
performance on the validation set.
obj : function
Customized objective function.
feval : function
Customized evaluation function.
Returns
-------
booster : a trained booster model
"""
evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])
for i in range(num_boost_round): for i in range(num_boost_round):
bst.update( dtrain, i, obj ) bst.update(dtrain, i, obj)
if len(evals) != 0: if len(evals) != 0:
bst_eval_set=bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set,str): if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set+'\n') sys.stderr.write(bst_eval_set + '\n')
else: else:
sys.stderr.write(bst_eval_set.decode()+'\n') sys.stderr.write(bst_eval_set.decode() + '\n')
return bst return bst
class CVPack:
class CVPack(object):
def __init__(self, dtrain, dtest, param): def __init__(self, dtrain, dtest, param):
self.dtrain = dtrain self.dtrain = dtrain
self.dtest = dtest self.dtest = dtest
self.watchlist = watchlist = [ (dtrain,'train'), (dtest, 'test') ] self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
self.bst = Booster(param, [dtrain,dtest]) self.bst = Booster(param, [dtrain, dtest])
def update(self, r, fobj): def update(self, r, fobj):
self.bst.update(self.dtrain, r, fobj) self.bst.update(self.dtrain, r, fobj)
def eval(self, r, feval): def eval(self, r, feval):
return self.bst.eval_set(self.watchlist, r, feval) return self.bst.eval_set(self.watchlist, r, feval)
def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None):
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
""" """
mk nfold list of cvpack from randidx Make an n-fold list of CVPack from random indices.
""" """
evals = list(evals)
np.random.seed(seed) np.random.seed(seed)
randidx = np.random.permutation(dall.num_row()) randidx = np.random.permutation(dall.num_row())
kstep = len(randidx) / nfold kstep = len(randidx) / nfold
idset = [randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ] for i in range(nfold)] idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
ret = [] ret = []
for k in range(nfold): for k in range(nfold):
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i])) dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
@ -493,9 +590,10 @@ def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None):
ret.append(CVPack(dtrain, dtest, plst)) ret.append(CVPack(dtrain, dtest, plst))
return ret return ret
def aggcv(rlist, show_stdv=True): def aggcv(rlist, show_stdv=True):
""" """
aggregate cross validation results Aggregate cross-validation results.
""" """
cvmap = {} cvmap = {}
ret = rlist[0].split()[0] ret = rlist[0].split()[0]
@ -503,15 +601,15 @@ def aggcv(rlist, show_stdv=True):
arr = line.split() arr = line.split()
assert ret == arr[0] assert ret == arr[0]
for it in arr[1:]: for it in arr[1:]:
if not isinstance(it,str): if not isinstance(it, string_types):
it=it.decode() it = it.decode()
k, v = it.split(':') k, v = it.split(':')
if k not in cvmap: if k not in cvmap:
cvmap[k] = [] cvmap[k] = []
cvmap[k].append(float(v)) cvmap[k].append(float(v))
for k, v in sorted(cvmap.items(), key = lambda x:x[0]): for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
v = np.array(v) v = np.array(v)
if not isinstance(ret,str): if not isinstance(ret, string_types):
ret = ret.decode() ret = ret.decode()
if show_stdv: if show_stdv:
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v)) ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
@ -519,33 +617,39 @@ def aggcv(rlist, show_stdv=True):
ret += '\tcv-%s:%f' % (k, np.mean(v)) ret += '\tcv-%s:%f' % (k, np.mean(v))
return ret return ret
def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \
obj = None, feval = None, fpreproc = None, show_stdv = True, seed = 0):
""" cross validation with given paramaters
Args:
params: dict
params of booster
dtrain: DMatrix
data to be trained
num_boost_round: int
num of round to be boosted
nfold: int
number of folds to do cv
metrics: list of strings
evaluation metrics to be watched in cv
obj: function
custom objective function
feval: function
custom evaluation function
fpreproc: function
preprocessing function that takes dtrain, dtest,
param and return transformed version of dtrain, dtest, param
show_stdv: bool
whether display standard deviation
seed: int
seed used to generate the folds, this is passed to numpy.random.seed
Returns: list(string) of evaluation history def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0):
"""
Cross-validation with given paramaters.
Parameters
----------
params : dict
Booster params.
dtrain : DMatrix
Data to be trained.
num_boost_round : int
Number of boosting iterations.
nfold : int
Number of folds in CV.
metrics : list of strings
Evaluation metrics to be watched in CV.
obj : function
Custom objective function.
feval : function
Custom evaluation function.
fpreproc : function
Preprocessing function that takes (dtrain, dtest, param) and returns
transformed versions of those.
show_stdv : bool
Whether to display the standard deviation.
seed : int
Seed used to generate the folds (passed to numpy.random.seed).
Returns
-------
evaluation history : list(string)
""" """
results = [] results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc) cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
@ -553,7 +657,6 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \
for f in cvfolds: for f in cvfolds:
f.update(i, obj) f.update(i, obj)
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv) res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
sys.stderr.write(res+'\n') sys.stderr.write(res + '\n')
results.append(res) results.append(res)
return results return results