Cleanup str roundtrip using ctypes
This commit is contained in:
parent
bad4a27b9f
commit
bb6b7ded55
@ -21,16 +21,66 @@ class XGBoostError(Exception):
|
|||||||
"""Error throwed by xgboost trainer."""
|
"""Error throwed by xgboost trainer."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
PY3 = (sys.version_info[0] == 3)
|
||||||
|
|
||||||
if sys.version_info[0] == 3:
|
if PY3:
|
||||||
# pylint: disable=invalid-name, redefined-builtin
|
# pylint: disable=invalid-name, redefined-builtin
|
||||||
STRING_TYPES = str,
|
STRING_TYPES = str,
|
||||||
unicode = str
|
|
||||||
else:
|
else:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
STRING_TYPES = basestring,
|
STRING_TYPES = basestring,
|
||||||
|
|
||||||
|
|
||||||
|
def from_pystr_to_cstr(data):
|
||||||
|
"""Convert a list of Python str to C pointer
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : list
|
||||||
|
list of str
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(data, list):
|
||||||
|
pointers = (ctypes.c_char_p * len(data))()
|
||||||
|
if PY3:
|
||||||
|
data = [bytes(d, 'utf-8') for d in data]
|
||||||
|
else:
|
||||||
|
data = [d.encode('utf-8') if isinstance(d, unicode) else d
|
||||||
|
for d in data]
|
||||||
|
pointers[:] = data
|
||||||
|
return pointers
|
||||||
|
else:
|
||||||
|
# copy from above when we actually use it
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def from_cstr_to_pystr(data, length):
|
||||||
|
"""Revert C pointer to Python str
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : ctypes pointer
|
||||||
|
pointer to data
|
||||||
|
length : ctypes pointer
|
||||||
|
pointer to length of data
|
||||||
|
"""
|
||||||
|
if PY3:
|
||||||
|
res = []
|
||||||
|
for i in range(length.value):
|
||||||
|
try:
|
||||||
|
res.append(str(data[i].decode('ascii')))
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
res.append(str(data[i].decode('utf-8')))
|
||||||
|
else:
|
||||||
|
res = []
|
||||||
|
for i in range(length.value):
|
||||||
|
try:
|
||||||
|
res.append(str(data[i].decode('ascii')))
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
res.append(unicode(data[i].decode('utf-8')))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def find_lib_path():
|
def find_lib_path():
|
||||||
"""Load find the path to xgboost dynamic library files.
|
"""Load find the path to xgboost dynamic library files.
|
||||||
|
|
||||||
@ -787,21 +837,12 @@ class Booster(object):
|
|||||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
if self.feature_names is not None and fmap == '':
|
if self.feature_names is not None and fmap == '':
|
||||||
flen = int(len(self.feature_names))
|
flen = int(len(self.feature_names))
|
||||||
fname = (ctypes.c_char_p * flen)()
|
|
||||||
ftype = (ctypes.c_char_p * flen)()
|
fname = from_pystr_to_cstr(self.feature_names)
|
||||||
|
|
||||||
# supports quantitative type only
|
# supports quantitative type only
|
||||||
# {'q': quantitative, 'i': indicator}
|
# {'q': quantitative, 'i': indicator}
|
||||||
if sys.version_info[0] == 3:
|
ftype = from_pystr_to_cstr(['q'] * flen)
|
||||||
features = [bytes(f, 'utf-8') for f in self.feature_names]
|
|
||||||
types = [bytes('q', 'utf-8')] * flen
|
|
||||||
else:
|
|
||||||
features = [f.encode('utf-8') if isinstance(f, unicode) else f
|
|
||||||
for f in self.feature_names]
|
|
||||||
types = ['q'] * flen
|
|
||||||
|
|
||||||
fname[:] = features
|
|
||||||
ftype[:] = types
|
|
||||||
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
|
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
|
||||||
flen,
|
flen,
|
||||||
fname,
|
fname,
|
||||||
@ -815,13 +856,7 @@ class Booster(object):
|
|||||||
int(with_stats),
|
int(with_stats),
|
||||||
ctypes.byref(length),
|
ctypes.byref(length),
|
||||||
ctypes.byref(sarr)))
|
ctypes.byref(sarr)))
|
||||||
|
res = from_cstr_to_pystr(sarr, length)
|
||||||
res = []
|
|
||||||
for i in range(length.value):
|
|
||||||
try:
|
|
||||||
res.append(str(sarr[i].decode('ascii')))
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
res.append(unicode(sarr[i].decode('utf-8')))
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_fscore(self, fmap=''):
|
def get_fscore(self, fmap=''):
|
||||||
|
|||||||
@ -6,106 +6,125 @@ import unittest
|
|||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
|
|
||||||
|
|
||||||
class TestBasic(unittest.TestCase):
|
class TestBasic(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
|
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||||
|
# specify validations set to watch performance
|
||||||
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
|
num_round = 2
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
# this is prediction
|
||||||
|
preds = bst.predict(dtest)
|
||||||
|
labels = dtest.get_label()
|
||||||
|
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
||||||
|
# error must be smaller than 10%
|
||||||
|
assert err < 0.1
|
||||||
|
|
||||||
|
# save dmatrix into binary buffer
|
||||||
|
dtest.save_binary('dtest.buffer')
|
||||||
|
# save model
|
||||||
|
bst.save_model('xgb.model')
|
||||||
|
# load model and data in
|
||||||
|
bst2 = xgb.Booster(model_file='xgb.model')
|
||||||
|
dtest2 = xgb.DMatrix('dtest.buffer')
|
||||||
|
preds2 = bst2.predict(dtest2)
|
||||||
|
# assert they are the same
|
||||||
|
assert np.sum(np.abs(preds2-preds)) == 0
|
||||||
|
|
||||||
|
def test_dmatrix_init(self):
|
||||||
|
data = np.random.randn(5, 5)
|
||||||
|
|
||||||
|
# different length
|
||||||
|
self.assertRaises(ValueError, xgb.DMatrix, data,
|
||||||
|
feature_names=list('abcdef'))
|
||||||
|
# contains duplicates
|
||||||
|
self.assertRaises(ValueError, xgb.DMatrix, data,
|
||||||
|
feature_names=['a', 'b', 'c', 'd', 'd'])
|
||||||
|
# contains symbol
|
||||||
|
self.assertRaises(ValueError, xgb.DMatrix, data,
|
||||||
|
feature_names=['a', 'b', 'c', 'd', 'e=1'])
|
||||||
|
|
||||||
|
def test_feature_names(self):
|
||||||
|
data = np.random.randn(100, 5)
|
||||||
|
target = np.array([0, 1] * 50)
|
||||||
|
|
||||||
|
cases = [['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5'],
|
||||||
|
[u'要因1', u'要因2', u'要因3', u'要因4', u'要因5']]
|
||||||
|
|
||||||
|
for features in cases:
|
||||||
|
dm = xgb.DMatrix(data, label=target,
|
||||||
|
feature_names=features)
|
||||||
|
assert dm.feature_names == features
|
||||||
|
assert dm.num_row() == 100
|
||||||
|
assert dm.num_col() == 5
|
||||||
|
|
||||||
|
params={'objective': 'multi:softprob',
|
||||||
|
'eval_metric': 'mlogloss',
|
||||||
|
'eta': 0.3,
|
||||||
|
'num_class': 3}
|
||||||
|
|
||||||
|
bst = xgb.train(params, dm, num_boost_round=10)
|
||||||
|
scores = bst.get_fscore()
|
||||||
|
assert list(sorted(k for k in scores)) == features
|
||||||
|
|
||||||
|
dummy = np.random.randn(5, 5)
|
||||||
|
dm = xgb.DMatrix(dummy, feature_names=features)
|
||||||
|
bst.predict(dm)
|
||||||
|
|
||||||
|
# different feature name must raises error
|
||||||
|
dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
|
||||||
|
self.assertRaises(ValueError, bst.predict, dm)
|
||||||
|
|
||||||
def test_load_file_invalid(self):
|
def test_load_file_invalid(self):
|
||||||
|
|
||||||
self.assertRaises(ValueError, xgb.Booster,
|
self.assertRaises(ValueError, xgb.Booster,
|
||||||
model_file='incorrect_path')
|
model_file='incorrect_path')
|
||||||
|
|
||||||
|
def test_plotting(self):
|
||||||
|
bst2 = xgb.Booster(model_file='xgb.model')
|
||||||
|
# plotting
|
||||||
|
|
||||||
def test_basic():
|
import matplotlib
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
matplotlib.use('Agg')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
|
||||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
|
||||||
# specify validations set to watch performance
|
|
||||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
|
||||||
num_round = 2
|
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
|
||||||
# this is prediction
|
|
||||||
preds = bst.predict(dtest)
|
|
||||||
labels = dtest.get_label()
|
|
||||||
err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds))
|
|
||||||
# error must be smaller than 10%
|
|
||||||
assert err < 0.1
|
|
||||||
|
|
||||||
# save dmatrix into binary buffer
|
from matplotlib.axes import Axes
|
||||||
dtest.save_binary('dtest.buffer')
|
from graphviz import Digraph
|
||||||
# save model
|
|
||||||
bst.save_model('xgb.model')
|
|
||||||
# load model and data in
|
|
||||||
bst2 = xgb.Booster(model_file='xgb.model')
|
|
||||||
dtest2 = xgb.DMatrix('dtest.buffer')
|
|
||||||
preds2 = bst2.predict(dtest2)
|
|
||||||
# assert they are the same
|
|
||||||
assert np.sum(np.abs(preds2-preds)) == 0
|
|
||||||
|
|
||||||
def test_feature_names():
|
ax = xgb.plot_importance(bst2)
|
||||||
data = np.random.randn(100, 5)
|
assert isinstance(ax, Axes)
|
||||||
target = np.array([0, 1] * 50)
|
assert ax.get_title() == 'Feature importance'
|
||||||
|
assert ax.get_xlabel() == 'F score'
|
||||||
|
assert ax.get_ylabel() == 'Features'
|
||||||
|
assert len(ax.patches) == 4
|
||||||
|
|
||||||
cases = [['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5'],
|
ax = xgb.plot_importance(bst2, color='r',
|
||||||
[u'要因1', u'要因2', u'要因3', u'要因4', u'要因5']]
|
title='t', xlabel='x', ylabel='y')
|
||||||
|
assert isinstance(ax, Axes)
|
||||||
for features in cases:
|
assert ax.get_title() == 't'
|
||||||
dm = xgb.DMatrix(data, label=target,
|
assert ax.get_xlabel() == 'x'
|
||||||
feature_names=features)
|
assert ax.get_ylabel() == 'y'
|
||||||
assert dm.feature_names == features
|
assert len(ax.patches) == 4
|
||||||
assert dm.num_row() == 100
|
for p in ax.patches:
|
||||||
assert dm.num_col() == 5
|
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||||
|
|
||||||
params={'objective': 'multi:softprob',
|
|
||||||
'eval_metric': 'mlogloss',
|
|
||||||
'eta': 0.3,
|
|
||||||
'num_class': 3}
|
|
||||||
|
|
||||||
bst = xgb.train(params, dm, num_boost_round=10)
|
|
||||||
scores = bst.get_fscore()
|
|
||||||
assert list(sorted(k for k in scores)) == features
|
|
||||||
|
|
||||||
|
|
||||||
def test_plotting():
|
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
|
||||||
bst2 = xgb.Booster(model_file='xgb.model')
|
title=None, xlabel=None, ylabel=None)
|
||||||
# plotting
|
assert isinstance(ax, Axes)
|
||||||
|
assert ax.get_title() == ''
|
||||||
|
assert ax.get_xlabel() == ''
|
||||||
|
assert ax.get_ylabel() == ''
|
||||||
|
assert len(ax.patches) == 4
|
||||||
|
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||||
|
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
|
||||||
|
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
||||||
|
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
||||||
|
|
||||||
import matplotlib
|
g = xgb.to_graphviz(bst2, num_trees=0)
|
||||||
matplotlib.use('Agg')
|
assert isinstance(g, Digraph)
|
||||||
|
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||||
|
assert isinstance(ax, Axes)
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
|
||||||
from graphviz import Digraph
|
|
||||||
|
|
||||||
ax = xgb.plot_importance(bst2)
|
|
||||||
assert isinstance(ax, Axes)
|
|
||||||
assert ax.get_title() == 'Feature importance'
|
|
||||||
assert ax.get_xlabel() == 'F score'
|
|
||||||
assert ax.get_ylabel() == 'Features'
|
|
||||||
assert len(ax.patches) == 4
|
|
||||||
|
|
||||||
ax = xgb.plot_importance(bst2, color='r',
|
|
||||||
title='t', xlabel='x', ylabel='y')
|
|
||||||
assert isinstance(ax, Axes)
|
|
||||||
assert ax.get_title() == 't'
|
|
||||||
assert ax.get_xlabel() == 'x'
|
|
||||||
assert ax.get_ylabel() == 'y'
|
|
||||||
assert len(ax.patches) == 4
|
|
||||||
for p in ax.patches:
|
|
||||||
assert p.get_facecolor() == (1.0, 0, 0, 1.0) # red
|
|
||||||
|
|
||||||
|
|
||||||
ax = xgb.plot_importance(bst2, color=['r', 'r', 'b', 'b'],
|
|
||||||
title=None, xlabel=None, ylabel=None)
|
|
||||||
assert isinstance(ax, Axes)
|
|
||||||
assert ax.get_title() == ''
|
|
||||||
assert ax.get_xlabel() == ''
|
|
||||||
assert ax.get_ylabel() == ''
|
|
||||||
assert len(ax.patches) == 4
|
|
||||||
assert ax.patches[0].get_facecolor() == (1.0, 0, 0, 1.0) # red
|
|
||||||
assert ax.patches[1].get_facecolor() == (1.0, 0, 0, 1.0) # red
|
|
||||||
assert ax.patches[2].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
|
||||||
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
|
||||||
|
|
||||||
g = xgb.to_graphviz(bst2, num_trees=0)
|
|
||||||
assert isinstance(g, Digraph)
|
|
||||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
|
||||||
assert isinstance(ax, Axes)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user