[Python] Use appropriate integer types when calling native code. (#2361)
Don't use implicit conversions to c_int, which incidentally happen to work on (some) 64-bit platforms, but: * may lead to truncation of the input value to a 32-bit signed int, * cause segfaults on some 32-bit architectures (tested on Ubuntu ARM, but is also the likely cause of issue #1707). Also, when passing references use explicit 64-bit integers, where needed, instead of c_ulong, which is not guaranteed to be this large.
This commit is contained in:
parent
ed8da45f9d
commit
ed6384ecbf
@ -17,6 +17,9 @@ from .libpath import find_lib_path
|
||||
|
||||
from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED
|
||||
|
||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||
c_bst_ulong = ctypes.c_uint64
|
||||
|
||||
|
||||
class XGBoostError(Exception):
|
||||
"""Error thrown by xgboost trainer."""
|
||||
@ -258,7 +261,7 @@ class DMatrix(object):
|
||||
if isinstance(data, STRING_TYPES):
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
|
||||
int(silent),
|
||||
ctypes.c_int(silent),
|
||||
ctypes.byref(self.handle)))
|
||||
elif isinstance(data, scipy.sparse.csr_matrix):
|
||||
self._init_from_csr(data)
|
||||
@ -290,8 +293,9 @@ class DMatrix(object):
|
||||
_check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr),
|
||||
c_array(ctypes.c_uint, csr.indices),
|
||||
c_array(ctypes.c_float, csr.data),
|
||||
len(csr.indptr), len(csr.data),
|
||||
csr.shape[1],
|
||||
ctypes.c_size_t(len(csr.indptr)),
|
||||
ctypes.c_size_t(len(csr.data)),
|
||||
ctypes.c_size_t(csr.shape[1]),
|
||||
ctypes.byref(self.handle)))
|
||||
|
||||
def _init_from_csc(self, csc):
|
||||
@ -304,8 +308,9 @@ class DMatrix(object):
|
||||
_check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr),
|
||||
c_array(ctypes.c_uint, csc.indices),
|
||||
c_array(ctypes.c_float, csc.data),
|
||||
len(csc.indptr), len(csc.data),
|
||||
csc.shape[0],
|
||||
ctypes.c_size_t(len(csc.indptr)),
|
||||
ctypes.c_size_t(len(csc.data)),
|
||||
ctypes.c_size_t(csc.shape[0]),
|
||||
ctypes.byref(self.handle)))
|
||||
|
||||
def _init_from_npy2d(self, mat, missing):
|
||||
@ -329,7 +334,8 @@ class DMatrix(object):
|
||||
self.handle = ctypes.c_void_p()
|
||||
missing = missing if missing is not None else np.nan
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
mat.shape[0], mat.shape[1],
|
||||
c_bst_ulong(mat.shape[0]),
|
||||
c_bst_ulong(mat.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
ctypes.byref(self.handle)))
|
||||
|
||||
@ -349,7 +355,7 @@ class DMatrix(object):
|
||||
info : array
|
||||
a numpy array of float information of the data
|
||||
"""
|
||||
length = ctypes.c_ulong()
|
||||
length = c_bst_ulong()
|
||||
ret = ctypes.POINTER(ctypes.c_float)()
|
||||
_check_call(_LIB.XGDMatrixGetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
@ -370,7 +376,7 @@ class DMatrix(object):
|
||||
info : array
|
||||
a numpy array of float information of the data
|
||||
"""
|
||||
length = ctypes.c_ulong()
|
||||
length = c_bst_ulong()
|
||||
ret = ctypes.POINTER(ctypes.c_uint)()
|
||||
_check_call(_LIB.XGDMatrixGetUIntInfo(self.handle,
|
||||
c_str(field),
|
||||
@ -392,7 +398,7 @@ class DMatrix(object):
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
c_array(ctypes.c_float, data),
|
||||
len(data)))
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def set_uint_info(self, field, data):
|
||||
"""Set uint type property into the DMatrix.
|
||||
@ -408,7 +414,7 @@ class DMatrix(object):
|
||||
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
||||
c_str(field),
|
||||
c_array(ctypes.c_uint, data),
|
||||
len(data)))
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def save_binary(self, fname, silent=True):
|
||||
"""Save DMatrix to an XGBoost buffer.
|
||||
@ -422,7 +428,7 @@ class DMatrix(object):
|
||||
"""
|
||||
_check_call(_LIB.XGDMatrixSaveBinary(self.handle,
|
||||
c_str(fname),
|
||||
int(silent)))
|
||||
ctypes.c_int(silent)))
|
||||
|
||||
def set_label(self, label):
|
||||
"""Set label of dmatrix
|
||||
@ -470,7 +476,7 @@ class DMatrix(object):
|
||||
"""
|
||||
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
|
||||
c_array(ctypes.c_uint, group),
|
||||
len(group)))
|
||||
c_bst_ulong(len(group))))
|
||||
|
||||
def get_label(self):
|
||||
"""Get the label of the DMatrix.
|
||||
@ -506,7 +512,7 @@ class DMatrix(object):
|
||||
-------
|
||||
number of rows : int
|
||||
"""
|
||||
ret = ctypes.c_ulong()
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumRow(self.handle,
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
@ -518,7 +524,7 @@ class DMatrix(object):
|
||||
-------
|
||||
number of columns : int
|
||||
"""
|
||||
ret = ctypes.c_uint()
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixNumCol(self.handle,
|
||||
ctypes.byref(ret)))
|
||||
return ret.value
|
||||
@ -540,7 +546,7 @@ class DMatrix(object):
|
||||
res.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixSliceDMatrix(self.handle,
|
||||
c_array(ctypes.c_int, rindex),
|
||||
len(rindex),
|
||||
c_bst_ulong(len(rindex)),
|
||||
ctypes.byref(res.handle)))
|
||||
return res
|
||||
|
||||
@ -659,7 +665,8 @@ class Booster(object):
|
||||
|
||||
dmats = c_array(ctypes.c_void_p, [d.handle for d in cache])
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, len(cache), ctypes.byref(self.handle)))
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
||||
ctypes.byref(self.handle)))
|
||||
self.set_param({'seed': 0})
|
||||
self.set_param(params or {})
|
||||
if model_file is not None:
|
||||
@ -685,8 +692,8 @@ class Booster(object):
|
||||
buf = handle
|
||||
dmats = c_array(ctypes.c_void_p, [])
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, 0, ctypes.byref(handle)))
|
||||
length = ctypes.c_ulong(len(buf))
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(0), ctypes.byref(handle)))
|
||||
length = c_bst_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(handle, ptr, length))
|
||||
state['handle'] = handle
|
||||
@ -756,7 +763,7 @@ class Booster(object):
|
||||
result : dictionary of attribute_name: attribute_value pairs of strings.
|
||||
Returns an empty dict if there's no attributes.
|
||||
"""
|
||||
length = ctypes.c_ulong()
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGBoosterGetAttrNames(self.handle,
|
||||
ctypes.byref(length),
|
||||
@ -816,7 +823,8 @@ class Booster(object):
|
||||
self._validate_features(dtrain)
|
||||
|
||||
if fobj is None:
|
||||
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, iteration, dtrain.handle))
|
||||
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle, ctypes.c_int(iteration),
|
||||
dtrain.handle))
|
||||
else:
|
||||
pred = self.predict(dtrain)
|
||||
grad, hess = fobj(pred, dtrain)
|
||||
@ -844,7 +852,7 @@ class Booster(object):
|
||||
_check_call(_LIB.XGBoosterBoostOneIter(self.handle, dtrain.handle,
|
||||
c_array(ctypes.c_float, grad),
|
||||
c_array(ctypes.c_float, hess),
|
||||
len(grad)))
|
||||
c_bst_ulong(len(grad))))
|
||||
|
||||
def eval_set(self, evals, iteration=0, feval=None):
|
||||
# pylint: disable=invalid-name
|
||||
@ -874,8 +882,9 @@ class Booster(object):
|
||||
dmats = c_array(ctypes.c_void_p, [d[0].handle for d in evals])
|
||||
evnames = c_array(ctypes.c_char_p, [c_str(d[1]) for d in evals])
|
||||
msg = ctypes.c_char_p()
|
||||
_check_call(_LIB.XGBoosterEvalOneIter(self.handle, iteration,
|
||||
dmats, evnames, len(evals),
|
||||
_check_call(_LIB.XGBoosterEvalOneIter(self.handle, ctypes.c_int(iteration),
|
||||
dmats, evnames,
|
||||
c_bst_ulong(len(evals)),
|
||||
ctypes.byref(msg)))
|
||||
res = msg.value.decode()
|
||||
if feval is not None:
|
||||
@ -958,10 +967,11 @@ class Booster(object):
|
||||
|
||||
self._validate_features(data)
|
||||
|
||||
length = ctypes.c_ulong()
|
||||
length = c_bst_ulong()
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
|
||||
option_mask, ntree_limit,
|
||||
ctypes.c_int(option_mask),
|
||||
ctypes.c_uint(ntree_limit),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
@ -995,7 +1005,7 @@ class Booster(object):
|
||||
-------
|
||||
a in memory buffer representation of the model
|
||||
"""
|
||||
length = ctypes.c_ulong()
|
||||
length = c_bst_ulong()
|
||||
cptr = ctypes.POINTER(ctypes.c_char)()
|
||||
_check_call(_LIB.XGBoosterGetModelRaw(self.handle,
|
||||
ctypes.byref(length),
|
||||
@ -1016,7 +1026,7 @@ class Booster(object):
|
||||
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(fname)))
|
||||
else:
|
||||
buf = fname
|
||||
length = ctypes.c_ulong(len(buf))
|
||||
length = c_bst_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
|
||||
|
||||
@ -1050,10 +1060,10 @@ class Booster(object):
|
||||
Returns the dump the model as a list of strings.
|
||||
"""
|
||||
|
||||
length = ctypes.c_ulong()
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
if self.feature_names is not None and fmap == '':
|
||||
flen = int(len(self.feature_names))
|
||||
flen = len(self.feature_names)
|
||||
|
||||
fname = from_pystr_to_cstr(self.feature_names)
|
||||
|
||||
@ -1065,10 +1075,10 @@ class Booster(object):
|
||||
ftype = from_pystr_to_cstr(self.feature_types)
|
||||
_check_call(_LIB.XGBoosterDumpModelExWithFeatures(
|
||||
self.handle,
|
||||
flen,
|
||||
ctypes.c_int(flen),
|
||||
fname,
|
||||
ftype,
|
||||
int(with_stats),
|
||||
ctypes.c_int(with_stats),
|
||||
c_str(dump_format),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
@ -1077,7 +1087,7 @@ class Booster(object):
|
||||
raise ValueError("No such file: {0}".format(fmap))
|
||||
_check_call(_LIB.XGBoosterDumpModelEx(self.handle,
|
||||
c_str(fmap),
|
||||
int(with_stats),
|
||||
ctypes.c_int(with_stats),
|
||||
c_str(dump_format),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user