Cleanup str roundtrip using ctypes
This commit is contained in:
parent
bad4a27b9f
commit
bb6b7ded55
@ -21,16 +21,66 @@ class XGBoostError(Exception):
|
||||
"""Error throwed by xgboost trainer."""
|
||||
pass
|
||||
|
||||
PY3 = (sys.version_info[0] == 3)
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
if PY3:
|
||||
# pylint: disable=invalid-name, redefined-builtin
|
||||
STRING_TYPES = str,
|
||||
unicode = str
|
||||
else:
|
||||
# pylint: disable=invalid-name
|
||||
STRING_TYPES = basestring,
|
||||
|
||||
|
||||
def from_pystr_to_cstr(data):
|
||||
"""Convert a list of Python str to C pointer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : list
|
||||
list of str
|
||||
"""
|
||||
|
||||
if isinstance(data, list):
|
||||
pointers = (ctypes.c_char_p * len(data))()
|
||||
if PY3:
|
||||
data = [bytes(d, 'utf-8') for d in data]
|
||||
else:
|
||||
data = [d.encode('utf-8') if isinstance(d, unicode) else d
|
||||
for d in data]
|
||||
pointers[:] = data
|
||||
return pointers
|
||||
else:
|
||||
# copy from above when we actually use it
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def from_cstr_to_pystr(data, length):
|
||||
"""Revert C pointer to Python str
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ctypes pointer
|
||||
pointer to data
|
||||
length : ctypes pointer
|
||||
pointer to length of data
|
||||
"""
|
||||
if PY3:
|
||||
res = []
|
||||
for i in range(length.value):
|
||||
try:
|
||||
res.append(str(data[i].decode('ascii')))
|
||||
except UnicodeDecodeError:
|
||||
res.append(str(data[i].decode('utf-8')))
|
||||
else:
|
||||
res = []
|
||||
for i in range(length.value):
|
||||
try:
|
||||
res.append(str(data[i].decode('ascii')))
|
||||
except UnicodeDecodeError:
|
||||
res.append(unicode(data[i].decode('utf-8')))
|
||||
return res
|
||||
|
||||
|
||||
def find_lib_path():
|
||||
"""Load find the path to xgboost dynamic library files.
|
||||
|
||||
@ -787,21 +837,12 @@ class Booster(object):
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
if self.feature_names is not None and fmap == '':
|
||||
flen = int(len(self.feature_names))
|
||||
fname = (ctypes.c_char_p * flen)()
|
||||
ftype = (ctypes.c_char_p * flen)()
|
||||
|
||||
fname = from_pystr_to_cstr(self.feature_names)
|
||||
|
||||
# supports quantitative type only
|
||||
# {'q': quantitative, 'i': indicator}
|
||||
if sys.version_info[0] == 3:
|
||||
features = [bytes(f, 'utf-8') for f in self.feature_names]
|
||||
types = [bytes('q', 'utf-8')] * flen
|
||||
else:
|
||||
features = [f.encode('utf-8') if isinstance(f, unicode) else f
|
||||
for f in self.feature_names]
|
||||
types = ['q'] * flen
|
||||
|
||||
fname[:] = features
|
||||
ftype[:] = types
|
||||
ftype = from_pystr_to_cstr(['q'] * flen)
|
||||
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
|
||||
flen,
|
||||
fname,
|
||||
@ -815,13 +856,7 @@ class Booster(object):
|
||||
int(with_stats),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
|
||||
res = []
|
||||
for i in range(length.value):
|
||||
try:
|
||||
res.append(str(sarr[i].decode('ascii')))
|
||||
except UnicodeDecodeError:
|
||||
res.append(unicode(sarr[i].decode('utf-8')))
|
||||
res = from_cstr_to_pystr(sarr, length)
|
||||
return res
|
||||
|
||||
def get_fscore(self, fmap=''):
|
||||
|
||||
@ -6,16 +6,9 @@ import unittest
|
||||
|
||||
dpath = 'demo/data/'
|
||||
|
||||
|
||||
class TestBasic(unittest.TestCase):
|
||||
|
||||
def test_load_file_invalid(self):
|
||||
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
model_file='incorrect_path')
|
||||
|
||||
|
||||
def test_basic():
|
||||
def test_basic(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth':2, 'eta':1, 'silent':1, 'objective':'binary:logistic' }
|
||||
@ -41,7 +34,20 @@ def test_basic():
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2-preds)) == 0
|
||||
|
||||
def test_feature_names():
|
||||
def test_dmatrix_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
|
||||
# different length
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data,
|
||||
feature_names=list('abcdef'))
|
||||
# contains duplicates
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data,
|
||||
feature_names=['a', 'b', 'c', 'd', 'd'])
|
||||
# contains symbol
|
||||
self.assertRaises(ValueError, xgb.DMatrix, data,
|
||||
feature_names=['a', 'b', 'c', 'd', 'e=1'])
|
||||
|
||||
def test_feature_names(self):
|
||||
data = np.random.randn(100, 5)
|
||||
target = np.array([0, 1] * 50)
|
||||
|
||||
@ -64,8 +70,20 @@ def test_feature_names():
|
||||
scores = bst.get_fscore()
|
||||
assert list(sorted(k for k in scores)) == features
|
||||
|
||||
dummy = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(dummy, feature_names=features)
|
||||
bst.predict(dm)
|
||||
|
||||
def test_plotting():
|
||||
# different feature name must raises error
|
||||
dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
|
||||
self.assertRaises(ValueError, bst.predict, dm)
|
||||
|
||||
def test_load_file_invalid(self):
|
||||
|
||||
self.assertRaises(ValueError, xgb.Booster,
|
||||
model_file='incorrect_path')
|
||||
|
||||
def test_plotting(self):
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
# plotting
|
||||
|
||||
@ -109,3 +127,4 @@ def test_plotting():
|
||||
assert isinstance(g, Digraph)
|
||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user