Use ctypes

This commit is contained in:
sinhrks 2015-09-12 14:36:17 +09:00
parent 6506a1c490
commit 48ac946d9f
4 changed files with 115 additions and 36 deletions

View File

@ -1,10 +1,9 @@
# coding: utf-8
# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments, too-many-branches
"""Core XGBoost Library."""
from __future__ import absolute_import
import os
import re
import sys
import ctypes
import platform
@ -24,8 +23,9 @@ class XGBoostError(Exception):
if sys.version_info[0] == 3:
# pylint: disable=invalid-name
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str,
unicode = str
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,
@ -184,15 +184,18 @@ class DMatrix(object):
self.set_weight(weight)
# validate feature name
if not isinstance(feature_names, list):
feature_names = list(feature_names)
if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col():
raise ValueError('feature_names must have the same length as data')
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
for f in feature_names):
raise ValueError('all feature_names must be alphanumerics')
if not feature_names is None:
if not isinstance(feature_names, list):
feature_names = list(feature_names)
if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col():
msg = 'feature_names must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. ``[]=.``
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
for f in feature_names):
raise ValueError('all feature_names must be alphanumerics')
self.feature_names = feature_names
def _init_from_csr(self, csr):
@ -411,13 +414,13 @@ class DMatrix(object):
return ret.value
def num_col(self):
"""Get the number of columns in the DMatrix.
"""Get the number of columns (features) in the DMatrix.
Returns
-------
number of columns : int
"""
ret = ctypes.c_ulong()
ret = ctypes.c_uint()
_check_call(_LIB.XGDMatrixNumCol(self.handle,
ctypes.byref(ret)))
return ret.value
@ -611,7 +614,7 @@ class Booster(object):
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
if not isinstance(d[1], STRING_TYPES):
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
self._validate_feature_names(d)
self._validate_feature_names(d[0])
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
@ -776,27 +779,46 @@ class Booster(object):
"""
Returns the dump the model as a list of strings.
"""
res = []
length = ctypes.c_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()
_check_call(_LIB.XGBoosterDumpModel(self.handle,
c_str(fmap),
int(with_stats),
ctypes.byref(length),
ctypes.byref(sarr)))
for i in range(length.value):
res.append(str(sarr[i].decode('ascii')))
if self.feature_names is not None and fmap == '':
flen = int(len(self.feature_names))
fname = (ctypes.c_char_p * flen)()
ftype = (ctypes.c_char_p * flen)()
if self.feature_names is not None:
defaults = ['f{0}'.format(i) for i in
range(len(self.feature_names))]
rep = dict((re.escape(k), v) for k, v in
zip(defaults, self.feature_names))
pattern = re.compile("|".join(rep))
def _replace(expr):
""" Replace matched group to corresponding values """
return pattern.sub(lambda m: rep[re.escape(m.group(0))], expr)
res = [_replace(r) for r in res]
# supports quantitative type only
# {'q': quantitative, 'i': indicator}
if sys.version_info[0] == 3:
features = [bytes(f, 'utf-8') for f in self.feature_names]
types = [bytes('q', 'utf-8')] * flen
else:
features = [f.encode('utf-8') if isinstance(f, unicode) else f
for f in self.feature_names]
types = ['q'] * flen
fname[:] = features
ftype[:] = types
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
flen,
fname,
ftype,
int(with_stats),
ctypes.byref(length),
ctypes.byref(sarr)))
else:
_check_call(_LIB.XGBoosterDumpModel(self.handle,
c_str(fmap),
int(with_stats),
ctypes.byref(length),
ctypes.byref(sarr)))
res = []
for i in range(length.value):
try:
res.append(str(sarr[i].decode('ascii')))
except UnicodeDecodeError:
res.append(unicode(sarr[i].decode('utf-8')))
return res
def get_fscore(self, fmap=''):

View File

@ -29,6 +29,26 @@ def test_basic():
# assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0
def test_feature_names():
data = np.random.randn(100, 5)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5']
dm = xgb.DMatrix(data, label=target,
feature_names=features)
assert dm.feature_names == features
assert dm.num_row() == 100
assert dm.num_col() == 5
params={'objective': 'multi:softprob',
'eval_metric': 'mlogloss',
'eta': 0.3,
'num_class': 3}
bst = xgb.train(params, dm, num_boost_round=10)
scores = bst.get_fscore()
assert list(sorted(k for k in scores)) == features
def test_plotting():
bst2 = xgb.Booster(model_file='xgb.model')
# plotting

View File

@ -445,9 +445,9 @@ int XGDMatrixNumRow(const DMatrixHandle handle,
int XGDMatrixNumCol(const DMatrixHandle handle,
bst_ulong *out) {
API_BEGIN();
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
API_END();
API_BEGIN();
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
API_END();
}
// xgboost implementation
@ -580,3 +580,20 @@ int XGBoosterDumpModel(BoosterHandle handle,
featmap, with_stats != 0, len);
API_END();
}
int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char **fname,
const char **ftype,
int with_stats,
bst_ulong *len,
const char ***out_models) {
API_BEGIN();
utils::FeatMap featmap;
for (int i = 0; i < fnum; ++i) {
featmap.PushBack(i, fname[i], ftype[i]);
}
*out_models = static_cast<Booster*>(handle)->GetModelDump(
featmap, with_stats != 0, len);
API_END();
}

View File

@ -331,4 +331,24 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
int with_stats,
bst_ulong *out_len,
const char ***out_dump_array);
/*!
* \brief dump model, return array of strings representing model dump
* \param handle handle
* \param fnum number of features
* \param fnum names of features
* \param fnum types of features
* \param with_stats whether to dump with statistics
* \param out_len length of output array
* \param out_dump_array pointer to hold representing dump of each model
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
int fnum,
const char **fname,
const char **ftype,
int with_stats,
bst_ulong *len,
const char ***out_models);
#endif // XGBOOST_WRAPPER_H_