Support numpy array interface (#6998)
This commit is contained in:
@@ -116,11 +116,6 @@ def _maybe_np_slice(data, dtype):
|
||||
'''
|
||||
try:
|
||||
if not data.flags.c_contiguous:
|
||||
warnings.warn(
|
||||
"Use of np.ndarray subsets (sliced data) is not recommended " +
|
||||
"because it will generate extra copies and increase " +
|
||||
"memory consumption. Consider using np.ascontiguousarray to " +
|
||||
"make the array contiguous.")
|
||||
data = np.array(data, copy=True, dtype=dtype)
|
||||
else:
|
||||
data = np.array(data, copy=False, dtype=dtype)
|
||||
@@ -130,44 +125,28 @@ def _maybe_np_slice(data, dtype):
|
||||
return data
|
||||
|
||||
|
||||
def _transform_np_array(data: np.ndarray) -> np.ndarray:
|
||||
if not isinstance(data, np.ndarray) and hasattr(data, '__array__'):
|
||||
data = np.array(data, copy=False)
|
||||
if len(data.shape) != 2:
|
||||
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
||||
data.shape)
|
||||
# flatten the array by rows and ensure it is float32. we try to avoid
|
||||
# data copies if possible (reshape returns a view when possible and we
|
||||
# explicitly tell np.array to try and avoid copying)
|
||||
flatten = np.array(data.reshape(data.size), copy=False,
|
||||
dtype=np.float32)
|
||||
flatten = _maybe_np_slice(flatten, np.float32)
|
||||
_check_complex(data)
|
||||
return flatten
|
||||
|
||||
|
||||
def _from_numpy_array(data, missing, nthread, feature_names, feature_types):
|
||||
"""Initialize data from a 2-D numpy matrix.
|
||||
|
||||
If ``mat`` does not have ``order='C'`` (aka row-major) or is
|
||||
not contiguous, a temporary copy will be made.
|
||||
|
||||
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
|
||||
be made.
|
||||
|
||||
So there could be as many as two temporary data copies; be mindful of
|
||||
input layout and type if memory use is a concern.
|
||||
|
||||
"""
|
||||
flatten: np.ndarray = _transform_np_array(data)
|
||||
if len(data.shape) != 2:
|
||||
raise ValueError(
|
||||
"Expecting 2 dimensional numpy.ndarray, got: ", data.shape
|
||||
)
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
||||
flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
c_bst_ulong(data.shape[0]),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
ctypes.byref(handle),
|
||||
ctypes.c_int(nthread)))
|
||||
args = {
|
||||
"missing": float(missing),
|
||||
"nthread": int(nthread),
|
||||
}
|
||||
config = bytes(json.dumps(args), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArray(
|
||||
_array_interface(data),
|
||||
config,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user