make python lint

This commit is contained in:
tqchen 2015-07-03 20:36:41 -07:00
parent 57ec922214
commit 59b91cf205
2 changed files with 263 additions and 169 deletions

View File

@ -1,9 +1,12 @@
# pylint: disable=invalid-name
"""Setup xgboost package."""
import os import os
import platform import platform
from setuptools import setup from setuptools import setup
class XGBoostLibraryNotFound(Exception): class XGBoostLibraryNotFound(Exception):
"""Exception to raise when xgboost library cannot be found."""
pass pass

View File

@ -6,7 +6,7 @@ Version: 0.40
Authors: Tianqi Chen, Bing Xu Authors: Tianqi Chen, Bing Xu
Early stopping by Zygmunt Zając Early stopping by Zygmunt Zając
""" """
# pylint: disable=too-many-arguments, too-many-locals, too-many-lines
from __future__ import absolute_import from __future__ import absolute_import
import os import os
@ -28,20 +28,25 @@ except ImportError:
SKLEARN_INSTALLED = False SKLEARN_INSTALLED = False
class XGBoostLibraryNotFound(Exception): class XGBoostLibraryNotFound(Exception):
"""Error throwed by when xgboost is not found"""
pass pass
class XGBoostError(Exception): class XGBoostError(Exception):
"""Error throwed by xgboost trainer."""
pass pass
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train'] __all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
string_types = str, # pylint: disable=invalid-name
STRING_TYPES = str,
else: else:
string_types = basestring, # pylint: disable=invalid-name
STRING_TYPES = basestring,
def load_xglib(): def load_xglib():
"""Load the xgboost library."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [curr_path] dll_path = [curr_path]
if os.name == 'nt': if os.name == 'nt':
@ -55,7 +60,8 @@ def load_xglib():
dll_path = [os.path.join(p, 'libxgboostwrapper.so') for p in dll_path] dll_path = [os.path.join(p, 'libxgboostwrapper.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(dll_path) == 0: if len(dll_path) == 0:
raise XGBoostLibraryNotFound('cannot find find the files in the candicate path ' + str(dll_path)) raise XGBoostLibraryNotFound(
'cannot find find the files in the candicate path ' + str(dll_path))
lib = ctypes.cdll.LoadLibrary(lib_path[0]) lib = ctypes.cdll.LoadLibrary(lib_path[0])
# DMatrix functions # DMatrix functions
@ -79,12 +85,11 @@ def load_xglib():
return lib return lib
# load the XGBoost library globally # load the XGBoost library globally
xglib = load_xglib() _LIB = load_xglib()
def ctypes2numpy(cptr, length, dtype): def ctypes2numpy(cptr, length, dtype):
""" """Convert a ctypes pointer array to a numpy array.
Convert a ctypes pointer array to a numpy array.
""" """
if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)): if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
raise RuntimeError('expected float pointer') raise RuntimeError('expected float pointer')
@ -95,6 +100,7 @@ def ctypes2numpy(cptr, length, dtype):
def ctypes2buffer(cptr, length): def ctypes2buffer(cptr, length):
"""Convert ctypes pointer to buffer type."""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
raise RuntimeError('expected char pointer') raise RuntimeError('expected char pointer')
res = bytearray(length) res = bytearray(length)
@ -105,14 +111,17 @@ def ctypes2buffer(cptr, length):
def c_str(string): def c_str(string):
"""Convert a python string to cstring."""
return ctypes.c_char_p(string.encode('utf-8')) return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values): def c_array(ctype, values):
"""Convert a python string to c array."""
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
class DMatrix(object): class DMatrix(object):
"""Data Matrix used in XGBoost."""
def __init__(self, data, label=None, missing=0.0, weight=None, silent=False): def __init__(self, data, label=None, missing=0.0, weight=None, silent=False):
""" """
Data matrix used in XGBoost. Data matrix used in XGBoost.
@ -135,8 +144,8 @@ class DMatrix(object):
if data is None: if data is None:
self.handle = None self.handle = None
return return
if isinstance(data, string_types): if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromFile(c_str(data), int(silent))) self.handle = ctypes.c_void_p(_LIB.XGDMatrixCreateFromFile(c_str(data), int(silent)))
elif isinstance(data, scipy.sparse.csr_matrix): elif isinstance(data, scipy.sparse.csr_matrix):
self._init_from_csr(data) self._init_from_csr(data)
elif isinstance(data, scipy.sparse.csc_matrix): elif isinstance(data, scipy.sparse.csc_matrix):
@ -160,7 +169,7 @@ class DMatrix(object):
""" """
if len(csr.indices) != len(csr.data): if len(csr.indices) != len(csr.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data))) raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR( self.handle = ctypes.c_void_p(_LIB.XGDMatrixCreateFromCSR(
c_array(ctypes.c_ulong, csr.indptr), c_array(ctypes.c_ulong, csr.indptr),
c_array(ctypes.c_uint, csr.indices), c_array(ctypes.c_uint, csr.indices),
c_array(ctypes.c_float, csr.data), c_array(ctypes.c_float, csr.data),
@ -172,7 +181,7 @@ class DMatrix(object):
""" """
if len(csc.indices) != len(csc.data): if len(csc.indices) != len(csc.data):
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data))) raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSC( self.handle = ctypes.c_void_p(_LIB.XGDMatrixCreateFromCSC(
c_array(ctypes.c_ulong, csc.indptr), c_array(ctypes.c_ulong, csc.indptr),
c_array(ctypes.c_uint, csc.indices), c_array(ctypes.c_uint, csc.indices),
c_array(ctypes.c_float, csc.data), c_array(ctypes.c_float, csc.data),
@ -183,34 +192,77 @@ class DMatrix(object):
Initialize data from a 2-D numpy matrix. Initialize data from a 2-D numpy matrix.
""" """
data = np.array(mat.reshape(mat.size), dtype=np.float32) data = np.array(mat.reshape(mat.size), dtype=np.float32)
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat( self.handle = ctypes.c_void_p(_LIB.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
mat.shape[0], mat.shape[1], ctypes.c_float(missing))) mat.shape[0], mat.shape[1], ctypes.c_float(missing)))
def __del__(self): def __del__(self):
xglib.XGDMatrixFree(self.handle) _LIB.XGDMatrixFree(self.handle)
def get_float_info(self, field): def get_float_info(self, field):
"""Get float property from the DMatrix.
Parameters
----------
field: str
The field name of the information
Returns
-------
info : array
a numpy array of float information of the data
"""
length = ctypes.c_ulong() length = ctypes.c_ulong()
ret = xglib.XGDMatrixGetFloatInfo(self.handle, c_str(field), ctypes.byref(length)) ret = _LIB.XGDMatrixGetFloatInfo(self.handle, c_str(field), ctypes.byref(length))
return ctypes2numpy(ret, length.value, np.float32) return ctypes2numpy(ret, length.value, np.float32)
def get_uint_info(self, field): def get_uint_info(self, field):
"""Get unsigned integer property from the DMatrix.
Parameters
----------
field: str
The field name of the information
Returns
-------
info : array
a numpy array of float information of the data
"""
length = ctypes.c_ulong() length = ctypes.c_ulong()
ret = xglib.XGDMatrixGetUIntInfo(self.handle, c_str(field), ctypes.byref(length)) ret = _LIB.XGDMatrixGetUIntInfo(self.handle, c_str(field), ctypes.byref(length))
return ctypes2numpy(ret, length.value, np.uint32) return ctypes2numpy(ret, length.value, np.uint32)
def set_float_info(self, field, data): def set_float_info(self, field, data):
xglib.XGDMatrixSetFloatInfo(self.handle, c_str(field), """Set float type property into the DMatrix.
Parameters
----------
field: str
The field name of the information
data: numpy array
The array ofdata to be set
"""
_LIB.XGDMatrixSetFloatInfo(self.handle, c_str(field),
c_array(ctypes.c_float, data), len(data)) c_array(ctypes.c_float, data), len(data))
def set_uint_info(self, field, data): def set_uint_info(self, field, data):
xglib.XGDMatrixSetUIntInfo(self.handle, c_str(field), """Set uint type property into the DMatrix.
Parameters
----------
field: str
The field name of the information
data: numpy array
The array ofdata to be set
"""
_LIB.XGDMatrixSetUIntInfo(self.handle, c_str(field),
c_array(ctypes.c_uint, data), len(data)) c_array(ctypes.c_uint, data), len(data))
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
""" """Save DMatrix to an XGBoost buffer.
Save DMatrix to an XGBoost buffer.
Parameters Parameters
---------- ----------
@ -219,74 +271,74 @@ class DMatrix(object):
silent : bool (optional; default: True) silent : bool (optional; default: True)
If set, the output is suppressed. If set, the output is suppressed.
""" """
xglib.XGDMatrixSaveBinary(self.handle, c_str(fname), int(silent)) _LIB.XGDMatrixSaveBinary(self.handle, c_str(fname), int(silent))
def set_label(self, label): def set_label(self, label):
"""set label of dmatrix """Set label of dmatrix
Args:
label: list Parameters
label for DMatrix ----------
Returns: label: array like
None The label information to be set into DMatrix
""" """
self.set_float_info('label', label) self.set_float_info('label', label)
def set_weight(self, weight): def set_weight(self, weight):
""" """ Set weight of each instance.
Set weight of each instance.
Parameters Parameters
---------- ----------
weight : float weight : array like
Weight for positive instance. Weight for each data point
""" """
self.set_float_info('weight', weight) self.set_float_info('weight', weight)
def set_base_margin(self, margin): def set_base_margin(self, margin):
""" """ Set base margin of booster to start from.
set base margin of booster to start from
this can be used to specify a prediction value of This can be used to specify a prediction value of
existing model to be base_margin existing model to be base_margin
However, remember margin is needed, instead of transformed prediction However, remember margin is needed, instead of transformed prediction
e.g. for logistic regression: need to put in value before logistic transformation e.g. for logistic regression: need to put in value before logistic transformation
see also example/demo.py see also example/demo.py
Parameters
----------
margin: array like
Prediction margin of each datapoint
""" """
self.set_float_info('base_margin', margin) self.set_float_info('base_margin', margin)
def set_group(self, group): def set_group(self, group):
""" """Set group size of DMatrix (used for ranking).
Set group size of DMatrix (used for ranking).
Parameters Parameters
---------- ----------
group : int group : array like
Group size. Group size of each group
""" """
xglib.XGDMatrixSetGroup(self.handle, c_array(ctypes.c_uint, group), len(group)) _LIB.XGDMatrixSetGroup(self.handle, c_array(ctypes.c_uint, group), len(group))
def get_label(self): def get_label(self):
""" """Get the label of the DMatrix.
Get the label of the DMatrix.
Returns Returns
------- -------
label : list label : array
""" """
return self.get_float_info('label') return self.get_float_info('label')
def get_weight(self): def get_weight(self):
""" """Get the weight of the DMatrix.
Get the weight of the DMatrix.
Returns Returns
------- -------
weight : float weight : array
""" """
return self.get_float_info('weight') return self.get_float_info('weight')
def get_base_margin(self): def get_base_margin(self):
""" """Get the base margin of the DMatrix.
Get the base margin of the DMatrix.
Returns Returns
------- -------
@ -295,18 +347,16 @@ class DMatrix(object):
return self.get_float_info('base_margin') return self.get_float_info('base_margin')
def num_row(self): def num_row(self):
""" """Get the number of rows in the DMatrix.
Get the number of rows in the DMatrix.
Returns Returns
------- -------
number of rows : int number of rows : int
""" """
return xglib.XGDMatrixNumRow(self.handle) return _LIB.XGDMatrixNumRow(self.handle)
def slice(self, rindex): def slice(self, rindex):
""" """Slice the DMatrix and return a new DMatrix that only contains `rindex`.
Slice the DMatrix and return a new DMatrix that only contains `rindex`.
Parameters Parameters
---------- ----------
@ -319,13 +369,15 @@ class DMatrix(object):
A new DMatrix containing only selected indices. A new DMatrix containing only selected indices.
""" """
res = DMatrix(None) res = DMatrix(None)
res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix( res.handle = ctypes.c_void_p(_LIB.XGDMatrixSliceDMatrix(
self.handle, c_array(ctypes.c_int, rindex), len(rindex))) self.handle, c_array(ctypes.c_int, rindex), len(rindex)))
return res return res
class Booster(object): class Booster(object):
""""A Booster of of XGBoost."""
def __init__(self, params=None, cache=(), model_file=None): def __init__(self, params=None, cache=(), model_file=None):
# pylint: disable=invalid-name
""" """
Learner class. Learner class.
@ -342,14 +394,14 @@ class Booster(object):
if not isinstance(d, DMatrix): if not isinstance(d, DMatrix):
raise TypeError('invalid cache item: {}'.format(type(d).__name__)) raise TypeError('invalid cache item: {}'.format(type(d).__name__))
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache]) dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache))) self.handle = ctypes.c_void_p(_LIB.XGBoosterCreate(dmats, len(cache)))
self.set_param({'seed': 0}) self.set_param({'seed': 0})
self.set_param(params or {}) self.set_param(params or {})
if model_file is not None: if model_file is not None:
self.load_model(model_file) self.load_model(model_file)
def __del__(self): def __del__(self):
xglib.XGBoosterFree(self.handle) _LIB.XGBoosterFree(self.handle)
def __getstate__(self): def __getstate__(self):
# can't pickle ctypes pointers # can't pickle ctypes pointers
@ -367,10 +419,10 @@ class Booster(object):
if handle is not None: if handle is not None:
buf = handle buf = handle
dmats = c_array(ctypes.c_void_p, []) dmats = c_array(ctypes.c_void_p, [])
handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, 0)) handle = ctypes.c_void_p(_LIB.XGBoosterCreate(dmats, 0))
length = ctypes.c_ulong(len(buf)) length = ctypes.c_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf) ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
xglib.XGBoosterLoadModelFromBuffer(handle, ptr, length) _LIB.XGBoosterLoadModelFromBuffer(handle, ptr, length)
state['handle'] = handle state['handle'] = handle
self.__dict__.update(state) self.__dict__.update(state)
self.set_param({'seed': 0}) self.set_param({'seed': 0})
@ -382,8 +434,7 @@ class Booster(object):
return Booster(model_file=self.save_raw()) return Booster(model_file=self.save_raw())
def copy(self): def copy(self):
""" """Copy the booster object.
Copy the booster object
Returns Returns
-------- --------
@ -391,15 +442,16 @@ class Booster(object):
""" """
return self.__copy__() return self.__copy__()
def set_param(self, params, pv=None): def set_param(self, params, value=None):
"""Set parameters into the DMatrix."""
if isinstance(params, collections.Mapping): if isinstance(params, collections.Mapping):
params = params.items() params = params.items()
elif isinstance(params, string_types) and pv is not None: elif isinstance(params, STRING_TYPES) and value is not None:
params = [(params, pv)] params = [(params, value)]
for k, v in params: for key, val in params:
xglib.XGBoosterSetParam(self.handle, c_str(k), c_str(str(v))) _LIB.XGBoosterSetParam(self.handle, c_str(key), c_str(str(val)))
def update(self, dtrain, it, fobj=None): def update(self, dtrain, iteration, fobj=None):
""" """
Update (one iteration). Update (one iteration).
@ -407,7 +459,7 @@ class Booster(object):
---------- ----------
dtrain : DMatrix dtrain : DMatrix
Training data. Training data.
it : int iteration : int
Current iteration number. Current iteration number.
fobj : function fobj : function
Customized objective function. Customized objective function.
@ -415,7 +467,7 @@ class Booster(object):
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
if fobj is None: if fobj is None:
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle) _LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle)
else: else:
pred = self.predict(dtrain) pred = self.predict(dtrain)
grad, hess = fobj(pred, dtrain) grad, hess = fobj(pred, dtrain)
@ -438,20 +490,20 @@ class Booster(object):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle, _LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
c_array(ctypes.c_float, grad), c_array(ctypes.c_float, grad),
c_array(ctypes.c_float, hess), c_array(ctypes.c_float, hess),
len(grad)) len(grad))
def eval_set(self, evals, it=0, feval=None): def eval_set(self, evals, iteration=0, feval=None):
""" # pylint: disable=invalid-name
Evaluate by a metric. """Evaluate a set of data.
Parameters Parameters
---------- ----------
evals : list of tuples (DMatrix, string) evals : list of tuples (DMatrix, string)
List of items to be evaluated. List of items to be evaluated.
it : int iteration : int
Current iteration. Current iteration.
feval : function feval : function
Custom evaluation function. Custom evaluation function.
@ -464,20 +516,35 @@ class Booster(object):
for d in evals: for d in evals:
if not isinstance(d[0], DMatrix): if not isinstance(d[0], DMatrix):
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__)) raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
if not isinstance(d[1], string_types): if not isinstance(d[1], STRING_TYPES):
raise TypeError('expected string, got {}'.format(type(d[1]).__name__)) raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals]) dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals]) evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals)) return _LIB.XGBoosterEvalOneIter(self.handle, iteration, dmats, evnames, len(evals))
else: else:
res = '[%d]' % it res = '[%d]' % iteration
for dm, evname in evals: for dmat, evname in evals:
name, val = feval(self.predict(dm), dm) name, val = feval(self.predict(dmat), dmat)
res += '\t%s-%s:%f' % (evname, name, val) res += '\t%s-%s:%f' % (evname, name, val)
return res return res
def eval(self, mat, name='eval', it=0): def eval(self, data, name='eval', iteration=0):
return self.eval_set([(mat, name)], it) """Evaluate the model on mat.
Parameters
---------
data : DMatrix
The dmatrix storing the input.
name : str (default = 'eval')
The name of the dataset
iteration : int (default = 0)
The current iteration number
"""
return self.eval_set([(data, name)], iteration)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
""" """
@ -492,10 +559,13 @@ class Booster(object):
---------- ----------
data : DMatrix data : DMatrix
The dmatrix storing the input. The dmatrix storing the input.
output_margin : bool output_margin : bool
Whether to output the raw untransformed margin value. Whether to output the raw untransformed margin value.
ntree_limit : int ntree_limit : int
Limit number of trees in the prediction; defaults to 0 (use all trees). Limit number of trees in the prediction; defaults to 0 (use all trees).
pred_leaf : bool pred_leaf : bool
When this option is on, the output will be a matrix of (nsample, ntrees) When this option is on, the output will be a matrix of (nsample, ntrees)
with each record indicating the predicted leaf index of each sample in each tree. with each record indicating the predicted leaf index of each sample in each tree.
@ -512,7 +582,7 @@ class Booster(object):
if pred_leaf: if pred_leaf:
option_mask |= 0x02 option_mask |= 0x02
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle, preds = _LIB.XGBoosterPredict(self.handle, data.handle,
option_mask, ntree_limit, ctypes.byref(length)) option_mask, ntree_limit, ctypes.byref(length))
preds = ctypes2numpy(preds, length.value, np.float32) preds = ctypes2numpy(preds, length.value, np.float32)
if pred_leaf: if pred_leaf:
@ -531,8 +601,8 @@ class Booster(object):
fname : string fname : string
Output file name Output file name
""" """
if isinstance(fname, string_types): # assume file name if isinstance(fname, STRING_TYPES): # assume file name
xglib.XGBoosterSaveModel(self.handle, c_str(fname)) _LIB.XGBoosterSaveModel(self.handle, c_str(fname))
else: else:
raise TypeError("fname must be a string") raise TypeError("fname must be a string")
@ -545,7 +615,7 @@ class Booster(object):
a in memory buffer represetation of the model a in memory buffer represetation of the model
""" """
length = ctypes.c_ulong() length = ctypes.c_ulong()
cptr = xglib.XGBoosterGetModelRaw(self.handle, cptr = _LIB.XGBoosterGetModelRaw(self.handle,
ctypes.byref(length)) ctypes.byref(length))
return ctypes2buffer(cptr, length.value) return ctypes2buffer(cptr, length.value)
@ -559,44 +629,44 @@ class Booster(object):
Input file name or memory buffer(see also save_raw) Input file name or memory buffer(see also save_raw)
""" """
if isinstance(fname, str): # assume file name if isinstance(fname, str): # assume file name
xglib.XGBoosterLoadModel(self.handle, c_str(fname)) _LIB.XGBoosterLoadModel(self.handle, c_str(fname))
else: else:
buf = fname buf = fname
length = ctypes.c_ulong(len(buf)) length = ctypes.c_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf) ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
xglib.XGBoosterLoadModelFromBuffer(self.handle, ptr, length) _LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)
def dump_model(self, fo, fmap='', with_stats=False): def dump_model(self, fout, fmap='', with_stats=False):
""" """
Dump model into a text file. Dump model into a text file.
Parameters Parameters
---------- ----------
fo : string foout : string
Output file name. Output file name.
fmap : string, optional fmap : string, optional
Name of the file containing feature map names. Name of the file containing feature map names.
with_stats : bool (optional) with_stats : bool (optional)
Controls whether the split statistics are output. Controls whether the split statistics are output.
""" """
if isinstance(fo, string_types): if isinstance(fout, STRING_TYPES):
fo = open(fo, 'w') fout = open(fout, 'w')
need_close = True need_close = True
else: else:
need_close = False need_close = False
ret = self.get_dump(fmap, with_stats) ret = self.get_dump(fmap, with_stats)
for i in range(len(ret)): for i in range(len(ret)):
fo.write('booster[{}]:\n'.format(i)) fout.write('booster[{}]:\n'.format(i))
fo.write(ret[i]) fout.write(ret[i])
if need_close: if need_close:
fo.close() fout.close()
def get_dump(self, fmap='', with_stats=False): def get_dump(self, fmap='', with_stats=False):
""" """
Returns the dump the model as a list of strings. Returns the dump the model as a list of strings.
""" """
length = ctypes.c_ulong() length = ctypes.c_ulong()
sarr = xglib.XGBoosterDumpModel(self.handle, c_str(fmap), sarr = _LIB.XGBoosterDumpModel(self.handle, c_str(fmap),
int(with_stats), ctypes.byref(length)) int(with_stats), ctypes.byref(length))
res = [] res = []
for i in range(length.value): for i in range(length.value):
@ -604,14 +674,18 @@ class Booster(object):
return res return res
def get_fscore(self, fmap=''): def get_fscore(self, fmap=''):
""" """Get feature importance of each feature.
Get feature importance of each feature.
Parameters
----------
fmap: str (optional)
The name of feature map file
""" """
trees = self.get_dump(fmap) trees = self.get_dump(fmap)
fmap = {} fmap = {}
for tree in trees: for tree in trees:
for l in tree.split('\n'): for line in tree.split('\n'):
arr = l.split('[') arr = line.split('[')
if len(arr) == 1: if len(arr) == 1:
continue continue
fid = arr[1].split(']')[0] fid = arr[1].split(']')[0]
@ -625,8 +699,8 @@ class Booster(object):
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
early_stopping_rounds=None, evals_result=None): early_stopping_rounds=None, evals_result=None):
""" # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
Train a booster with given parameters. """Train a booster with given parameters.
Parameters Parameters
---------- ----------
@ -663,7 +737,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
if evals_result is not None: if evals_result is not None:
if type(evals_result) is not dict: if isinstance(evals_result, dict):
raise TypeError('evals_result has to be a dictionary') raise TypeError('evals_result has to be a dictionary')
else: else:
evals_name = [d[1] for d in evals] evals_name = [d[1] for d in evals]
@ -675,7 +749,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
if len(evals) != 0: if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types): if isinstance(bst_eval_set, STRING_TYPES):
msg = bst_eval_set msg = bst_eval_set
else: else:
msg = bst_eval_set.decode() msg = bst_eval_set.decode()
@ -689,23 +763,24 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
else: else:
# early stopping # early stopping
if len(evals) < 1: if len(evals) < 1:
raise ValueError('For early stopping you need at least one set in evals.') raise ValueError('For early stopping you need at least one set in evals.')
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds)) sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(\
evals[-1][1], early_stopping_rounds))
# is params a list of tuples? are we using multiple eval metrics? # is params a list of tuples? are we using multiple eval metrics?
if type(params) == list: if isinstance(params, list):
if len(params) != len(dict(params).items()): if len(params) != len(dict(params).items()):
raise ValueError('Check your params. Early stopping works with single eval metric only.') raise ValueError('Check your params.'\
'Early stopping works with single eval metric only.')
params = dict(params) params = dict(params)
# either minimize loss or maximize AUC/MAP/NDCG # either minimize loss or maximize AUC/MAP/NDCG
maximize_score = False maximize_score = False
if 'eval_metric' in params: if 'eval_metric' in params:
maximize_metrics = ('auc', 'map', 'ndcg') maximize_metrics = ('auc', 'map', 'ndcg')
if list(filter(lambda x: params['eval_metric'].startswith(x), maximize_metrics)): if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
maximize_score = True maximize_score = True
if maximize_score: if maximize_score:
@ -720,7 +795,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types): if isinstance(bst_eval_set, STRING_TYPES):
msg = bst_eval_set msg = bst_eval_set
else: else:
msg = bst_eval_set.decode() msg = bst_eval_set.decode()
@ -748,17 +823,21 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
return bst return bst
class CVPack(object): class CVPack(object):
""""Auxiliary datastruct to hold one fold of CV."""
def __init__(self, dtrain, dtest, param): def __init__(self, dtrain, dtest, param):
""""Initialize the CVPack"""
self.dtrain = dtrain self.dtrain = dtrain
self.dtest = dtest self.dtest = dtest
self.watchlist = [(dtrain, 'train'), (dtest, 'test')] self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
self.bst = Booster(param, [dtrain, dtest]) self.bst = Booster(param, [dtrain, dtest])
def update(self, r, fobj): def update(self, iteration, fobj):
self.bst.update(self.dtrain, r, fobj) """"Update the boosters for one iteration"""
self.bst.update(self.dtrain, iteration, fobj)
def eval(self, r, feval): def eval(self, iteration, feval):
return self.bst.eval_set(self.watchlist, r, feval) """"Evaluate the CVPack for one iteration."""
return self.bst.eval_set(self.watchlist, iteration, feval)
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None): def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
@ -785,6 +864,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
def aggcv(rlist, show_stdv=True): def aggcv(rlist, show_stdv=True):
# pylint: disable=invalid-name
""" """
Aggregate cross-validation results. Aggregate cross-validation results.
""" """
@ -794,7 +874,7 @@ def aggcv(rlist, show_stdv=True):
arr = line.split() arr = line.split()
assert ret == arr[0] assert ret == arr[0]
for it in arr[1:]: for it in arr[1:]:
if not isinstance(it, string_types): if not isinstance(it, STRING_TYPES):
it = it.decode() it = it.decode()
k, v = it.split(':') k, v = it.split(':')
if k not in cvmap: if k not in cvmap:
@ -802,7 +882,7 @@ def aggcv(rlist, show_stdv=True):
cvmap[k].append(float(v)) cvmap[k].append(float(v))
for k, v in sorted(cvmap.items(), key=lambda x: x[0]): for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
v = np.array(v) v = np.array(v)
if not isinstance(ret, string_types): if not isinstance(ret, STRING_TYPES):
ret = ret.decode() ret = ret.decode()
if show_stdv: if show_stdv:
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v)) ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
@ -813,8 +893,8 @@ def aggcv(rlist, show_stdv=True):
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0): obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0):
""" # pylint: disable = invalid-name
Cross-validation with given paramaters. """Cross-validation with given paramaters.
Parameters Parameters
---------- ----------
@ -847,8 +927,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
results = [] results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc) cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
for i in range(num_boost_round): for i in range(num_boost_round):
for f in cvfolds: for fold in cvfolds:
f.update(i, obj) fold.update(i, obj)
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv) res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
sys.stderr.write(res + '\n') sys.stderr.write(res + '\n')
results.append(res) results.append(res)
@ -857,16 +937,16 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
# used for compatiblity without sklearn # used for compatiblity without sklearn
XGBModelBase = object XGBModelBase = object
XGBClassifier = object XGBClassifierBase = object
XGBRegressor = object XGBRegressorBase = object
if SKLEARN_INSTALLED: if SKLEARN_INSTALLED:
XGBModelBase = BaseEstimator XGBModelBase = BaseEstimator
XGBRegressor = RegressorMixin XGBRegressorBase = RegressorMixin
XGBClassifier = ClassifierMixin XGBClassifierBase = ClassifierMixin
class XGBModel(XGBModelBase): class XGBModel(XGBModelBase):
""" # pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
Implementation of the Scikit-Learn API for XGBoost. """Implementation of the Scikit-Learn API for XGBoost.
Parameters Parameters
---------- ----------
@ -902,8 +982,10 @@ class XGBModel(XGBModelBase):
Value in the data which needs to be present as a missing value. If Value in the data which needs to be present as a missing value. If
None, defaults to np.nan. None, defaults to np.nan.
""" """
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="reg:linear", def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, silent=True, objective="reg:linear",
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1,
base_score=0.5, seed=0, missing=None): base_score=0.5, seed=0, missing=None):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use this module') raise XGBoostError('sklearn needs to be installed in order to use this module')
@ -923,7 +1005,6 @@ class XGBModel(XGBModelBase):
self.base_score = base_score self.base_score = base_score
self.seed = seed self.seed = seed
self.missing = missing if missing is not None else np.nan self.missing = missing if missing is not None else np.nan
self._Booster = None self._Booster = None
def __setstate__(self, state): def __setstate__(self, state):
@ -936,9 +1017,9 @@ class XGBModel(XGBModelBase):
self.__dict__.update(state) self.__dict__.update(state)
def booster(self): def booster(self):
""" """Get the underlying xgboost Booster of this model.
get the underlying xgboost Booster of this model
will raise an exception when fit was not called This will raise an exception when fit was not called
Returns Returns
------- -------
@ -949,12 +1030,14 @@ class XGBModel(XGBModelBase):
return self._Booster return self._Booster
def get_params(self, deep=False): def get_params(self, deep=False):
"""Get parameter.s"""
params = super(XGBModel, self).get_params(deep=deep) params = super(XGBModel, self).get_params(deep=deep)
if params['missing'] is np.nan: if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725 params['missing'] = None # sklearn doesn't handle nan. see #4725
return params return params
def get_xgb_params(self): def get_xgb_params(self):
"""Get xgboost type parameters."""
xgb_params = self.get_params() xgb_params = self.get_params()
xgb_params['silent'] = 1 if self.silent else 0 xgb_params['silent'] = 1 if self.silent else 0
@ -963,30 +1046,39 @@ class XGBModel(XGBModelBase):
xgb_params.pop('nthread', None) xgb_params.pop('nthread', None)
return xgb_params return xgb_params
def fit(self, X, y): def fit(self, data, y):
trainDmatrix = DMatrix(X, label=y, missing=self.missing) # pylint: disable=missing-docstring,invalid-name
self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_estimators) train_dmatrix = DMatrix(data, label=y, missing=self.missing)
self._Booster = train(self.get_xgb_params(), train_dmatrix, self.n_estimators)
return self return self
def predict(self, X): def predict(self, data):
testDmatrix = DMatrix(X, missing=self.missing) # pylint: disable=missing-docstring,invalid-name
return self.booster().predict(testDmatrix) test_dmatrix = DMatrix(data, missing=self.missing)
return self.booster().predict(test_dmatrix)
class XGBClassifier(XGBModel, XGBClassifier): class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
__doc__ = """ __doc__ = """
Implementation of the scikit-learn API for XGBoost classification Implementation of the scikit-learn API for XGBoost classification
""" + "\n".join(XGBModel.__doc__.split('\n')[2:]) """ + "\n".join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="binary:logistic", def __init__(self, max_depth=3, learning_rate=0.1,
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, n_estimators=100, silent=True,
objective="binary:logistic",
nthread=-1, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1,
base_score=0.5, seed=0, missing=None): base_score=0.5, seed=0, missing=None):
super(XGBClassifier, self).__init__(max_depth, learning_rate, n_estimators, silent, objective, super(XGBClassifier, self).__init__(max_depth, learning_rate,
nthread, gamma, min_child_weight, max_delta_step, subsample, n_estimators, silent, objective,
nthread, gamma, min_child_weight,
max_delta_step, subsample,
colsample_bytree, colsample_bytree,
base_score, seed, missing) base_score, seed, missing)
def fit(self, X, y, sample_weight=None): def fit(self, X, y, sample_weight=None):
# pylint: disable = attribute-defined-outside-init,arguments-differ
self.classes_ = list(np.unique(y)) self.classes_ = list(np.unique(y))
self.n_classes_ = len(self.classes_) self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2: if self.n_classes_ > 2:
@ -1001,29 +1093,29 @@ class XGBClassifier(XGBModel, XGBClassifier):
training_labels = self._le.transform(y) training_labels = self._le.transform(y)
if sample_weight is not None: if sample_weight is not None:
trainDmatrix = DMatrix(X, label=training_labels, weight=sample_weight, train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
missing=self.missing) missing=self.missing)
else: else:
trainDmatrix = DMatrix(X, label=training_labels, train_dmatrix = DMatrix(X, label=training_labels,
missing=self.missing) missing=self.missing)
self._Booster = train(xgb_options, trainDmatrix, self.n_estimators) self._Booster = train(xgb_options, train_dmatrix, self.n_estimators)
return self return self
def predict(self, X): def predict(self, data):
testDmatrix = DMatrix(X, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(testDmatrix) class_probs = self.booster().predict(test_dmatrix)
if len(class_probs.shape) > 1: if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1) column_indexes = np.argmax(class_probs, axis=1)
else: else:
column_indexes = np.repeat(0, X.shape[0]) column_indexes = np.repeat(0, data.shape[0])
column_indexes[class_probs > 0.5] = 1 column_indexes[class_probs > 0.5] = 1
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
def predict_proba(self, X): def predict_proba(self, data):
testDmatrix = DMatrix(X, missing=self.missing) test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(testDmatrix) class_probs = self.booster().predict(test_dmatrix)
if self.objective == "multi:softprob": if self.objective == "multi:softprob":
return class_probs return class_probs
else: else:
@ -1031,9 +1123,8 @@ class XGBClassifier(XGBModel, XGBClassifier):
classzero_probs = 1.0 - classone_probs classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose() return np.vstack((classzero_probs, classone_probs)).transpose()
class XGBRegressor(XGBModel, XGBRegressor): class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
__doc__ = """ __doc__ = """
Implementation of the scikit-learn API for XGBoost regression Implementation of the scikit-learn API for XGBoost regression
""" + "\n".join(XGBModel.__doc__.split('\n')[2:]) """ + "\n".join(XGBModel.__doc__.split('\n')[2:])
pass