Cleanup str roundtrip using ctypes

This commit is contained in:
sinhrks
2015-09-16 20:37:19 +09:00
parent bad4a27b9f
commit bb6b7ded55
2 changed files with 163 additions and 109 deletions

View File

@@ -21,16 +21,66 @@ class XGBoostError(Exception):
"""Error throwed by xgboost trainer."""
pass
PY3 = (sys.version_info[0] == 3)
if sys.version_info[0] == 3:
if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str,
unicode = str
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,
def from_pystr_to_cstr(data):
"""Convert a list of Python str to C pointer
Parameters
----------
data : list
list of str
"""
if isinstance(data, list):
pointers = (ctypes.c_char_p * len(data))()
if PY3:
data = [bytes(d, 'utf-8') for d in data]
else:
data = [d.encode('utf-8') if isinstance(d, unicode) else d
for d in data]
pointers[:] = data
return pointers
else:
# copy from above when we actually use it
raise NotImplementedError
def from_cstr_to_pystr(data, length):
"""Revert C pointer to Python str
Parameters
----------
data : ctypes pointer
pointer to data
length : ctypes pointer
pointer to length of data
"""
if PY3:
res = []
for i in range(length.value):
try:
res.append(str(data[i].decode('ascii')))
except UnicodeDecodeError:
res.append(str(data[i].decode('utf-8')))
else:
res = []
for i in range(length.value):
try:
res.append(str(data[i].decode('ascii')))
except UnicodeDecodeError:
res.append(unicode(data[i].decode('utf-8')))
return res
def find_lib_path():
"""Load find the path to xgboost dynamic library files.
@@ -787,21 +837,12 @@ class Booster(object):
sarr = ctypes.POINTER(ctypes.c_char_p)()
if self.feature_names is not None and fmap == '':
flen = int(len(self.feature_names))
fname = (ctypes.c_char_p * flen)()
ftype = (ctypes.c_char_p * flen)()
fname = from_pystr_to_cstr(self.feature_names)
# supports quantitative type only
# {'q': quantitative, 'i': indicator}
if sys.version_info[0] == 3:
features = [bytes(f, 'utf-8') for f in self.feature_names]
types = [bytes('q', 'utf-8')] * flen
else:
features = [f.encode('utf-8') if isinstance(f, unicode) else f
for f in self.feature_names]
types = ['q'] * flen
fname[:] = features
ftype[:] = types
ftype = from_pystr_to_cstr(['q'] * flen)
_check_call(_LIB.XGBoosterDumpModelWithFeatures(self.handle,
flen,
fname,
@@ -815,13 +856,7 @@ class Booster(object):
int(with_stats),
ctypes.byref(length),
ctypes.byref(sarr)))
res = []
for i in range(length.value):
try:
res.append(str(sarr[i].decode('ascii')))
except UnicodeDecodeError:
res.append(unicode(sarr[i].decode('utf-8')))
res = from_cstr_to_pystr(sarr, length)
return res
def get_fscore(self, fmap=''):