Use ctypes
This commit is contained in:
parent
6506a1c490
commit
48ac946d9f
@ -1,10 +1,9 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-arguments
|
||||
# pylint: disable=too-many-arguments, too-many-branches
|
||||
"""Core XGBoost Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import ctypes
|
||||
import platform
|
||||
@ -24,8 +23,9 @@ class XGBoostError(Exception):
|
||||
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=invalid-name, redefined-builtin
|
||||
STRING_TYPES = str,
|
||||
unicode = str
|
||||
else:
|
||||
# pylint: disable=invalid-name
|
||||
STRING_TYPES = basestring,
|
||||
@ -184,15 +184,18 @@ class DMatrix(object):
|
||||
self.set_weight(weight)
|
||||
|
||||
# validate feature name
|
||||
if not isinstance(feature_names, list):
|
||||
feature_names = list(feature_names)
|
||||
if len(feature_names) != len(set(feature_names)):
|
||||
raise ValueError('feature_names must be unique')
|
||||
if len(feature_names) != self.num_col():
|
||||
raise ValueError('feature_names must have the same length as data')
|
||||
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
|
||||
for f in feature_names):
|
||||
raise ValueError('all feature_names must be alphanumerics')
|
||||
if not feature_names is None:
|
||||
if not isinstance(feature_names, list):
|
||||
feature_names = list(feature_names)
|
||||
if len(feature_names) != len(set(feature_names)):
|
||||
raise ValueError('feature_names must be unique')
|
||||
if len(feature_names) != self.num_col():
|
||||
msg = 'feature_names must have the same length as data'
|
||||
raise ValueError(msg)
|
||||
# prohibit to use symbols may affect to parse. e.g. ``[]=.``
|
||||
if not all(isinstance(f, STRING_TYPES) and f.isalnum()
|
||||
for f in feature_names):
|
||||
raise ValueError('all feature_names must be alphanumerics')
|
||||
self.feature_names = feature_names
|
||||
|
||||
def _init_from_csr(self, csr):
|
||||
@ -411,13 +414,13 @@ class DMatrix(object):
|
||||
return ret.value
|
||||
|
||||
def num_col(self):
|
||||
"""Get the number of columns in the DMatrix.
|
||||
"""Get the number of columns (features) in the DMatrix.
|
||||
|
||||
Returns
|
||||
-------
|
||||
number of columns : int
|
||||
"""
|
||||
ret = ctypes.c_ulong()
|
||||
ret = ctypes.c_uint()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
@ -611,7 +614,7 @@ class Booster(object):
|
||||
raise TypeError('expected DMatrix, got {}'.format(type(d[0]).__name__))
|
||||
if not isinstance(d[1], STRING_TYPES):
|
||||
raise TypeError('expected string, got {}'.format(type(d[1]).__name__))
|
||||
self._validate_feature_names(d)
|
||||
self._validate_feature_names(d[0])
|
||||
|
||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||
@ -776,27 +779,46 @@ class Booster(object):
|
||||
"""
|
||||
Returns the dump the model as a list of strings.
|
||||
"""
|
||||
res = []
|
||||
|
||||
length = ctypes.c_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||
c_str(fmap),
|
||||
int(with_stats),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
for i in range(length.value):
|
||||
res.append(str(sarr[i].decode('ascii')))
|
||||
if self.feature_names is not None and fmap == '':
|
||||
flen = int(len(self.feature_names))
|
||||
fname = (ctypes.c_char_p * flen)()
|
||||
ftype = (ctypes.c_char_p * flen)()
|
||||
|
||||
if self.feature_names is not None:
|
||||
defaults = ['f{0}'.format(i) for i in
|
||||
range(len(self.feature_names))]
|
||||
rep = dict((re.escape(k), v) for k, v in
|
||||
zip(defaults, self.feature_names))
|
||||
pattern = re.compile("|".join(rep))
|
||||
def _replace(expr):
|
||||
""" Replace matched group to corresponding values """
|
||||
return pattern.sub(lambda m: rep[re.escape(m.group(0))], expr)
|
||||
res = [_replace(r) for r in res]
|
||||
# supports quantitative type only
|
||||
# {'q': quantitative, 'i': indicator}
|
||||
if sys.version_info[0] == 3:
|
||||
features = [bytes(f, 'utf-8') for f in self.feature_names]
|
||||
types = [bytes('q', 'utf-8')] * flen
|
||||
else:
|
||||
features = [f.encode('utf-8') if isinstance(f, unicode) else f
|
||||
for f in self.feature_names]
|
||||
types = ['q'] * flen
|
||||
|
||||
fname[:] = features
|
||||
ftype[:] = types
|
||||
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
|
||||
flen,
|
||||
fname,
|
||||
ftype,
|
||||
int(with_stats),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
else:
|
||||
_check_call(_LIB.XGBoosterDumpModel(self.handle,
|
||||
c_str(fmap),
|
||||
int(with_stats),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
|
||||
res = []
|
||||
for i in range(length.value):
|
||||
try:
|
||||
res.append(str(sarr[i].decode('ascii')))
|
||||
except UnicodeDecodeError:
|
||||
res.append(unicode(sarr[i].decode('utf-8')))
|
||||
return res
|
||||
|
||||
def get_fscore(self, fmap=''):
|
||||
|
||||
@ -29,6 +29,26 @@ def test_basic():
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2-preds)) == 0
|
||||
|
||||
def test_feature_names():
|
||||
data = np.random.randn(100, 5)
|
||||
target = np.array([0, 1] * 50)
|
||||
|
||||
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5']
|
||||
dm = xgb.DMatrix(data, label=target,
|
||||
feature_names=features)
|
||||
assert dm.feature_names == features
|
||||
assert dm.num_row() == 100
|
||||
assert dm.num_col() == 5
|
||||
|
||||
params={'objective': 'multi:softprob',
|
||||
'eval_metric': 'mlogloss',
|
||||
'eta': 0.3,
|
||||
'num_class': 3}
|
||||
|
||||
bst = xgb.train(params, dm, num_boost_round=10)
|
||||
scores = bst.get_fscore()
|
||||
assert list(sorted(k for k in scores)) == features
|
||||
|
||||
def test_plotting():
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
# plotting
|
||||
|
||||
@ -445,9 +445,9 @@ int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
|
||||
int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
|
||||
API_END();
|
||||
API_BEGIN();
|
||||
*out = static_cast<size_t>(static_cast<const DataMatrix*>(handle)->info.num_col());
|
||||
API_END();
|
||||
}
|
||||
|
||||
// xgboost implementation
|
||||
@ -580,3 +580,20 @@ int XGBoosterDumpModel(BoosterHandle handle,
|
||||
featmap, with_stats != 0, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
int fnum,
|
||||
const char **fname,
|
||||
const char **ftype,
|
||||
int with_stats,
|
||||
bst_ulong *len,
|
||||
const char ***out_models) {
|
||||
API_BEGIN();
|
||||
utils::FeatMap featmap;
|
||||
for (int i = 0; i < fnum; ++i) {
|
||||
featmap.PushBack(i, fname[i], ftype[i]);
|
||||
}
|
||||
*out_models = static_cast<Booster*>(handle)->GetModelDump(
|
||||
featmap, with_stats != 0, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -331,4 +331,24 @@ XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
||||
int with_stats,
|
||||
bst_ulong *out_len,
|
||||
const char ***out_dump_array);
|
||||
|
||||
/*!
|
||||
* \brief dump model, return array of strings representing model dump
|
||||
* \param handle handle
|
||||
* \param fnum number of features
|
||||
* \param fnum names of features
|
||||
* \param fnum types of features
|
||||
* \param with_stats whether to dump with statistics
|
||||
* \param out_len length of output array
|
||||
* \param out_dump_array pointer to hold representing dump of each model
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
int fnum,
|
||||
const char **fname,
|
||||
const char **ftype,
|
||||
int with_stats,
|
||||
bst_ulong *len,
|
||||
const char ***out_models);
|
||||
|
||||
#endif // XGBOOST_WRAPPER_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user