Accept numpy array view. (#4147)
* Accept array view (slice) in metainfo.
This commit is contained in:
@@ -219,6 +219,17 @@ def c_array(ctype, values):
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def _get_length_and_stride(data):
|
||||
"Return length and stride of 1-D data."
|
||||
if isinstance(data, np.ndarray) and data.base is not None:
|
||||
length = len(data.base)
|
||||
stride = data.strides[0] // data.dtype.itemsize
|
||||
else:
|
||||
length = len(data)
|
||||
stride = 1
|
||||
return length, stride
|
||||
|
||||
|
||||
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
|
||||
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
|
||||
'float16': 'float', 'float32': 'float', 'float64': 'float',
|
||||
@@ -585,10 +596,13 @@ class DMatrix(object):
|
||||
The array of data to be set
|
||||
"""
|
||||
c_data = c_array(ctypes.c_float, data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(len(data))))
|
||||
length, stride = _get_length_and_stride(data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
|
||||
self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(stride),
|
||||
c_bst_ulong(length)))
|
||||
|
||||
def set_float_info_npy2d(self, field, data):
|
||||
"""Set float type property into the DMatrix
|
||||
@@ -604,10 +618,13 @@ class DMatrix(object):
|
||||
"""
|
||||
data = np.array(data, copy=False, dtype=np.float32)
|
||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(len(data))))
|
||||
length, stride = _get_length_and_stride(data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
|
||||
self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(stride),
|
||||
c_bst_ulong(length)))
|
||||
|
||||
def set_uint_info(self, field, data):
|
||||
"""Set uint type property into the DMatrix.
|
||||
@@ -620,10 +637,15 @@ class DMatrix(object):
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
||||
c_str(field),
|
||||
c_array(ctypes.c_uint, data),
|
||||
c_bst_ulong(len(data))))
|
||||
data = np.array(data, copy=False, dtype=ctypes.c_uint)
|
||||
c_data = c_array(ctypes.c_uint, data)
|
||||
length, stride = _get_length_and_stride(data)
|
||||
_check_call(_LIB.XGDMatrixSetUIntInfoStrided(
|
||||
self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(stride),
|
||||
c_bst_ulong(length)))
|
||||
|
||||
def save_binary(self, fname, silent=True):
|
||||
"""Save DMatrix to an XGBoost buffer.
|
||||
@@ -719,9 +741,7 @@ class DMatrix(object):
|
||||
group : array like
|
||||
Group size of each group
|
||||
"""
|
||||
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
|
||||
c_array(ctypes.c_uint, group),
|
||||
c_bst_ulong(len(group))))
|
||||
self.set_uint_info('group', group)
|
||||
|
||||
def get_label(self):
|
||||
"""Get the label of the DMatrix.
|
||||
|
||||
Reference in New Issue
Block a user