From 8660ea91b59ba657cdab685a41ce0981c40e94bf Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 16 Feb 2015 16:03:47 +0000 Subject: [PATCH] Fixed the dll import for relative paths + various cleanup. - DLL import now works when __file__ is a relative path - Various PEP8 and whitespace fixes + whitespace cleanup - Docstring fixes (conform to numpydoc) - Added __all__ to the module - Fixed mutable default values - Removed print statements - py2/py3-compatible string-type checks - Replace asserts with proper exceptions - Make classes new-style (derive from object) --- wrapper/xgboost.py | 761 +++++++++++++++++++++++++-------------------- 1 file changed, 432 insertions(+), 329 deletions(-) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 1a2a4e1c2..affda3ca7 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -1,142 +1,185 @@ """ xgboost: eXtreme Gradient Boosting library -Author: Tianqi Chen, Bing Xu +Authors: Tianqi Chen, Bing Xu """ -import ctypes + +from __future__ import absolute_import + import os -# optinally have scipy sparse, though not necessary -import numpy as np import sys -import numpy.ctypeslib -import scipy.sparse as scp +import ctypes +import collections -# set this line correctly -if os.name == 'nt': - XGBOOST_PATH = os.path.dirname(__file__)+'/../windows/x64/Release/xgboost_wrapper.dll' +import numpy as np +import scipy.sparse + +__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train'] + +if sys.version_info[0] == 3: + string_types = str, else: - XGBOOST_PATH = os.path.dirname(__file__)+'/libxgboostwrapper.so' + string_types = basestring, + + +def load_xglib(): + dll_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + if os.name == 'nt': + dll_path = os.path.join(dll_path, '../windows/x64/Release/xgboost_wrapper.dll') + else: + dll_path = os.path.join(dll_path, 'libxgboostwrapper.so') + + # load the xgboost wrapper library + lib = ctypes.cdll.LoadLibrary(dll_path) + + # DMatrix functions + lib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p + lib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p + lib.XGDMatrixCreateFromCSC.restype = ctypes.c_void_p + lib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p + lib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p + lib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float) + lib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint) + lib.XGDMatrixNumRow.restype = ctypes.c_ulong + + # Booster functions + lib.XGBoosterCreate.restype = ctypes.c_void_p + lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float) + lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p + lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) + + return lib + +# load the XGBoost library globally +xglib = load_xglib() -# load in xgboost library -xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH) -# DMatrix functions -xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p -xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p -xglib.XGDMatrixCreateFromCSC.restype = ctypes.c_void_p -xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p -xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p -xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float) -xglib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint) -xglib.XGDMatrixNumRow.restype = ctypes.c_ulong -# booster functions -xglib.XGBoosterCreate.restype = ctypes.c_void_p -xglib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float) -xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p -xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) def ctypes2numpy(cptr, length, dtype): - """convert a ctypes pointer array to numpy array """ - assert isinstance(cptr, ctypes.POINTER(ctypes.c_float)) - res = numpy.zeros(length, dtype=dtype) - assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]) + """ + Convert a ctypes pointer array to a numpy array. + """ + if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)): + raise RuntimeError('expected float pointer') + res = np.zeros(length, dtype=dtype) + if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): + raise RuntimeError('memmove failed') return res -class DMatrix: - """data matrix used in xgboost""" - # constructor - def __init__(self, data, label=None, missing=0.0, weight = None): - """ constructor of DMatrix - Args: - data: string/numpy array/scipy.sparse - data source, string type is the path of svmlight format txt file or xgb buffer - label: list or numpy 1d array, optional - label of training data - missing: float - value in data which need to be present as missing value - weight: list or numpy 1d array, optional - weight for each instances +def c_str(string): + return ctypes.c_char_p(string.encode('utf-8')) + + +def c_array(ctype, values): + return (ctype * len(values))(*values) + + +class DMatrix(object): + def __init__(self, data, label=None, missing=0.0, weight=None): """ + Data matrix used in XGBoost. + + Parameters + ---------- + data : string/numpy array/scipy.sparse + Data source, string type is the path of svmlight format txt file or xgb buffer. + label : list or numpy 1-D array (optional) + Label of the training data. + missing : float + Value in the data which needs to be present as a missing value. + weight : list or numpy 1-D array (optional) + Weight for each instance. + """ + # force into void_p, mac need to pass things in as void_p if data is None: self.handle = None return - if isinstance(data, str): - self.handle = ctypes.c_void_p( - xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0)) - elif isinstance(data, scp.csr_matrix): - self.__init_from_csr(data) - elif isinstance(data, scp.csc_matrix): - self.__init_from_csc(data) - elif isinstance(data, numpy.ndarray) and len(data.shape) == 2: - self.__init_from_npy2d(data, missing) + if isinstance(data, string_types): + self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromFile(c_str(data), 0)) + elif isinstance(data, scipy.sparse.csr_matrix): + self._init_from_csr(data) + elif isinstance(data, scipy.sparse.csc_matrix): + self._init_from_csc(data) + elif isinstance(data, np.ndarray) and len(data.shape) == 2: + self._init_from_npy2d(data, missing) else: try: - csr = scp.csr_matrix(data) - self.__init_from_csr(csr) + csr = scipy.sparse.csr_matrix(data) + self._init_from_csr(csr) except: - raise Exception("can not intialize DMatrix from"+str(type(data))) - if label != None: + raise TypeError('can not intialize DMatrix from {}'.format(type(data).__name__)) + if label is not None: self.set_label(label) - if weight !=None: + if weight is not None: self.set_weight(weight) - def __init_from_csr(self, csr): - """convert data from csr matrix""" - assert len(csr.indices) == len(csr.data) + def _init_from_csr(self, csr): + """ + Initialize data from a CSR matrix. + """ + if len(csr.indices) != len(csr.data): + raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data))) self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR( - (ctypes.c_ulong * len(csr.indptr))(*csr.indptr), - (ctypes.c_uint * len(csr.indices))(*csr.indices), - (ctypes.c_float * len(csr.data))(*csr.data), + c_array(ctypes.c_ulong, csr.indptr), + c_array(ctypes.c_uint, csr.indices), + c_array(ctypes.c_float, csr.data), len(csr.indptr), len(csr.data))) - def __init_from_csc(self, csc): - """convert data from csr matrix""" - assert len(csc.indices) == len(csc.data) + def _init_from_csc(self, csc): + """ + Initialize data from a CSC matrix. + """ + if len(csc.indices) != len(csc.data): + raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data))) self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSC( - (ctypes.c_ulong * len(csc.indptr))(*csc.indptr), - (ctypes.c_uint * len(csc.indices))(*csc.indices), - (ctypes.c_float * len(csc.data))(*csc.data), + c_array(ctypes.c_ulong, csc.indptr), + c_array(ctypes.c_uint, csc.indices), + c_array(ctypes.c_float, csc.data), len(csc.indptr), len(csc.data))) - def __init_from_npy2d(self,mat,missing): - """convert data from numpy matrix""" - data = numpy.array(mat.reshape(mat.size), dtype='float32') + def _init_from_npy2d(self, mat, missing): + """ + Initialize data from a 2-D numpy matrix. + """ + data = np.array(mat.reshape(mat.size), dtype=np.float32) self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat( data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), mat.shape[0], mat.shape[1], ctypes.c_float(missing))) def __del__(self): - """destructor""" xglib.XGDMatrixFree(self.handle) + def get_float_info(self, field): length = ctypes.c_ulong() - ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), - ctypes.byref(length)) - return ctypes2numpy(ret, length.value, 'float32') + ret = xglib.XGDMatrixGetFloatInfo(self.handle, c_str(field), ctypes.byref(length)) + return ctypes2numpy(ret, length.value, np.float32) + def get_uint_info(self, field): length = ctypes.c_ulong() - ret = xglib.XGDMatrixGetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), - ctypes.byref(length)) - return ctypes2numpy(ret, length.value, 'uint32') + ret = xglib.XGDMatrixGetUIntInfo(self.handle, c_str(field), ctypes.byref(length)) + return ctypes2numpy(ret, length.value, np.uint32) + def set_float_info(self, field, data): - xglib.XGDMatrixSetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), - (ctypes.c_float*len(data))(*data), len(data)) + xglib.XGDMatrixSetFloatInfo(self.handle, c_str(field), + c_array(ctypes.c_float, data), len(data)) + def set_uint_info(self, field, data): - xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), - (ctypes.c_uint*len(data))(*data), len(data)) + xglib.XGDMatrixSetUIntInfo(self.handle, c_str(field), + c_array(ctypes.c_uint, data), len(data)) def save_binary(self, fname, silent=True): - """save DMatrix to XGBoost buffer - Args: - fname: string - name of buffer file - slient: bool, option - whether print info - Returns: - None """ - xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) + Save DMatrix to an XGBoost buffer. + + Parameters + ---------- + fname : string + Name of the output buffer file. + silent : bool (optional; default: True) + If set, the output is suppressed. + """ + xglib.XGDMatrixSaveBinary(self.handle, c_str(fname), int(silent)) def set_label(self, label): """set label of dmatrix @@ -149,12 +192,13 @@ class DMatrix: self.set_float_info('label', label) def set_weight(self, weight): - """set weight of each instances - Args: - weight: float - weight for positive instance - Returns: - None + """ + Set weight of each instance. + + Parameters + ---------- + weight : float + Weight for positive instance. """ self.set_float_info('weight', weight) @@ -170,159 +214,180 @@ class DMatrix: self.set_float_info('base_margin', margin) def set_group(self, group): - """set group size of dmatrix, used for rank - Args: - group: - - Returns: - None """ - xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group)) + Set group size of DMatrix (used for ranking). + + Parameters + ---------- + group : int + Group size. + """ + xglib.XGDMatrixSetGroup(self.handle, c_array(ctypes.c_uint, group), len(group)) def get_label(self): - """get label from dmatrix - Args: - None - Returns: - list, label of data + """ + Get the label of the DMatrix. + + Returns + ------- + label : list """ return self.get_float_info('label') def get_weight(self): - """get weight from dmatrix - Args: - None - Returns: - float, weight + """ + Get the weight of the DMatrix. + + Returns + ------- + weight : float """ return self.get_float_info('weight') + def get_base_margin(self): - """get base_margin from dmatrix - Args: - None - Returns: - float, base margin + """ + Get the base margin of the DMatrix. + + Returns + ------- + base_margin : float """ return self.get_float_info('base_margin') + def num_row(self): - """get number of rows - Args: - None - Returns: - int, num rows + """ + Get the number of rows in the DMatrix. + + Returns + ------- + number of rows : int """ return xglib.XGDMatrixNumRow(self.handle) + def slice(self, rindex): - """slice the DMatrix to return a new DMatrix that only contains rindex - Args: - rindex: list - list of index to be chosen - Returns: - res: DMatrix - new DMatrix with chosen index + """ + Slice the DMatrix and return a new DMatrix that only contains `rindex`. + + Parameters + ---------- + rindex : list + List of indices to be selected. + + Returns + ------- + res : DMatrix + A new DMatrix containing only selected indices. """ res = DMatrix(None) res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix( - self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex))) + self.handle, c_array(ctypes.c_int, rindex), len(rindex))) return res -class Booster: - """learner class """ - def __init__(self, params={}, cache=[], model_file = None): - """ constructor - Args: - params: dict - params for boosters - cache: list - list of cache item - model_file: string - path of model file - Returns: - None + +class Booster(object): + def __init__(self, params=None, cache=(), model_file=None): + """ + Learner class. + + Parameters + ---------- + params : dict + Parameters for boosters. + cache : list + List of cache items. + model_file : string + Path to the model file. """ for d in cache: - assert isinstance(d, DMatrix) - dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache]) + if not isinstance(d, DMatrix): + raise TypeError('invalid cache item: {}'.format(type(d).__name__)) + dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache))) - self.set_param({'seed':0}) - self.set_param(params) - if model_file != None: + self.set_param({'seed': 0}) + self.set_param(params or {}) + if model_file is not None: self.load_model(model_file) + def __del__(self): xglib.XGBoosterFree(self.handle) + def set_param(self, params, pv=None): - if isinstance(params, dict): - for k, v in params.items(): - xglib.XGBoosterSetParam( - self.handle, ctypes.c_char_p(k.encode('utf-8')), - ctypes.c_char_p(str(v).encode('utf-8'))) - elif isinstance(params,str) and pv != None: - xglib.XGBoosterSetParam( - self.handle, ctypes.c_char_p(params.encode('utf-8')), - ctypes.c_char_p(str(pv).encode('utf-8'))) - else: - for k, v in params: - xglib.XGBoosterSetParam( - self.handle, ctypes.c_char_p(k.encode('utf-8')), - ctypes.c_char_p(str(v).encode('utf-8'))) + if isinstance(params, collections.Mapping): + params = params.items() + elif isinstance(params, string_types) and pv is not None: + params = [(params, pv)] + for k, v in params: + xglib.XGBoosterSetParam(self.handle, c_str(k), c_str(str(v))) def update(self, dtrain, it, fobj=None): """ - update - Args: - dtrain: DMatrix - the training DMatrix - it: int - current iteration number - fobj: function - cutomzied objective function - Returns: - None + Update (one iteration). + + Parameters + ---------- + dtrain : DMatrix + Training data. + it : int + Current iteration number. + fobj : function + Customized objective function. """ - assert isinstance(dtrain, DMatrix) + if not isinstance(dtrain, DMatrix): + raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) if fobj is None: xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle) else: - pred = self.predict( dtrain ) - grad, hess = fobj( pred, dtrain ) - self.boost( dtrain, grad, hess ) + pred = self.predict(dtrain) + grad, hess = fobj(pred, dtrain) + self.boost(dtrain, grad, hess) def boost(self, dtrain, grad, hess): - """ update - Args: - dtrain: DMatrix - the training DMatrix - grad: list - the first order of gradient - hess: list - the second order of gradient """ - assert len(grad) == len(hess) - assert isinstance(dtrain, DMatrix) + Update. + + Parameters + ---------- + dtrain : DMatrix + The training DMatrix. + grad : list + The first order of gradient. + hess : list + The second order of gradient. + """ + if len(grad) != len(hess): + raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) + if not isinstance(dtrain, DMatrix): + raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle, - (ctypes.c_float*len(grad))(*grad), - (ctypes.c_float*len(hess))(*hess), + c_array(ctypes.c_float, grad), + c_array(ctypes.c_float, hess), len(grad)) - def eval_set(self, evals, it = 0, feval = None): - """evaluates by metric - Args: - evals: list of tuple (DMatrix, string) - lists of items to be evaluated - it: int - current iteration - feval: function - custom evaluation function - Returns: - evals result + def eval_set(self, evals, it=0, feval=None): + """ + Evaluate by a metric. + + Parameters + ---------- + evals : list of tuples (DMatrix, string) + List of items to be evaluated. + it : int + Current iteration. + feval : function + Custom evaluation function. + + Returns + ------- + evaluation result """ if feval is None: for d in evals: - assert isinstance(d[0], DMatrix) - assert isinstance(d[1], str) - dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals]) - evnames = (ctypes.c_char_p * len(evals))( - * [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals]) + if not isinstance(d[0], DMatrix): + raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__)) + if not isinstance(d[1], string_types): + raise TypeError('expected string, got {}'.format(type(d[1]).__name__)) + dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) + evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) else: res = '[%d]' % it @@ -330,97 +395,115 @@ class Booster: name, val = feval(self.predict(dm), dm) res += '\t%s-%s:%f' % (evname, name, val) return res - def eval(self, mat, name = 'eval', it = 0): - return self.eval_set( [(mat,name)], it) + + def eval(self, mat, name='eval', it=0): + return self.eval_set([(mat, name)], it) + def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): """ - predict with data - Args: - data: DMatrix - the dmatrix storing the input - output_margin: bool - whether output raw margin value that is untransformed - ntree_limit: int - limit number of trees in prediction, default to 0, 0 means using all the trees - pred_leaf: bool - when this option is on, the output will be a matrix of (nsample, ntrees) - with each record indicate the predicted leaf index of each sample in each tree - Note that the leaf index of tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0 - Returns: - numpy array of prediction + Predict with data. + + Parameters + ---------- + data : DMatrix + The dmatrix storing the input. + output_margin : bool + Whether to output the raw untransformed margin value. + ntree_limit : int + Limit number of trees in the prediction; defaults to 0 (use all trees). + pred_leaf : bool + When this option is on, the output will be a matrix of (nsample, ntrees) + with each record indicating the predicted leaf index of each sample in each tree. + Note that the leaf index of a tree is unique per tree, so you may find leaf 1 + in both tree 1 and tree 0. + + Returns + ------- + prediction : numpy array """ - option_mask = 0 + option_mask = 0x00 if output_margin: - option_mask += 1 + option_mask |= 0x01 if pred_leaf: - option_mask += 2 + option_mask |= 0x02 length = ctypes.c_ulong() preds = xglib.XGBoosterPredict(self.handle, data.handle, - option_mask, ntree_limit, ctypes.byref(length)) - preds = ctypes2numpy(preds, length.value, 'float32') + option_mask, ntree_limit, ctypes.byref(length)) + preds = ctypes2numpy(preds, length.value, np.float32) if pred_leaf: - preds = preds.astype('int32') + preds = preds.astype(np.int32) nrow = data.num_row() if preds.size != nrow and preds.size % nrow == 0: - preds = preds.reshape(nrow, preds.size / nrow) + preds = preds.reshape(nrow, preds.size / nrow) return preds + def save_model(self, fname): - """ save model to file - Args: - fname: string - file name of saving model - Returns: - None """ - xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8'))) + Save the model to a file. + + Parameters + ---------- + fname : string + Output file name. + """ + xglib.XGBoosterSaveModel(self.handle, c_str(fname)) + def load_model(self, fname): - """load model from file - Args: - fname: string - file name of saving model - Returns: - None """ - xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) ) - def dump_model(self, fo, fmap='', with_stats = False): - """dump model into text file - Args: - fo: string - file name to be dumped - fmap: string, optional - file name of feature map names - with_stats: bool, optional - whether output statistics of the split - Returns: - None + Load the model from a file. + + Parameters + ---------- + fname : string + Input file name. """ - if isinstance(fo,str): - fo = open(fo,'w') + xglib.XGBoosterLoadModel(self.handle, c_str(fname)) + + def dump_model(self, fo, fmap='', with_stats=False): + """ + Dump model into a text file. + + Parameters + ---------- + fo : string + Output file name. + fmap : string, optional + Name of the file containing feature map names. + with_stats : bool (optional) + Controls whether the split statistics are output. + """ + if isinstance(fo, string_types): + fo = open(fo, 'w') need_close = True else: need_close = False ret = self.get_dump(fmap, with_stats) for i in range(len(ret)): - fo.write('booster[%d]:\n' %i) - fo.write( ret[i] ) + fo.write('booster[{}]:\n'.format(i)) + fo.write(ret[i]) if need_close: fo.close() + def get_dump(self, fmap='', with_stats=False): - """get dump of model as list of strings """ + """ + Returns the dump the model as a list of strings. + """ length = ctypes.c_ulong() - sarr = xglib.XGBoosterDumpModel(self.handle, - ctypes.c_char_p(fmap.encode('utf-8')), + sarr = xglib.XGBoosterDumpModel(self.handle, c_str(fmap), int(with_stats), ctypes.byref(length)) res = [] for i in range(length.value): - res.append( str(sarr[i]) ) + res.append(str(sarr[i])) return res + def get_fscore(self, fmap=''): - """ get feature importance of each feature """ + """ + Get feature importance of each feature. + """ trees = self.get_dump(fmap) fmap = {} for tree in trees: - print (tree) + sys.stdout.write(str(tree) + '\n') for l in tree.split('\n'): arr = l.split('[') if len(arr) == 1: @@ -430,56 +513,70 @@ class Booster: if fid not in fmap: fmap[fid] = 1 else: - fmap[fid]+= 1 + fmap[fid] += 1 return fmap -def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None): - """ train a booster with given paramaters - Args: - params: dict - params of booster - dtrain: DMatrix - data to be trained - num_boost_round: int - num of round to be boosted - watchlist: list of pairs (DMatrix, string) - list of items to be evaluated during training, this allows user to watch performance on validation set - obj: function - cutomized objective function - feval: function - cutomized evaluation function - Returns: Booster model trained + +def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None): """ - bst = Booster(params, [dtrain]+[ d[0] for d in evals ] ) + Train a booster with given parameters. + + Parameters + ---------- + params : dict + Booster params. + dtrain : DMatrix + Data to be trained. + num_boost_round: int + Number of boosting iterations. + watchlist : list of pairs (DMatrix, string) + List of items to be evaluated during training, this allows user to watch + performance on the validation set. + obj : function + Customized objective function. + feval : function + Customized evaluation function. + + Returns + ------- + booster : a trained booster model + """ + evals = list(evals) + bst = Booster(params, [dtrain] + [d[0] for d in evals]) for i in range(num_boost_round): - bst.update( dtrain, i, obj ) + bst.update(dtrain, i, obj) if len(evals) != 0: - bst_eval_set=bst.eval_set(evals, i, feval) - if isinstance(bst_eval_set,str): - sys.stderr.write(bst_eval_set+'\n') + bst_eval_set = bst.eval_set(evals, i, feval) + if isinstance(bst_eval_set, string_types): + sys.stderr.write(bst_eval_set + '\n') else: - sys.stderr.write(bst_eval_set.decode()+'\n') + sys.stderr.write(bst_eval_set.decode() + '\n') return bst -class CVPack: + +class CVPack(object): def __init__(self, dtrain, dtest, param): self.dtrain = dtrain self.dtest = dtest - self.watchlist = watchlist = [ (dtrain,'train'), (dtest, 'test') ] - self.bst = Booster(param, [dtrain,dtest]) + self.watchlist = [(dtrain, 'train'), (dtest, 'test')] + self.bst = Booster(param, [dtrain, dtest]) + def update(self, r, fobj): self.bst.update(self.dtrain, r, fobj) + def eval(self, r, feval): return self.bst.eval_set(self.watchlist, r, feval) -def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None): + +def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None): """ - mk nfold list of cvpack from randidx + Make an n-fold list of CVPack from random indices. """ + evals = list(evals) np.random.seed(seed) randidx = np.random.permutation(dall.num_row()) kstep = len(randidx) / nfold - idset = [randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ] for i in range(nfold)] + idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)] ret = [] for k in range(nfold): dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i])) @@ -493,9 +590,10 @@ def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None): ret.append(CVPack(dtrain, dtest, plst)) return ret + def aggcv(rlist, show_stdv=True): """ - aggregate cross validation results + Aggregate cross-validation results. """ cvmap = {} ret = rlist[0].split()[0] @@ -503,15 +601,15 @@ def aggcv(rlist, show_stdv=True): arr = line.split() assert ret == arr[0] for it in arr[1:]: - if not isinstance(it,str): - it=it.decode() - k, v = it.split(':') + if not isinstance(it, string_types): + it = it.decode() + k, v = it.split(':') if k not in cvmap: cvmap[k] = [] cvmap[k].append(float(v)) - for k, v in sorted(cvmap.items(), key = lambda x:x[0]): + for k, v in sorted(cvmap.items(), key=lambda x: x[0]): v = np.array(v) - if not isinstance(ret,str): + if not isinstance(ret, string_types): ret = ret.decode() if show_stdv: ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v)) @@ -519,33 +617,39 @@ def aggcv(rlist, show_stdv=True): ret += '\tcv-%s:%f' % (k, np.mean(v)) return ret -def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \ - obj = None, feval = None, fpreproc = None, show_stdv = True, seed = 0): - """ cross validation with given paramaters - Args: - params: dict - params of booster - dtrain: DMatrix - data to be trained - num_boost_round: int - num of round to be boosted - nfold: int - number of folds to do cv - metrics: list of strings - evaluation metrics to be watched in cv - obj: function - custom objective function - feval: function - custom evaluation function - fpreproc: function - preprocessing function that takes dtrain, dtest, - param and return transformed version of dtrain, dtest, param - show_stdv: bool - whether display standard deviation - seed: int - seed used to generate the folds, this is passed to numpy.random.seed - Returns: list(string) of evaluation history +def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), + obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0): + """ + Cross-validation with given paramaters. + + Parameters + ---------- + params : dict + Booster params. + dtrain : DMatrix + Data to be trained. + num_boost_round : int + Number of boosting iterations. + nfold : int + Number of folds in CV. + metrics : list of strings + Evaluation metrics to be watched in CV. + obj : function + Custom objective function. + feval : function + Custom evaluation function. + fpreproc : function + Preprocessing function that takes (dtrain, dtest, param) and returns + transformed versions of those. + show_stdv : bool + Whether to display the standard deviation. + seed : int + Seed used to generate the folds (passed to numpy.random.seed). + + Returns + ------- + evaluation history : list(string) """ results = [] cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc) @@ -553,7 +657,6 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \ for f in cvfolds: f.update(i, obj) res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv) - sys.stderr.write(res+'\n') + sys.stderr.write(res + '\n') results.append(res) return results -