Accept numpy array view. (#4147)

* Accept array view (slice) in metainfo.
This commit is contained in:
Jiaming Yuan
2019-02-18 22:21:34 +08:00
committed by GitHub
parent 0ff84d950e
commit a985a99cf0
6 changed files with 152 additions and 43 deletions

View File

@@ -219,6 +219,17 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
def _get_length_and_stride(data):
"Return length and stride of 1-D data."
if isinstance(data, np.ndarray) and data.base is not None:
length = len(data.base)
stride = data.strides[0] // data.dtype.itemsize
else:
length = len(data)
stride = 1
return length, stride
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
@@ -585,10 +596,13 @@ class DMatrix(object):
The array of data to be set
"""
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
length, stride = _get_length_and_stride(data)
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
self.handle,
c_str(field),
c_data,
c_bst_ulong(stride),
c_bst_ulong(length)))
def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix
@@ -604,10 +618,13 @@ class DMatrix(object):
"""
data = np.array(data, copy=False, dtype=np.float32)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
length, stride = _get_length_and_stride(data)
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
self.handle,
c_str(field),
c_data,
c_bst_ulong(stride),
c_bst_ulong(length)))
def set_uint_info(self, field, data):
"""Set uint type property into the DMatrix.
@@ -620,10 +637,15 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
data = np.array(data, copy=False, dtype=ctypes.c_uint)
c_data = c_array(ctypes.c_uint, data)
length, stride = _get_length_and_stride(data)
_check_call(_LIB.XGDMatrixSetUIntInfoStrided(
self.handle,
c_str(field),
c_data,
c_bst_ulong(stride),
c_bst_ulong(length)))
def save_binary(self, fname, silent=True):
"""Save DMatrix to an XGBoost buffer.
@@ -719,9 +741,7 @@ class DMatrix(object):
group : array like
Group size of each group
"""
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
c_array(ctypes.c_uint, group),
c_bst_ulong(len(group))))
self.set_uint_info('group', group)
def get_label(self):
"""Get the label of the DMatrix.