Use ctypes
This commit is contained in:
parent
6506a1c490
commit
48ac946d9f
@ -1,10 +1,9 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments, too-many-branches
|
||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
@ -24,8 +23,9 @@ class XGBoostError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
if sys.version_info[0] == 3:
|
if sys.version_info[0] == 3:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name, redefined-builtin
|
||||||
STRING_TYPES = str,
|
STRING_TYPES = str,
|
||||||
|
unicode = str
|
||||||
else:
|
else:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
STRING_TYPES = basestring,
|
STRING_TYPES = basestring,
|
||||||
@ -184,15 +184,18 @@ class DMatrix(object):
|
|||||||
self.set_weight(weight)
|
self.set_weight(weight)
|
||||||
|
|
||||||
# validate feature name
|
# validate feature name
|
||||||
if not isinstance(feature_names, list):
|
if not feature_names is None:
|
||||||
feature_names = list(feature_names)
|
if not isinstance(feature_names, list):
|
||||||
if len(feature_names) != len(set(feature_names)):
|
feature_names = list(feature_names)
|
||||||
raise ValueError('feature_names must be unique')
|
if len(feature_names) != len(set(feature_names)):
|
||||||
if len(feature_names) != self.num_col():
|
raise ValueError('feature_names must be unique')
|
||||||
raise ValueError('feature_names must have the same length as data')
|
if len(feature_names) != self.num_col():
|
||||||
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
|
msg = 'feature_names must have the same length as data'
|
||||||
for f in feature_names):
|
raise ValueError(msg)
|
||||||
raise ValueError('all feature_names must be alphanumerics')
|
# prohibit to use symbols may affect to parse. e.g. ``[]=.``
|
||||||
|
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
|
||||||
|
for f in feature_names):
|
||||||
|
raise ValueError('all feature_names must be alphanumerics')
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
|
|
||||||
def _init_from_csr(self, csr):
|
def _init_from_csr(self, csr):
|
||||||
@ -411,13 +414,13 @@ class DMatrix(object):
|
|||||||
return ret.value
|
return ret.value
|
||||||
|
|
||||||
def num_col(self):
|
def num_col(self):
|
||||||
"""Get the number of columns in the DMatrix.
|
"""Get the number of columns (features) in the DMatrix.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
number of columns : int
|
number of columns : int
|
||||||
"""
|
"""
|
||||||
ret = ctypes.c_ulong()
|
ret = ctypes.c_uint()
|
||||||
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
||||||
ctypes.byref(ret)))
|
ctypes.byref(ret)))
|
||||||
return ret.value
|
return ret.value
|
||||||
@ -611,7 +614,7 @@ class Booster(object):
|
|||||||
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
||||||
if not isinstance(d[1], STRING_TYPES):
|
if not isinstance(d[1], STRING_TYPES):
|
||||||
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
||||||
self._validate_feature_names(d)
|
self._validate_feature_names(d[0])
|
||||||
|
|
||||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||||
@ -776,27 +779,46 @@ class Booster(object):
|
|||||||
"""
|
"""
|
||||||
Returns the dump the model as a list of strings.
|
Returns the dump the model as a list of strings.
|
||||||
"""
|
"""
|
||||||
res = []
|
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
if self.feature_names is not None and fmap == '':
|
||||||
c_str(fmap),
|
flen = int(len(self.feature_names))
|
||||||
int(with_stats),
|
fname = (ctypes.c_char_p * flen)()
|
||||||
ctypes.byref(length),
|
ftype = (ctypes.c_char_p * flen)()
|
||||||
ctypes.byref(sarr)))
|
|
||||||
for i in range(length.value):
|
|
||||||
res.append(str(sarr[i].decode('ascii')))
|
|
||||||
|
|
||||||
if self.feature_names is not None:
|
# supports quantitative type only
|
||||||
defaults = ['f{0}'.format(i) for i in
|
# {'q': quantitative, 'i': indicator}
|
||||||
range(len(self.feature_names))]
|
if sys.version_info[0] == 3:
|
||||||
rep = dict((re.escape(k), v) for k, v in
|
features = [bytes(f, 'utf-8') for f in self.feature_names]
|
||||||
zip(defaults, self.feature_names))
|
types = [bytes('q', 'utf-8')] * flen
|
||||||
pattern = re.compile("|".join(rep))
|
else:
|
||||||
def _replace(expr):
|
features = [f.encode('utf-8') if isinstance(f, unicode) else f
|
||||||
""" Replace matched group to corresponding values """
|
for f in self.feature_names]
|
||||||
return pattern.sub(lambda m: rep[re.escape(m.group(0))], expr)
|
types = ['q'] * flen
|
||||||
res = [_replace(r) for r in res]
|
|
||||||
|
fname[:] = features
|
||||||
|
ftype[:] = types
|
||||||
|
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
|
||||||
|
flen,
|
||||||
|
fname,
|
||||||
|
ftype,
|
||||||
|
int(with_stats),
|
||||||
|
ctypes.byref(length),
|
||||||
|
ctypes.byref(sarr)))
|
||||||
|
else:
|
||||||
|
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||||
|
c_str(fmap),
|
||||||
|
int(with_stats),
|
||||||
|
ctypes.byref(length),
|
||||||
|
ctypes.byref(sarr)))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for i in range(length.value):
|
||||||
|
try:
|
||||||
|
res.append(str(sarr[i].decode('ascii')))
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
res.append(unicode(sarr[i].decode('utf-8')))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_fscore(self, fmap=''):
|
def get_fscore(self, fmap=''):
|
||||||
|
|||||||
@ -29,6 +29,26 @@ def test_basic():
|
|||||||
# assert they are the same
|
# assert they are the same
|
||||||
assert np.sum(np.abs(preds2-preds)) == 0
|
assert np.sum(np.abs(preds2-preds)) == 0
|
||||||
|
|
||||||
|
def test_feature_names():
|
||||||
|
data = np.random.randn(100, 5)
|
||||||
|
target = np.array([0, 1] * 50)
|
||||||
|
|
||||||
|
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5']
|
||||||
|
dm = xgb.DMatrix(data, label=target,
|
||||||
|
feature_names=features)
|
||||||
|
assert dm.feature_names == features
|
||||||
|
assert dm.num_row() == 100
|
||||||
|
assert dm.num_col() == 5
|
||||||
|
|
||||||
|
params={'objective': 'multi:softprob',
|
||||||
|
'eval_metric': 'mlogloss',
|
||||||
|
'eta': 0.3,
|
||||||
|
'num_class': 3}
|
||||||
|
|
||||||
|
bst = xgb.train(params, dm, num_boost_round=10)
|
||||||
|
scores = bst.get_fscore()
|
||||||
|
assert list(sorted(k for k in scores)) == features
|
||||||
|
|
||||||
def test_plotting():
|
def test_plotting():
|
||||||
bst2 = xgb.Booster(model_file='xgb.model')
|
bst2 = xgb.Booster(model_file='xgb.model')
|
||||||
# plotting
|
# plotting
|
||||||
|
|||||||
@ -445,9 +445,9 @@ int XGDMatrixNumRow(const DMatrixHandle handle,
|
|||||||
|
|
||||||
int XGDMatrixNumCol(const DMatrixHandle handle,
|
int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||||
bst_ulong *out) {
|
bst_ulong *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
|
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
// xgboost implementation
|
// xgboost implementation
|
||||||
@ -580,3 +580,20 @@ int XGBoosterDumpModel(BoosterHandle handle,
|
|||||||
featmap, with_stats != 0, len);
|
featmap, with_stats != 0, len);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||||
|
int fnum,
|
||||||
|
const char **fname,
|
||||||
|
const char **ftype,
|
||||||
|
int with_stats,
|
||||||
|
bst_ulong *len,
|
||||||
|
const char ***out_models) {
|
||||||
|
API_BEGIN();
|
||||||
|
utils::FeatMap featmap;
|
||||||
|
for (int i = 0; i < fnum; ++i) {
|
||||||
|
featmap.PushBack(i, fname[i], ftype[i]);
|
||||||
|
}
|
||||||
|
*out_models = static_cast<Booster*>(handle)->GetModelDump(
|
||||||
|
featmap, with_stats != 0, len);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|||||||
@ -331,4 +331,24 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
|||||||
int with_stats,
|
int with_stats,
|
||||||
bst_ulong *out_len,
|
bst_ulong *out_len,
|
||||||
const char ***out_dump_array);
|
const char ***out_dump_array);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief dump model, return array of strings representing model dump
|
||||||
|
* \param handle handle
|
||||||
|
* \param fnum number of features
|
||||||
|
* \param fnum names of features
|
||||||
|
* \param fnum types of features
|
||||||
|
* \param with_stats whether to dump with statistics
|
||||||
|
* \param out_len length of output array
|
||||||
|
* \param out_dump_array pointer to hold representing dump of each model
|
||||||
|
* \return 0 when success, -1 when failure happens
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||||
|
int fnum,
|
||||||
|
const char **fname,
|
||||||
|
const char **ftype,
|
||||||
|
int with_stats,
|
||||||
|
bst_ulong *len,
|
||||||
|
const char ***out_models);
|
||||||
|
|
||||||
#endif // XGBOOST_WRAPPER_H_
|
#endif // XGBOOST_WRAPPER_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user