From 66e74d222330adf3e41bb1e1f34ae1dd4aae15f0 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Thu, 5 Jul 2018 20:06:59 -0700 Subject: [PATCH] Fix get_uint_info() (#3442) * Add regression test --- python-package/xgboost/core.py | 13 ++++++++++--- tests/python/test_basic.py | 7 +++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index c03321a20..d35533293 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -147,8 +147,15 @@ def _check_call(ret): def ctypes2numpy(cptr, length, dtype): """Convert a ctypes pointer array to a numpy array. """ - if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)): - raise RuntimeError('expected float pointer') + NUMPY_TO_CTYPES_MAPPING = { + np.float32: ctypes.c_float, + np.uint32: ctypes.c_uint, + } + if dtype not in NUMPY_TO_CTYPES_MAPPING: + raise RuntimeError('Supported types: {}'.format(NUMPY_TO_CTYPES_MAPPING.keys())) + ctype = NUMPY_TO_CTYPES_MAPPING[dtype] + if not isinstance(cptr, ctypes.POINTER(ctype)): + raise RuntimeError('expected {} pointer'.format(ctype)) res = np.zeros(length, dtype=dtype) if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): raise RuntimeError('memmove failed') @@ -501,7 +508,7 @@ class DMatrix(object): Returns ------- info : array - a numpy array of float information of the data + a numpy array of unsigned integer information of the data """ length = c_bst_ulong() ret = ctypes.POINTER(ctypes.c_uint)() diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index e008b02c8..77336dcff 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -299,3 +299,10 @@ class TestBasic(unittest.TestCase): ) output = out.getvalue().strip() assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]' + + def test_get_info(self): + dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') + dtrain.get_float_info('label') + dtrain.get_float_info('weight') + dtrain.get_float_info('base_margin') + dtrain.get_uint_info('root_index')