Merge pull request #178 from aldanor/master
[python] Fixed the dll import for relative paths + various cleanup.
This commit is contained in:
commit
15562126a6
@ -1,142 +1,185 @@
|
|||||||
"""
|
"""
|
||||||
xgboost: eXtreme Gradient Boosting library
|
xgboost: eXtreme Gradient Boosting library
|
||||||
Author: Tianqi Chen, Bing Xu
|
|
||||||
|
|
||||||
|
Authors: Tianqi Chen, Bing Xu
|
||||||
"""
|
"""
|
||||||
import ctypes
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import os
|
import os
|
||||||
# optinally have scipy sparse, though not necessary
|
|
||||||
import numpy as np
|
|
||||||
import sys
|
import sys
|
||||||
import numpy.ctypeslib
|
import ctypes
|
||||||
import scipy.sparse as scp
|
import collections
|
||||||
|
|
||||||
# set this line correctly
|
import numpy as np
|
||||||
if os.name == 'nt':
|
import scipy.sparse
|
||||||
XGBOOST_PATH = os.path.dirname(__file__)+'/../windows/x64/Release/xgboost_wrapper.dll'
|
|
||||||
|
__all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train']
|
||||||
|
|
||||||
|
if sys.version_info[0] == 3:
|
||||||
|
string_types = str,
|
||||||
else:
|
else:
|
||||||
XGBOOST_PATH = os.path.dirname(__file__)+'/libxgboostwrapper.so'
|
string_types = basestring,
|
||||||
|
|
||||||
|
|
||||||
|
def load_xglib():
|
||||||
|
dll_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
||||||
|
if os.name == 'nt':
|
||||||
|
dll_path = os.path.join(dll_path, '../windows/x64/Release/xgboost_wrapper.dll')
|
||||||
|
else:
|
||||||
|
dll_path = os.path.join(dll_path, 'libxgboostwrapper.so')
|
||||||
|
|
||||||
|
# load the xgboost wrapper library
|
||||||
|
lib = ctypes.cdll.LoadLibrary(dll_path)
|
||||||
|
|
||||||
|
# DMatrix functions
|
||||||
|
lib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
|
||||||
|
lib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
|
||||||
|
lib.XGDMatrixCreateFromCSC.restype = ctypes.c_void_p
|
||||||
|
lib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
|
||||||
|
lib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
|
||||||
|
lib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
lib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint)
|
||||||
|
lib.XGDMatrixNumRow.restype = ctypes.c_ulong
|
||||||
|
|
||||||
|
# Booster functions
|
||||||
|
lib.XGBoosterCreate.restype = ctypes.c_void_p
|
||||||
|
lib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
lib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
|
||||||
|
lib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
||||||
|
|
||||||
|
return lib
|
||||||
|
|
||||||
|
# load the XGBoost library globally
|
||||||
|
xglib = load_xglib()
|
||||||
|
|
||||||
# load in xgboost library
|
|
||||||
xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH)
|
|
||||||
# DMatrix functions
|
|
||||||
xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
|
|
||||||
xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
|
|
||||||
xglib.XGDMatrixCreateFromCSC.restype = ctypes.c_void_p
|
|
||||||
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
|
|
||||||
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
|
|
||||||
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
|
|
||||||
xglib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint)
|
|
||||||
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
|
|
||||||
# booster functions
|
|
||||||
xglib.XGBoosterCreate.restype = ctypes.c_void_p
|
|
||||||
xglib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
|
|
||||||
xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
|
|
||||||
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
|
||||||
|
|
||||||
def ctypes2numpy(cptr, length, dtype):
|
def ctypes2numpy(cptr, length, dtype):
|
||||||
"""convert a ctypes pointer array to numpy array """
|
"""
|
||||||
assert isinstance(cptr, ctypes.POINTER(ctypes.c_float))
|
Convert a ctypes pointer array to a numpy array.
|
||||||
res = numpy.zeros(length, dtype=dtype)
|
"""
|
||||||
assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0])
|
if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
|
||||||
|
raise RuntimeError('expected float pointer')
|
||||||
|
res = np.zeros(length, dtype=dtype)
|
||||||
|
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
||||||
|
raise RuntimeError('memmove failed')
|
||||||
return res
|
return res
|
||||||
|
|
||||||
class DMatrix:
|
|
||||||
"""data matrix used in xgboost"""
|
|
||||||
# constructor
|
|
||||||
def __init__(self, data, label=None, missing=0.0, weight = None):
|
|
||||||
""" constructor of DMatrix
|
|
||||||
|
|
||||||
Args:
|
def c_str(string):
|
||||||
data: string/numpy array/scipy.sparse
|
return ctypes.c_char_p(string.encode('utf-8'))
|
||||||
data source, string type is the path of svmlight format txt file or xgb buffer
|
|
||||||
label: list or numpy 1d array, optional
|
|
||||||
label of training data
|
def c_array(ctype, values):
|
||||||
missing: float
|
return (ctype * len(values))(*values)
|
||||||
value in data which need to be present as missing value
|
|
||||||
weight: list or numpy 1d array, optional
|
|
||||||
weight for each instances
|
class DMatrix(object):
|
||||||
|
def __init__(self, data, label=None, missing=0.0, weight=None):
|
||||||
"""
|
"""
|
||||||
|
Data matrix used in XGBoost.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : string/numpy array/scipy.sparse
|
||||||
|
Data source, string type is the path of svmlight format txt file or xgb buffer.
|
||||||
|
label : list or numpy 1-D array (optional)
|
||||||
|
Label of the training data.
|
||||||
|
missing : float
|
||||||
|
Value in the data which needs to be present as a missing value.
|
||||||
|
weight : list or numpy 1-D array (optional)
|
||||||
|
Weight for each instance.
|
||||||
|
"""
|
||||||
|
|
||||||
# force into void_p, mac need to pass things in as void_p
|
# force into void_p, mac need to pass things in as void_p
|
||||||
if data is None:
|
if data is None:
|
||||||
self.handle = None
|
self.handle = None
|
||||||
return
|
return
|
||||||
if isinstance(data, str):
|
if isinstance(data, string_types):
|
||||||
self.handle = ctypes.c_void_p(
|
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromFile(c_str(data), 0))
|
||||||
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0))
|
elif isinstance(data, scipy.sparse.csr_matrix):
|
||||||
elif isinstance(data, scp.csr_matrix):
|
self._init_from_csr(data)
|
||||||
self.__init_from_csr(data)
|
elif isinstance(data, scipy.sparse.csc_matrix):
|
||||||
elif isinstance(data, scp.csc_matrix):
|
self._init_from_csc(data)
|
||||||
self.__init_from_csc(data)
|
elif isinstance(data, np.ndarray) and len(data.shape) == 2:
|
||||||
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
self._init_from_npy2d(data, missing)
|
||||||
self.__init_from_npy2d(data, missing)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
csr = scp.csr_matrix(data)
|
csr = scipy.sparse.csr_matrix(data)
|
||||||
self.__init_from_csr(csr)
|
self._init_from_csr(csr)
|
||||||
except:
|
except:
|
||||||
raise Exception("can not intialize DMatrix from"+str(type(data)))
|
raise TypeError('can not intialize DMatrix from {}'.format(type(data).__name__))
|
||||||
if label != None:
|
if label is not None:
|
||||||
self.set_label(label)
|
self.set_label(label)
|
||||||
if weight !=None:
|
if weight is not None:
|
||||||
self.set_weight(weight)
|
self.set_weight(weight)
|
||||||
|
|
||||||
def __init_from_csr(self, csr):
|
def _init_from_csr(self, csr):
|
||||||
"""convert data from csr matrix"""
|
"""
|
||||||
assert len(csr.indices) == len(csr.data)
|
Initialize data from a CSR matrix.
|
||||||
|
"""
|
||||||
|
if len(csr.indices) != len(csr.data):
|
||||||
|
raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
|
||||||
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR(
|
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR(
|
||||||
(ctypes.c_ulong * len(csr.indptr))(*csr.indptr),
|
c_array(ctypes.c_ulong, csr.indptr),
|
||||||
(ctypes.c_uint * len(csr.indices))(*csr.indices),
|
c_array(ctypes.c_uint, csr.indices),
|
||||||
(ctypes.c_float * len(csr.data))(*csr.data),
|
c_array(ctypes.c_float, csr.data),
|
||||||
len(csr.indptr), len(csr.data)))
|
len(csr.indptr), len(csr.data)))
|
||||||
|
|
||||||
def __init_from_csc(self, csc):
|
def _init_from_csc(self, csc):
|
||||||
"""convert data from csr matrix"""
|
"""
|
||||||
assert len(csc.indices) == len(csc.data)
|
Initialize data from a CSC matrix.
|
||||||
|
"""
|
||||||
|
if len(csc.indices) != len(csc.data):
|
||||||
|
raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
|
||||||
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSC(
|
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSC(
|
||||||
(ctypes.c_ulong * len(csc.indptr))(*csc.indptr),
|
c_array(ctypes.c_ulong, csc.indptr),
|
||||||
(ctypes.c_uint * len(csc.indices))(*csc.indices),
|
c_array(ctypes.c_uint, csc.indices),
|
||||||
(ctypes.c_float * len(csc.data))(*csc.data),
|
c_array(ctypes.c_float, csc.data),
|
||||||
len(csc.indptr), len(csc.data)))
|
len(csc.indptr), len(csc.data)))
|
||||||
|
|
||||||
def __init_from_npy2d(self,mat,missing):
|
def _init_from_npy2d(self, mat, missing):
|
||||||
"""convert data from numpy matrix"""
|
"""
|
||||||
data = numpy.array(mat.reshape(mat.size), dtype='float32')
|
Initialize data from a 2-D numpy matrix.
|
||||||
|
"""
|
||||||
|
data = np.array(mat.reshape(mat.size), dtype=np.float32)
|
||||||
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat(
|
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat(
|
||||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||||
mat.shape[0], mat.shape[1], ctypes.c_float(missing)))
|
mat.shape[0], mat.shape[1], ctypes.c_float(missing)))
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""destructor"""
|
|
||||||
xglib.XGDMatrixFree(self.handle)
|
xglib.XGDMatrixFree(self.handle)
|
||||||
|
|
||||||
def get_float_info(self, field):
|
def get_float_info(self, field):
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
ret = xglib.XGDMatrixGetFloatInfo(self.handle, c_str(field), ctypes.byref(length))
|
||||||
ctypes.byref(length))
|
return ctypes2numpy(ret, length.value, np.float32)
|
||||||
return ctypes2numpy(ret, length.value, 'float32')
|
|
||||||
def get_uint_info(self, field):
|
def get_uint_info(self, field):
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
ret = xglib.XGDMatrixGetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
ret = xglib.XGDMatrixGetUIntInfo(self.handle, c_str(field), ctypes.byref(length))
|
||||||
ctypes.byref(length))
|
return ctypes2numpy(ret, length.value, np.uint32)
|
||||||
return ctypes2numpy(ret, length.value, 'uint32')
|
|
||||||
def set_float_info(self, field, data):
|
def set_float_info(self, field, data):
|
||||||
xglib.XGDMatrixSetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
xglib.XGDMatrixSetFloatInfo(self.handle, c_str(field),
|
||||||
(ctypes.c_float*len(data))(*data), len(data))
|
c_array(ctypes.c_float, data), len(data))
|
||||||
|
|
||||||
def set_uint_info(self, field, data):
|
def set_uint_info(self, field, data):
|
||||||
xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
xglib.XGDMatrixSetUIntInfo(self.handle, c_str(field),
|
||||||
(ctypes.c_uint*len(data))(*data), len(data))
|
c_array(ctypes.c_uint, data), len(data))
|
||||||
|
|
||||||
def save_binary(self, fname, silent=True):
|
def save_binary(self, fname, silent=True):
|
||||||
"""save DMatrix to XGBoost buffer
|
|
||||||
Args:
|
|
||||||
fname: string
|
|
||||||
name of buffer file
|
|
||||||
slient: bool, option
|
|
||||||
whether print info
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
|
Save DMatrix to an XGBoost buffer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fname : string
|
||||||
|
Name of the output buffer file.
|
||||||
|
silent : bool (optional; default: True)
|
||||||
|
If set, the output is suppressed.
|
||||||
|
"""
|
||||||
|
xglib.XGDMatrixSaveBinary(self.handle, c_str(fname), int(silent))
|
||||||
|
|
||||||
def set_label(self, label):
|
def set_label(self, label):
|
||||||
"""set label of dmatrix
|
"""set label of dmatrix
|
||||||
@ -149,12 +192,13 @@ class DMatrix:
|
|||||||
self.set_float_info('label', label)
|
self.set_float_info('label', label)
|
||||||
|
|
||||||
def set_weight(self, weight):
|
def set_weight(self, weight):
|
||||||
"""set weight of each instances
|
"""
|
||||||
Args:
|
Set weight of each instance.
|
||||||
weight: float
|
|
||||||
weight for positive instance
|
Parameters
|
||||||
Returns:
|
----------
|
||||||
None
|
weight : float
|
||||||
|
Weight for positive instance.
|
||||||
"""
|
"""
|
||||||
self.set_float_info('weight', weight)
|
self.set_float_info('weight', weight)
|
||||||
|
|
||||||
@ -170,159 +214,180 @@ class DMatrix:
|
|||||||
self.set_float_info('base_margin', margin)
|
self.set_float_info('base_margin', margin)
|
||||||
|
|
||||||
def set_group(self, group):
|
def set_group(self, group):
|
||||||
"""set group size of dmatrix, used for rank
|
|
||||||
Args:
|
|
||||||
group:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
|
Set group size of DMatrix (used for ranking).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
group : int
|
||||||
|
Group size.
|
||||||
|
"""
|
||||||
|
xglib.XGDMatrixSetGroup(self.handle, c_array(ctypes.c_uint, group), len(group))
|
||||||
|
|
||||||
def get_label(self):
|
def get_label(self):
|
||||||
"""get label from dmatrix
|
"""
|
||||||
Args:
|
Get the label of the DMatrix.
|
||||||
None
|
|
||||||
Returns:
|
Returns
|
||||||
list, label of data
|
-------
|
||||||
|
label : list
|
||||||
"""
|
"""
|
||||||
return self.get_float_info('label')
|
return self.get_float_info('label')
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
"""get weight from dmatrix
|
"""
|
||||||
Args:
|
Get the weight of the DMatrix.
|
||||||
None
|
|
||||||
Returns:
|
Returns
|
||||||
float, weight
|
-------
|
||||||
|
weight : float
|
||||||
"""
|
"""
|
||||||
return self.get_float_info('weight')
|
return self.get_float_info('weight')
|
||||||
|
|
||||||
def get_base_margin(self):
|
def get_base_margin(self):
|
||||||
"""get base_margin from dmatrix
|
"""
|
||||||
Args:
|
Get the base margin of the DMatrix.
|
||||||
None
|
|
||||||
Returns:
|
Returns
|
||||||
float, base margin
|
-------
|
||||||
|
base_margin : float
|
||||||
"""
|
"""
|
||||||
return self.get_float_info('base_margin')
|
return self.get_float_info('base_margin')
|
||||||
|
|
||||||
def num_row(self):
|
def num_row(self):
|
||||||
"""get number of rows
|
"""
|
||||||
Args:
|
Get the number of rows in the DMatrix.
|
||||||
None
|
|
||||||
Returns:
|
Returns
|
||||||
int, num rows
|
-------
|
||||||
|
number of rows : int
|
||||||
"""
|
"""
|
||||||
return xglib.XGDMatrixNumRow(self.handle)
|
return xglib.XGDMatrixNumRow(self.handle)
|
||||||
|
|
||||||
def slice(self, rindex):
|
def slice(self, rindex):
|
||||||
"""slice the DMatrix to return a new DMatrix that only contains rindex
|
"""
|
||||||
Args:
|
Slice the DMatrix and return a new DMatrix that only contains `rindex`.
|
||||||
rindex: list
|
|
||||||
list of index to be chosen
|
Parameters
|
||||||
Returns:
|
----------
|
||||||
res: DMatrix
|
rindex : list
|
||||||
new DMatrix with chosen index
|
List of indices to be selected.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
res : DMatrix
|
||||||
|
A new DMatrix containing only selected indices.
|
||||||
"""
|
"""
|
||||||
res = DMatrix(None)
|
res = DMatrix(None)
|
||||||
res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix(
|
res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix(
|
||||||
self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex)))
|
self.handle, c_array(ctypes.c_int, rindex), len(rindex)))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
class Booster:
|
|
||||||
"""learner class """
|
class Booster(object):
|
||||||
def __init__(self, params={}, cache=[], model_file = None):
|
def __init__(self, params=None, cache=(), model_file=None):
|
||||||
""" constructor
|
"""
|
||||||
Args:
|
Learner class.
|
||||||
params: dict
|
|
||||||
params for boosters
|
Parameters
|
||||||
cache: list
|
----------
|
||||||
list of cache item
|
params : dict
|
||||||
model_file: string
|
Parameters for boosters.
|
||||||
path of model file
|
cache : list
|
||||||
Returns:
|
List of cache items.
|
||||||
None
|
model_file : string
|
||||||
|
Path to the model file.
|
||||||
"""
|
"""
|
||||||
for d in cache:
|
for d in cache:
|
||||||
assert isinstance(d, DMatrix)
|
if not isinstance(d, DMatrix):
|
||||||
dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
|
raise TypeError('invalid cache item: {}'.format(type(d).__name__))
|
||||||
|
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||||
self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache)))
|
self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache)))
|
||||||
self.set_param({'seed':0})
|
self.set_param({'seed': 0})
|
||||||
self.set_param(params)
|
self.set_param(params or {})
|
||||||
if model_file != None:
|
if model_file is not None:
|
||||||
self.load_model(model_file)
|
self.load_model(model_file)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
xglib.XGBoosterFree(self.handle)
|
xglib.XGBoosterFree(self.handle)
|
||||||
|
|
||||||
def set_param(self, params, pv=None):
|
def set_param(self, params, pv=None):
|
||||||
if isinstance(params, dict):
|
if isinstance(params, collections.Mapping):
|
||||||
for k, v in params.items():
|
params = params.items()
|
||||||
xglib.XGBoosterSetParam(
|
elif isinstance(params, string_types) and pv is not None:
|
||||||
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
params = [(params, pv)]
|
||||||
ctypes.c_char_p(str(v).encode('utf-8')))
|
for k, v in params:
|
||||||
elif isinstance(params,str) and pv != None:
|
xglib.XGBoosterSetParam(self.handle, c_str(k), c_str(str(v)))
|
||||||
xglib.XGBoosterSetParam(
|
|
||||||
self.handle, ctypes.c_char_p(params.encode('utf-8')),
|
|
||||||
ctypes.c_char_p(str(pv).encode('utf-8')))
|
|
||||||
else:
|
|
||||||
for k, v in params:
|
|
||||||
xglib.XGBoosterSetParam(
|
|
||||||
self.handle, ctypes.c_char_p(k.encode('utf-8')),
|
|
||||||
ctypes.c_char_p(str(v).encode('utf-8')))
|
|
||||||
|
|
||||||
def update(self, dtrain, it, fobj=None):
|
def update(self, dtrain, it, fobj=None):
|
||||||
"""
|
"""
|
||||||
update
|
Update (one iteration).
|
||||||
Args:
|
|
||||||
dtrain: DMatrix
|
Parameters
|
||||||
the training DMatrix
|
----------
|
||||||
it: int
|
dtrain : DMatrix
|
||||||
current iteration number
|
Training data.
|
||||||
fobj: function
|
it : int
|
||||||
cutomzied objective function
|
Current iteration number.
|
||||||
Returns:
|
fobj : function
|
||||||
None
|
Customized objective function.
|
||||||
"""
|
"""
|
||||||
assert isinstance(dtrain, DMatrix)
|
if not isinstance(dtrain, DMatrix):
|
||||||
|
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||||
if fobj is None:
|
if fobj is None:
|
||||||
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
|
xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
|
||||||
else:
|
else:
|
||||||
pred = self.predict( dtrain )
|
pred = self.predict(dtrain)
|
||||||
grad, hess = fobj( pred, dtrain )
|
grad, hess = fobj(pred, dtrain)
|
||||||
self.boost( dtrain, grad, hess )
|
self.boost(dtrain, grad, hess)
|
||||||
|
|
||||||
def boost(self, dtrain, grad, hess):
|
def boost(self, dtrain, grad, hess):
|
||||||
""" update
|
|
||||||
Args:
|
|
||||||
dtrain: DMatrix
|
|
||||||
the training DMatrix
|
|
||||||
grad: list
|
|
||||||
the first order of gradient
|
|
||||||
hess: list
|
|
||||||
the second order of gradient
|
|
||||||
"""
|
"""
|
||||||
assert len(grad) == len(hess)
|
Update.
|
||||||
assert isinstance(dtrain, DMatrix)
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dtrain : DMatrix
|
||||||
|
The training DMatrix.
|
||||||
|
grad : list
|
||||||
|
The first order of gradient.
|
||||||
|
hess : list
|
||||||
|
The second order of gradient.
|
||||||
|
"""
|
||||||
|
if len(grad) != len(hess):
|
||||||
|
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
|
||||||
|
if not isinstance(dtrain, DMatrix):
|
||||||
|
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||||
xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
||||||
(ctypes.c_float*len(grad))(*grad),
|
c_array(ctypes.c_float, grad),
|
||||||
(ctypes.c_float*len(hess))(*hess),
|
c_array(ctypes.c_float, hess),
|
||||||
len(grad))
|
len(grad))
|
||||||
|
|
||||||
def eval_set(self, evals, it = 0, feval = None):
|
def eval_set(self, evals, it=0, feval=None):
|
||||||
"""evaluates by metric
|
"""
|
||||||
Args:
|
Evaluate by a metric.
|
||||||
evals: list of tuple (DMatrix, string)
|
|
||||||
lists of items to be evaluated
|
Parameters
|
||||||
it: int
|
----------
|
||||||
current iteration
|
evals : list of tuples (DMatrix, string)
|
||||||
feval: function
|
List of items to be evaluated.
|
||||||
custom evaluation function
|
it : int
|
||||||
Returns:
|
Current iteration.
|
||||||
evals result
|
feval : function
|
||||||
|
Custom evaluation function.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
evaluation result
|
||||||
"""
|
"""
|
||||||
if feval is None:
|
if feval is None:
|
||||||
for d in evals:
|
for d in evals:
|
||||||
assert isinstance(d[0], DMatrix)
|
if not isinstance(d[0], DMatrix):
|
||||||
assert isinstance(d[1], str)
|
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
||||||
dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals])
|
if not isinstance(d[1], string_types):
|
||||||
evnames = (ctypes.c_char_p * len(evals))(
|
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
||||||
* [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals])
|
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||||
|
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||||
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
|
||||||
else:
|
else:
|
||||||
res = '[%d]' % it
|
res = '[%d]' % it
|
||||||
@ -330,97 +395,115 @@ class Booster:
|
|||||||
name, val = feval(self.predict(dm), dm)
|
name, val = feval(self.predict(dm), dm)
|
||||||
res += '\t%s-%s:%f' % (evname, name, val)
|
res += '\t%s-%s:%f' % (evname, name, val)
|
||||||
return res
|
return res
|
||||||
def eval(self, mat, name = 'eval', it = 0):
|
|
||||||
return self.eval_set( [(mat,name)], it)
|
def eval(self, mat, name='eval', it=0):
|
||||||
|
return self.eval_set([(mat, name)], it)
|
||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
||||||
"""
|
"""
|
||||||
predict with data
|
Predict with data.
|
||||||
Args:
|
|
||||||
data: DMatrix
|
Parameters
|
||||||
the dmatrix storing the input
|
----------
|
||||||
output_margin: bool
|
data : DMatrix
|
||||||
whether output raw margin value that is untransformed
|
The dmatrix storing the input.
|
||||||
ntree_limit: int
|
output_margin : bool
|
||||||
limit number of trees in prediction, default to 0, 0 means using all the trees
|
Whether to output the raw untransformed margin value.
|
||||||
pred_leaf: bool
|
ntree_limit : int
|
||||||
when this option is on, the output will be a matrix of (nsample, ntrees)
|
Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
with each record indicate the predicted leaf index of each sample in each tree
|
pred_leaf : bool
|
||||||
Note that the leaf index of tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0
|
When this option is on, the output will be a matrix of (nsample, ntrees)
|
||||||
Returns:
|
with each record indicating the predicted leaf index of each sample in each tree.
|
||||||
numpy array of prediction
|
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||||
|
in both tree 1 and tree 0.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
prediction : numpy array
|
||||||
"""
|
"""
|
||||||
option_mask = 0
|
option_mask = 0x00
|
||||||
if output_margin:
|
if output_margin:
|
||||||
option_mask += 1
|
option_mask |= 0x01
|
||||||
if pred_leaf:
|
if pred_leaf:
|
||||||
option_mask += 2
|
option_mask |= 0x02
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
||||||
option_mask, ntree_limit, ctypes.byref(length))
|
option_mask, ntree_limit, ctypes.byref(length))
|
||||||
preds = ctypes2numpy(preds, length.value, 'float32')
|
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||||
if pred_leaf:
|
if pred_leaf:
|
||||||
preds = preds.astype('int32')
|
preds = preds.astype(np.int32)
|
||||||
nrow = data.num_row()
|
nrow = data.num_row()
|
||||||
if preds.size != nrow and preds.size % nrow == 0:
|
if preds.size != nrow and preds.size % nrow == 0:
|
||||||
preds = preds.reshape(nrow, preds.size / nrow)
|
preds = preds.reshape(nrow, preds.size / nrow)
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def save_model(self, fname):
|
def save_model(self, fname):
|
||||||
""" save model to file
|
|
||||||
Args:
|
|
||||||
fname: string
|
|
||||||
file name of saving model
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8')))
|
Save the model to a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fname : string
|
||||||
|
Output file name.
|
||||||
|
"""
|
||||||
|
xglib.XGBoosterSaveModel(self.handle, c_str(fname))
|
||||||
|
|
||||||
def load_model(self, fname):
|
def load_model(self, fname):
|
||||||
"""load model from file
|
|
||||||
Args:
|
|
||||||
fname: string
|
|
||||||
file name of saving model
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
|
Load the model from a file.
|
||||||
def dump_model(self, fo, fmap='', with_stats = False):
|
|
||||||
"""dump model into text file
|
Parameters
|
||||||
Args:
|
----------
|
||||||
fo: string
|
fname : string
|
||||||
file name to be dumped
|
Input file name.
|
||||||
fmap: string, optional
|
|
||||||
file name of feature map names
|
|
||||||
with_stats: bool, optional
|
|
||||||
whether output statistics of the split
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
if isinstance(fo,str):
|
xglib.XGBoosterLoadModel(self.handle, c_str(fname))
|
||||||
fo = open(fo,'w')
|
|
||||||
|
def dump_model(self, fo, fmap='', with_stats=False):
|
||||||
|
"""
|
||||||
|
Dump model into a text file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fo : string
|
||||||
|
Output file name.
|
||||||
|
fmap : string, optional
|
||||||
|
Name of the file containing feature map names.
|
||||||
|
with_stats : bool (optional)
|
||||||
|
Controls whether the split statistics are output.
|
||||||
|
"""
|
||||||
|
if isinstance(fo, string_types):
|
||||||
|
fo = open(fo, 'w')
|
||||||
need_close = True
|
need_close = True
|
||||||
else:
|
else:
|
||||||
need_close = False
|
need_close = False
|
||||||
ret = self.get_dump(fmap, with_stats)
|
ret = self.get_dump(fmap, with_stats)
|
||||||
for i in range(len(ret)):
|
for i in range(len(ret)):
|
||||||
fo.write('booster[%d]:\n' %i)
|
fo.write('booster[{}]:\n'.format(i))
|
||||||
fo.write( ret[i] )
|
fo.write(ret[i])
|
||||||
if need_close:
|
if need_close:
|
||||||
fo.close()
|
fo.close()
|
||||||
|
|
||||||
def get_dump(self, fmap='', with_stats=False):
|
def get_dump(self, fmap='', with_stats=False):
|
||||||
"""get dump of model as list of strings """
|
"""
|
||||||
|
Returns the dump the model as a list of strings.
|
||||||
|
"""
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
sarr = xglib.XGBoosterDumpModel(self.handle,
|
sarr = xglib.XGBoosterDumpModel(self.handle, c_str(fmap),
|
||||||
ctypes.c_char_p(fmap.encode('utf-8')),
|
|
||||||
int(with_stats), ctypes.byref(length))
|
int(with_stats), ctypes.byref(length))
|
||||||
res = []
|
res = []
|
||||||
for i in range(length.value):
|
for i in range(length.value):
|
||||||
res.append( str(sarr[i]) )
|
res.append(str(sarr[i]))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_fscore(self, fmap=''):
|
def get_fscore(self, fmap=''):
|
||||||
""" get feature importance of each feature """
|
"""
|
||||||
|
Get feature importance of each feature.
|
||||||
|
"""
|
||||||
trees = self.get_dump(fmap)
|
trees = self.get_dump(fmap)
|
||||||
fmap = {}
|
fmap = {}
|
||||||
for tree in trees:
|
for tree in trees:
|
||||||
print (tree)
|
sys.stdout.write(str(tree) + '\n')
|
||||||
for l in tree.split('\n'):
|
for l in tree.split('\n'):
|
||||||
arr = l.split('[')
|
arr = l.split('[')
|
||||||
if len(arr) == 1:
|
if len(arr) == 1:
|
||||||
@ -430,56 +513,70 @@ class Booster:
|
|||||||
if fid not in fmap:
|
if fid not in fmap:
|
||||||
fmap[fid] = 1
|
fmap[fid] = 1
|
||||||
else:
|
else:
|
||||||
fmap[fid]+= 1
|
fmap[fid] += 1
|
||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None):
|
|
||||||
""" train a booster with given paramaters
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
||||||
Args:
|
|
||||||
params: dict
|
|
||||||
params of booster
|
|
||||||
dtrain: DMatrix
|
|
||||||
data to be trained
|
|
||||||
num_boost_round: int
|
|
||||||
num of round to be boosted
|
|
||||||
watchlist: list of pairs (DMatrix, string)
|
|
||||||
list of items to be evaluated during training, this allows user to watch performance on validation set
|
|
||||||
obj: function
|
|
||||||
cutomized objective function
|
|
||||||
feval: function
|
|
||||||
cutomized evaluation function
|
|
||||||
Returns: Booster model trained
|
|
||||||
"""
|
"""
|
||||||
bst = Booster(params, [dtrain]+[ d[0] for d in evals ] )
|
Train a booster with given parameters.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
params : dict
|
||||||
|
Booster params.
|
||||||
|
dtrain : DMatrix
|
||||||
|
Data to be trained.
|
||||||
|
num_boost_round: int
|
||||||
|
Number of boosting iterations.
|
||||||
|
watchlist : list of pairs (DMatrix, string)
|
||||||
|
List of items to be evaluated during training, this allows user to watch
|
||||||
|
performance on the validation set.
|
||||||
|
obj : function
|
||||||
|
Customized objective function.
|
||||||
|
feval : function
|
||||||
|
Customized evaluation function.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
booster : a trained booster model
|
||||||
|
"""
|
||||||
|
evals = list(evals)
|
||||||
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
bst.update( dtrain, i, obj )
|
bst.update(dtrain, i, obj)
|
||||||
if len(evals) != 0:
|
if len(evals) != 0:
|
||||||
bst_eval_set=bst.eval_set(evals, i, feval)
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
if isinstance(bst_eval_set,str):
|
if isinstance(bst_eval_set, string_types):
|
||||||
sys.stderr.write(bst_eval_set+'\n')
|
sys.stderr.write(bst_eval_set + '\n')
|
||||||
else:
|
else:
|
||||||
sys.stderr.write(bst_eval_set.decode()+'\n')
|
sys.stderr.write(bst_eval_set.decode() + '\n')
|
||||||
return bst
|
return bst
|
||||||
|
|
||||||
class CVPack:
|
|
||||||
|
class CVPack(object):
|
||||||
def __init__(self, dtrain, dtest, param):
|
def __init__(self, dtrain, dtest, param):
|
||||||
self.dtrain = dtrain
|
self.dtrain = dtrain
|
||||||
self.dtest = dtest
|
self.dtest = dtest
|
||||||
self.watchlist = watchlist = [ (dtrain,'train'), (dtest, 'test') ]
|
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
|
||||||
self.bst = Booster(param, [dtrain,dtest])
|
self.bst = Booster(param, [dtrain, dtest])
|
||||||
|
|
||||||
def update(self, r, fobj):
|
def update(self, r, fobj):
|
||||||
self.bst.update(self.dtrain, r, fobj)
|
self.bst.update(self.dtrain, r, fobj)
|
||||||
|
|
||||||
def eval(self, r, feval):
|
def eval(self, r, feval):
|
||||||
return self.bst.eval_set(self.watchlist, r, feval)
|
return self.bst.eval_set(self.watchlist, r, feval)
|
||||||
|
|
||||||
def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None):
|
|
||||||
|
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
|
||||||
"""
|
"""
|
||||||
mk nfold list of cvpack from randidx
|
Make an n-fold list of CVPack from random indices.
|
||||||
"""
|
"""
|
||||||
|
evals = list(evals)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
randidx = np.random.permutation(dall.num_row())
|
randidx = np.random.permutation(dall.num_row())
|
||||||
kstep = len(randidx) / nfold
|
kstep = len(randidx) / nfold
|
||||||
idset = [randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ] for i in range(nfold)]
|
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
|
||||||
ret = []
|
ret = []
|
||||||
for k in range(nfold):
|
for k in range(nfold):
|
||||||
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
|
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
|
||||||
@ -493,9 +590,10 @@ def mknfold(dall, nfold, param, seed, evals=[], fpreproc = None):
|
|||||||
ret.append(CVPack(dtrain, dtest, plst))
|
ret.append(CVPack(dtrain, dtest, plst))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def aggcv(rlist, show_stdv=True):
|
def aggcv(rlist, show_stdv=True):
|
||||||
"""
|
"""
|
||||||
aggregate cross validation results
|
Aggregate cross-validation results.
|
||||||
"""
|
"""
|
||||||
cvmap = {}
|
cvmap = {}
|
||||||
ret = rlist[0].split()[0]
|
ret = rlist[0].split()[0]
|
||||||
@ -503,15 +601,15 @@ def aggcv(rlist, show_stdv=True):
|
|||||||
arr = line.split()
|
arr = line.split()
|
||||||
assert ret == arr[0]
|
assert ret == arr[0]
|
||||||
for it in arr[1:]:
|
for it in arr[1:]:
|
||||||
if not isinstance(it,str):
|
if not isinstance(it, string_types):
|
||||||
it=it.decode()
|
it = it.decode()
|
||||||
k, v = it.split(':')
|
k, v = it.split(':')
|
||||||
if k not in cvmap:
|
if k not in cvmap:
|
||||||
cvmap[k] = []
|
cvmap[k] = []
|
||||||
cvmap[k].append(float(v))
|
cvmap[k].append(float(v))
|
||||||
for k, v in sorted(cvmap.items(), key = lambda x:x[0]):
|
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
|
||||||
v = np.array(v)
|
v = np.array(v)
|
||||||
if not isinstance(ret,str):
|
if not isinstance(ret, string_types):
|
||||||
ret = ret.decode()
|
ret = ret.decode()
|
||||||
if show_stdv:
|
if show_stdv:
|
||||||
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
|
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
|
||||||
@ -519,33 +617,39 @@ def aggcv(rlist, show_stdv=True):
|
|||||||
ret += '\tcv-%s:%f' % (k, np.mean(v))
|
ret += '\tcv-%s:%f' % (k, np.mean(v))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \
|
|
||||||
obj = None, feval = None, fpreproc = None, show_stdv = True, seed = 0):
|
|
||||||
""" cross validation with given paramaters
|
|
||||||
Args:
|
|
||||||
params: dict
|
|
||||||
params of booster
|
|
||||||
dtrain: DMatrix
|
|
||||||
data to be trained
|
|
||||||
num_boost_round: int
|
|
||||||
num of round to be boosted
|
|
||||||
nfold: int
|
|
||||||
number of folds to do cv
|
|
||||||
metrics: list of strings
|
|
||||||
evaluation metrics to be watched in cv
|
|
||||||
obj: function
|
|
||||||
custom objective function
|
|
||||||
feval: function
|
|
||||||
custom evaluation function
|
|
||||||
fpreproc: function
|
|
||||||
preprocessing function that takes dtrain, dtest,
|
|
||||||
param and return transformed version of dtrain, dtest, param
|
|
||||||
show_stdv: bool
|
|
||||||
whether display standard deviation
|
|
||||||
seed: int
|
|
||||||
seed used to generate the folds, this is passed to numpy.random.seed
|
|
||||||
|
|
||||||
Returns: list(string) of evaluation history
|
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
|
||||||
|
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0):
|
||||||
|
"""
|
||||||
|
Cross-validation with given paramaters.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
params : dict
|
||||||
|
Booster params.
|
||||||
|
dtrain : DMatrix
|
||||||
|
Data to be trained.
|
||||||
|
num_boost_round : int
|
||||||
|
Number of boosting iterations.
|
||||||
|
nfold : int
|
||||||
|
Number of folds in CV.
|
||||||
|
metrics : list of strings
|
||||||
|
Evaluation metrics to be watched in CV.
|
||||||
|
obj : function
|
||||||
|
Custom objective function.
|
||||||
|
feval : function
|
||||||
|
Custom evaluation function.
|
||||||
|
fpreproc : function
|
||||||
|
Preprocessing function that takes (dtrain, dtest, param) and returns
|
||||||
|
transformed versions of those.
|
||||||
|
show_stdv : bool
|
||||||
|
Whether to display the standard deviation.
|
||||||
|
seed : int
|
||||||
|
Seed used to generate the folds (passed to numpy.random.seed).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
evaluation history : list(string)
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
|
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
|
||||||
@ -553,7 +657,6 @@ def cv(params, dtrain, num_boost_round = 10, nfold=3, metrics=[], \
|
|||||||
for f in cvfolds:
|
for f in cvfolds:
|
||||||
f.update(i, obj)
|
f.update(i, obj)
|
||||||
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
|
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
|
||||||
sys.stderr.write(res+'\n')
|
sys.stderr.write(res + '\n')
|
||||||
results.append(res)
|
results.append(res)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user