Fix get_uint_info() (#3442)

* Add regression test
This commit is contained in:
Philip Hyunsu Cho 2018-07-05 20:06:59 -07:00 committed by Tianqi Chen
parent 48d6e68690
commit 66e74d2223
2 changed files with 17 additions and 3 deletions

View File

@ -147,8 +147,15 @@ def _check_call(ret):
def ctypes2numpy(cptr, length, dtype):
"""Convert a ctypes pointer array to a numpy array.
"""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
raise RuntimeError('expected float pointer')
NUMPY_TO_CTYPES_MAPPING = {
np.float32: ctypes.c_float,
np.uint32: ctypes.c_uint,
}
if dtype not in NUMPY_TO_CTYPES_MAPPING:
raise RuntimeError('Supported types: {}'.format(NUMPY_TO_CTYPES_MAPPING.keys()))
ctype = NUMPY_TO_CTYPES_MAPPING[dtype]
if not isinstance(cptr, ctypes.POINTER(ctype)):
raise RuntimeError('expected {} pointer'.format(ctype))
res = np.zeros(length, dtype=dtype)
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
raise RuntimeError('memmove failed')
@ -501,7 +508,7 @@ class DMatrix(object):
Returns
-------
info : array
a numpy array of float information of the data
a numpy array of unsigned integer information of the data
"""
length = c_bst_ulong()
ret = ctypes.POINTER(ctypes.c_uint)()

View File

@ -299,3 +299,10 @@ class TestBasic(unittest.TestCase):
)
output = out.getvalue().strip()
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'
def test_get_info(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtrain.get_float_info('label')
dtrain.get_float_info('weight')
dtrain.get_float_info('base_margin')
dtrain.get_uint_info('root_index')