parent
48d6e68690
commit
66e74d2223
@ -147,8 +147,15 @@ def _check_call(ret):
|
|||||||
def ctypes2numpy(cptr, length, dtype):
|
def ctypes2numpy(cptr, length, dtype):
|
||||||
"""Convert a ctypes pointer array to a numpy array.
|
"""Convert a ctypes pointer array to a numpy array.
|
||||||
"""
|
"""
|
||||||
if not isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
|
NUMPY_TO_CTYPES_MAPPING = {
|
||||||
raise RuntimeError('expected float pointer')
|
np.float32: ctypes.c_float,
|
||||||
|
np.uint32: ctypes.c_uint,
|
||||||
|
}
|
||||||
|
if dtype not in NUMPY_TO_CTYPES_MAPPING:
|
||||||
|
raise RuntimeError('Supported types: {}'.format(NUMPY_TO_CTYPES_MAPPING.keys()))
|
||||||
|
ctype = NUMPY_TO_CTYPES_MAPPING[dtype]
|
||||||
|
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||||
|
raise RuntimeError('expected {} pointer'.format(ctype))
|
||||||
res = np.zeros(length, dtype=dtype)
|
res = np.zeros(length, dtype=dtype)
|
||||||
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
||||||
raise RuntimeError('memmove failed')
|
raise RuntimeError('memmove failed')
|
||||||
@ -501,7 +508,7 @@ class DMatrix(object):
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
info : array
|
info : array
|
||||||
a numpy array of float information of the data
|
a numpy array of unsigned integer information of the data
|
||||||
"""
|
"""
|
||||||
length = c_bst_ulong()
|
length = c_bst_ulong()
|
||||||
ret = ctypes.POINTER(ctypes.c_uint)()
|
ret = ctypes.POINTER(ctypes.c_uint)()
|
||||||
|
|||||||
@ -299,3 +299,10 @@ class TestBasic(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
output = out.getvalue().strip()
|
output = out.getvalue().strip()
|
||||||
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'
|
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'
|
||||||
|
|
||||||
|
def test_get_info(self):
|
||||||
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
dtrain.get_float_info('label')
|
||||||
|
dtrain.get_float_info('weight')
|
||||||
|
dtrain.get_float_info('base_margin')
|
||||||
|
dtrain.get_uint_info('root_index')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user