* Revert "Accept numpy array view. (#4147)"
This reverts commit a985a99cf0.
* Fix #4163: always copy sliced data
* Remove print() from the test; check shape equality
* Check if 'base' attribute exists
* Fix lint
* Address reviewer comment
* Fix lint
This commit is contained in:
committed by
GitHub
parent
cecbe0cf71
commit
2aaae2e7bb
@@ -219,17 +219,6 @@ def c_array(ctype, values):
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def _get_length_and_stride(data):
|
||||
"Return length and stride of 1-D data."
|
||||
if isinstance(data, np.ndarray) and data.base is not None:
|
||||
length = len(data.base)
|
||||
stride = data.strides[0] // data.dtype.itemsize
|
||||
else:
|
||||
length = len(data)
|
||||
stride = 1
|
||||
return length, stride
|
||||
|
||||
|
||||
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
|
||||
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
|
||||
'float16': 'float', 'float32': 'float', 'float64': 'float',
|
||||
@@ -595,14 +584,16 @@ class DMatrix(object):
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
if getattr(data, 'base', None) is not None and \
|
||||
data.base is not None and isinstance(data, np.ndarray) \
|
||||
and isinstance(data.base, np.ndarray) and (not data.flags.c_contiguous):
|
||||
self.set_float_info_npy2d(field, data)
|
||||
return
|
||||
c_data = c_array(ctypes.c_float, data)
|
||||
length, stride = _get_length_and_stride(data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
|
||||
self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(stride),
|
||||
c_bst_ulong(length)))
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def set_float_info_npy2d(self, field, data):
|
||||
"""Set float type property into the DMatrix
|
||||
@@ -616,15 +607,19 @@ class DMatrix(object):
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
data = np.array(data, copy=False, dtype=np.float32)
|
||||
if getattr(data, 'base', None) is not None and \
|
||||
data.base is not None and isinstance(data, np.ndarray) \
|
||||
and isinstance(data.base, np.ndarray) and (not data.flags.c_contiguous):
|
||||
warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " +
|
||||
"because it will generate extra copies and increase memory consumption")
|
||||
data = np.array(data, copy=True, dtype=np.float32)
|
||||
else:
|
||||
data = np.array(data, copy=False, dtype=np.float32)
|
||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
length, stride = _get_length_and_stride(data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
|
||||
self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(stride),
|
||||
c_bst_ulong(length)))
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def set_uint_info(self, field, data):
|
||||
"""Set uint type property into the DMatrix.
|
||||
@@ -637,15 +632,18 @@ class DMatrix(object):
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
data = np.array(data, copy=False, dtype=ctypes.c_uint)
|
||||
c_data = c_array(ctypes.c_uint, data)
|
||||
length, stride = _get_length_and_stride(data)
|
||||
_check_call(_LIB.XGDMatrixSetUIntInfoStrided(
|
||||
self.handle,
|
||||
c_str(field),
|
||||
c_data,
|
||||
c_bst_ulong(stride),
|
||||
c_bst_ulong(length)))
|
||||
if getattr(data, 'base', None) is not None and \
|
||||
data.base is not None and isinstance(data, np.ndarray) \
|
||||
and isinstance(data.base, np.ndarray) and (not data.flags.c_contiguous):
|
||||
warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " +
|
||||
"because it will generate extra copies and increase memory consumption")
|
||||
data = np.array(data, copy=True, dtype=ctypes.c_uint)
|
||||
else:
|
||||
data = np.array(data, copy=False, dtype=ctypes.c_uint)
|
||||
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
||||
c_str(field),
|
||||
c_array(ctypes.c_uint, data),
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def save_binary(self, fname, silent=True):
|
||||
"""Save DMatrix to an XGBoost buffer.
|
||||
@@ -741,7 +739,9 @@ class DMatrix(object):
|
||||
group : array like
|
||||
Group size of each group
|
||||
"""
|
||||
self.set_uint_info('group', group)
|
||||
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
|
||||
c_array(ctypes.c_uint, group),
|
||||
c_bst_ulong(len(group))))
|
||||
|
||||
def get_label(self):
|
||||
"""Get the label of the DMatrix.
|
||||
|
||||
Reference in New Issue
Block a user